# Standard library imports
from argparse import ArgumentParser
import os, sys
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

# Third party imports
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from torchdiffeq import odeint
import matplotlib.pyplot as plt
import numpy as np 
from tqdm import tqdm
from torch.distributions import Normal
from sklearn.metrics import r2_score

# local application imports
from lag_caVAE.lag import Lag_Net
from lag_caVAE.leap import Leap_Net
from lag_caVAE.leap import TF_Block_EXP_Residual_TV_2
from lag_caVAE.leap import TransformerEncoderLayerCategoricalsCatPos
from lag_caVAE.leap import TransformerEncoderLayer_v2_CatPos
from lag_caVAE.leap import TransformerEncoderLayer_v2_CategoricalsCatPos
from lag_caVAE.leap import EstimatorNetwork
from lag_caVAE.leap import TransformerEncoderLayerCategoricals
from lag_caVAE.leap import TransformerEncoderLayer_v2_Categoricals
from lag_caVAE.leap import EstimatorNetwork_v2
from lag_caVAE.leap import EstimatorNetwork_NoScale
from lag_caVAE.leap import MLP_Mag
from lag_caVAE.leap import P_Neural_TIME_MultipleParameter
from lag_caVAE.leap import TransformerEncoderLayer
from lag_caVAE.leap import TransformerEncoderLayer_v2
from lag_caVAE.nn_models import MLP_Encoder, MLP, MLP_Decoder, PSD, Encoder
from hyperspherical_vae.distributions import VonMisesFisher
from hyperspherical_vae.distributions import HypersphericalUniform
from utils import arrange_data, from_pickle, my_collate, ImageDataset, HomoImageDataset

# Set the prediction length
T_pred = 100
# Set the first n index to compute the loss values
Loss_first_index = 100
# Set the first n FFT index (from lowest to highest)
cutoff_freq_input = 15
# Set the dataset that we want to use
dataset_type = 14 # 0->mu+L; 1->L; 2->mu; 3->mu (large); 4->mu (samll, random); 5->mu (smaller, random)
# Learning rate
lr = 1e-3
# Define the size of the batch samples
num_batch = 32 # (gap = 20 -> 50 batches)
# Gradient clip
gradient_clip = 1.0
# Define the non-linearity for the q net and the recons net
MLP_Encoder_nonLinear = 'tanh' # tanh; elu; softplus; relu
# Define the model we use: 0->FFT input to ODE Net; 1->Time Series data to ODE Net
MODEL_TYPE = 0
# Set the plotting enable
Plot_enable = 0
# Set the attention
enable_attn = 1
# Set model_variant
model_variant = 4 # 1->pos and vel are all attention; 2->pos is MLP; and vel is attention; 4->pos and vel combined together
# Set model_variant for attention
# 0->attnetion; 1->attention v2 with extra layer norm and skip connection; 2->based on 1, additional categoricals
model_variant_attn = 4
# Set if using physics or not
enable_physics = 1
# Set the weight of the reconstruction loss function
weight_recons = 1
# Set the loss function that aligns the states between the ODE solver and the encoder
Time_loss_weight = 1
# Set the FFT cosine theta
FFT_loss_weight = 0
# Set frame velocity
velocity_loss_enable = 0.0
# Define the gap_interval
gap_interval = 100
# Define source mask length 
att_len = 7 #value->size of mask: 1->3; 2->5; 7->15; 12->25; 50 -> 101
# Attention model parameters
d_model_attn = 300 # the final size after doing CNN and flattening
nhead_attn = 10
d_middle_attn = 100 # no use here
d_final_attn = 6 # no use here
dropout_attn = 0.0
pos_en_scale_attn = 1.0
attn_nonlinearity = 'relu'
attn_nonlinearity_1 = 'tanh'
cnn_pooling = 'max' #or 'avg'
# for estimator network
# Define nonlinear for the esitmator network
nonlinearity = 0 # 0 for tanh; 1 for softplut (donot work well); 2 for elu
# Definin decoder network
dec_size = 100

NN_parameter = {'d_model': d_model_attn,\
                'nhead': nhead_attn,\
                'd_middle': d_middle_attn,\
                'd_final': d_final_attn,\
                'dropout': dropout_attn,\
                'pos_en_scale': pos_en_scale_attn,\
                'nonlinearity': attn_nonlinearity,\
                'nonlinearity_1': attn_nonlinearity_1,\
                'pooling': cnn_pooling} # default pos_en_scale is 1.0; for d_final, using Gaussain: d_final=5
gpu_number = 0

# Set the samll value to prevent NaN
small_value = 1e-5
# Set the simulation Hz
Hz = 20
# Define NN input layer
NN_input_layer = [50,50]
# Define the time that the training loss will alternate
training_loss_interval = 15
mean_over_everything = 0
# Set the name of the file
if enable_physics:
    save_dir = '_Da'    + str(dataset_type) \
             + '_STL'   + str(Time_loss_weight) \
             + '_SFL'   + str(FFT_loss_weight) \
             + '_VFL'   + str(velocity_loss_enable) \
             + '_S'     + str(mean_over_everything) \
             + '_Model' + str(model_variant) \
             + '-'      + str(model_variant_attn) \
             + '-'      + str(MODEL_TYPE)\
             + '_Attn' + str(enable_attn) \
             + '_'     + str(d_model_attn) \
             + '-'     + str(nhead_attn) \
             + '-'     + str(d_middle_attn) \
             + '-'     + str(d_final_attn) \
             + '-'     + str(dropout_attn) \
             + '-'     + str(pos_en_scale_attn) \
             + '-'     + str(attn_nonlinearity) \
             + '-'     + str(attn_nonlinearity_1) \
             + '-'     + str(cnn_pooling) \
             + '-'     + str(att_len) \
             + '_Est'  + str(nonlinearity) \
             + '-'     + str(NN_input_layer[0]) \
             + '-'     + str(NN_input_layer[1]) \
             + '_Dec'  + str(dec_size) \
             + '-'     + str(MLP_Encoder_nonLinear)\
             + '_Lrate'    + str(lr) \
             + '_batch'    + str(num_batch) \
             + '_gradC'    + str(gradient_clip) \
             + '_gapInt'   + str(gap_interval)
else:
    save_dir = '_Da'    + str(dataset_type) \
             + '_STL'   + str(Time_loss_weight) \
             + '_SFL'   + str(FFT_loss_weight) \
             + '_VFL'   + str(velocity_loss_enable) \
             + '_S'     + str(mean_over_everything) \
             + '_Model' + str(model_variant) \
             + '-'      + str(model_variant_attn) \
             + '-'      + str(MODEL_TYPE)\
             + '_Attn' + str(enable_attn) \
             + '_'     + str(d_model_attn) \
             + '-'     + str(nhead_attn) \
             + '-'     + str(d_middle_attn) \
             + '-'     + str(d_final_attn) \
             + '-'     + str(dropout_attn) \
             + '-'     + str(pos_en_scale_attn) \
             + '-'     + str(attn_nonlinearity) \
             + '-'     + str(attn_nonlinearity_1) \
             + '-'     + str(cnn_pooling) \
             + '-'     + str(att_len) \
             + '_Est'  + str(nonlinearity) \
             + '-'     + str(NN_input_layer[0]) \
             + '-'     + str(NN_input_layer[1]) \
             + '_Dec'  + str(dec_size) \
             + '-'     + str(MLP_Encoder_nonLinear)\
             + '_Lrate'    + str(lr) \
             + '_batch'    + str(num_batch) \
             + '_gradC'    + str(gradient_clip) \
             + '_gapInt'   + str(gap_interval) + '_NoPhysics' 


seed_everything(42)

class Model(pl.LightningModule):

    def __init__(self, hparams, data_path=None):
        super(Model, self).__init__()

        self.hparams = hparams
        self.data_path = data_path
        self.T_pred = self.hparams.T_pred
        self.loss_fn = torch.nn.MSELoss(reduction='none')
        self.loss_fn_mean = torch.nn.MSELoss()
        self.size_image = 64*64
        self.input_dim  = 4 #(cos,sin,theta_dot,u)
        self.plu_output = 0.55
        # For plotting purpose
        self.count = 0
        self.cutoff_index_input = cutoff_freq_input

        # Define encoder and decoder
        if enable_attn:
            if model_variant == 0:
                self.recog_q_net = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                                                           nhead=NN_parameter['nhead'], \
                                                           d_middle=NN_parameter['d_middle'], \
                                                           d_final=3, \
                                                           dropout=NN_parameter['dropout'], \
                                                           max_len=self.T_pred+1,\
                                                           pos_en_scale=NN_parameter['pos_en_scale'], \
                                                           activation=NN_parameter['nonlinearity'],\
                                                           activation_1=NN_parameter['nonlinearity_1'])
                self.recog_q_net_velocity = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                                                           nhead=NN_parameter['nhead'], \
                                                           d_middle=NN_parameter['d_middle'], \
                                                           d_final=2, \
                                                           dropout=NN_parameter['dropout'], \
                                                           max_len=self.T_pred+1,\
                                                           pos_en_scale=NN_parameter['pos_en_scale'], \
                                                           activation=NN_parameter['nonlinearity'],\
                                                           activation_1=NN_parameter['nonlinearity_1'])
            elif model_variant == 1:
                # This attention is reponsible for the locations
                if model_variant_attn == 0:
                    self.recog_q_net = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                                                               nhead=NN_parameter['nhead'], \
                                                               d_middle=NN_parameter['d_middle'], \
                                                               d_final=3, \
                                                               dropout=NN_parameter['dropout'], \
                                                               max_len=self.T_pred+1,
                                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                                               activation=NN_parameter['nonlinearity'],\
                                                               activation_1=NN_parameter['nonlinearity_1'],\
                                                               pooling=NN_parameter['pooling'])
                elif model_variant_attn == 1:
                    self.recog_q_net = TransformerEncoderLayer_v2(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])  
                elif model_variant_attn == 2:
                    self.recog_q_net = TransformerEncoderLayer_v2_Categoricals(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])  
                # This attnetion is reponsible for the velocity
                if model_variant_attn == 0:
                    self.recog_q_net_velocity = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])
                elif model_variant_attn == 1:
                    self.recog_q_net_velocity = TransformerEncoderLayer_v2(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])   
                elif model_variant_attn == 2:
                    self.recog_q_net_velocity = TransformerEncoderLayer_v2_Categoricals(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])  
            elif model_variant == 2:
                # This attention is reponsible for the locations
                #self.recog_q_net = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                #                                           nhead=NN_parameter['nhead'], \
                #                                           d_middle=NN_parameter['d_middle'], \
                #                                           d_final=3, \
                #                                           dropout=NN_parameter['dropout'], \
                #                                           max_len=self.T_pred+1,
                #                                           pos_en_scale=NN_parameter['pos_en_scale'], \
                #                                           activation=NN_parameter['nonlinearity'],\
                #                                           activation_1=NN_parameter['nonlinearity_1'])
                self.recog_q_net = MLP_Encoder(self.size_image, 300, 3, nonlinearity=MLP_Encoder_nonLinear)
                # This attnetion is reponsible for the velocity
                if model_variant_attn == 0:
                    self.recog_q_net_velocity = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])
                elif model_variant_attn == 1:
                    self.recog_q_net_velocity = TransformerEncoderLayer_v2(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])  
                elif model_variant_attn == 2:
                    self.recog_q_net_velocity = TransformerEncoderLayer_v2_Categoricals(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])           
            elif model_variant == 3:
                # This attention is reponsible for the locations
                #self.recog_q_net = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                #                                           nhead=NN_parameter['nhead'], \
                #                                           d_middle=NN_parameter['d_middle'], \
                #                                           d_final=3, \
                #                                           dropout=NN_parameter['dropout'], \
                #                                           max_len=self.T_pred+1,
                #                                           pos_en_scale=NN_parameter['pos_en_scale'], \
                #                                           activation=NN_parameter['nonlinearity'],\
                #                                           activation_1=NN_parameter['nonlinearity_1'])
                self.recog_q_net = MLP_Encoder(self.size_image, 300, 3, nonlinearity=MLP_Encoder_nonLinear)
                # This attnetion is reponsible for the velocity
                # TransformerEncoderLayerCategoricalsCatPos
                self.recog_q_net_velocity = TransformerEncoderLayerCategoricalsCatPos(d_model=NN_parameter['d_model'],\
                                           nhead=NN_parameter['nhead'], \
                                           d_middle=NN_parameter['d_middle'], \
                                           d_final=2, \
                                           dropout=NN_parameter['dropout'], \
                                           max_len=self.T_pred+1,
                                           pos_en_scale=NN_parameter['pos_en_scale'], \
                                           activation=NN_parameter['nonlinearity'],\
                                           activation_1=NN_parameter['nonlinearity_1'])
            elif model_variant == 4:
                if model_variant_attn == 0:
                    self.recog_q_net_state = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=6, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])
                elif model_variant_attn == 1:
                    # A single NN that process all the information
                    self.recog_q_net_state = TransformerEncoderLayer_v2(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3+3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])
                elif model_variant_attn == 2:
                    # A single NN that process all the information
                    self.recog_q_net_state = TransformerEncoderLayer_v2_Categoricals(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3+3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])
                elif model_variant_attn == 3:
                    # A single NN that process all the information
                    self.recog_q_net_state = TransformerEncoderLayer_v2_CategoricalsCatPos(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3+3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])
                elif model_variant_attn == 4:
                    # A single NN that process all the information
                    self.recog_q_net_state = TransformerEncoderLayer_v2_CatPos(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3+3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])

        else:
            self.recog_q_net = MLP_Encoder(self.size_image, 300, 3, nonlinearity=MLP_Encoder_nonLinear)
        
        self.obs_net = MLP_Encoder(1, dec_size, self.size_image, nonlinearity=MLP_Encoder_nonLinear)  
        
        if MODEL_TYPE == 0:
            # Define ODE Net for the input of FFT features  
            self.MLP_Spec_mu = EstimatorNetwork(input_size=[self.cutoff_index_input,self.cutoff_index_input],input_layer=NN_input_layer,nonlinearity=nonlinearity,mag=0.55)#.to(self.device)
            self.MLP_Spec_L = EstimatorNetwork(input_size=[self.cutoff_index_input,self.cutoff_index_input],input_layer=NN_input_layer,nonlinearity=nonlinearity,mag=0.55)#.to(self.device)
            self.ode = Leap_Net(MLP_Spec_mu=self.MLP_Spec_mu, \
                                MLP_Spec_L=self.MLP_Spec_L, \
                                mul_output=0.55,\
                                plu_output=self.plu_output, \
                                cutoff_index=self.cutoff_index_input,\
                                device=self.device,\
                                input_dim=self.input_dim)
        elif MODEL_TYPE == 1:
            # Define ODE Net for the input of FFT features  
            self.MLP_Spec_mu = EstimatorNetwork_v2(input_size=[self.cutoff_index_input,self.cutoff_index_input],input_layer=NN_input_layer,nonlinearity=nonlinearity,mag=0.55)#.to(self.device)
            self.MLP_Spec_L = EstimatorNetwork_v2(input_size=[self.cutoff_index_input,self.cutoff_index_input],input_layer=NN_input_layer,nonlinearity=nonlinearity,mag=0.55)#.to(self.device)
            self.ode = Leap_Net(MLP_Spec_mu=self.MLP_Spec_mu, \
                                MLP_Spec_L=self.MLP_Spec_L, \
                                mul_output=0.55,\
                                plu_output=self.plu_output, \
                                cutoff_index=self.cutoff_index_input,\
                                device=self.device,\
                                input_dim=self.input_dim)
        elif MODEL_TYPE == 2:
            # Define ODE Net for the input of time features   
            self.MLP_Spec_mu = MLP_Mag(input_size=3*T_pred,input_layer=NN_input_layer,nonlinearity=nonlinearity,mag=0.55)
            self.MLP_Spec_L = MLP_Mag(input_size=3*T_pred,input_layer=NN_input_layer,nonlinearity=nonlinearity,mag=0.55)
            self.ode = P_Neural_TIME_MultipleParameter(MLP_Spec_mu=self.MLP_Spec_mu, \
                                                       MLP_Spec_L=self.MLP_Spec_L,\
                                                       mul_output=0.55, \
                                                       plu_output=self.plu_output, \
                                                       cutoff_index=self.cutoff_index_input, \
                                                       device=self.device, \
                                                       input_dim=self.input_dim)

        self.train_dataset = None
        self.non_ctrl_ind = 1


        # Generate the src mask 
        self.SRC_MAS_V = []
        for i in np.arange(0,T_pred-cutoff_freq_input ,gap_interval):
            self.SRC_MAS_V.append(self.src_mask(T_pred+1-i))

        self.training_loss_flag = 0

    def train_dataloader(self):
        if self.hparams.homo_u:
            # must set trainer flag reload_dataloaders_every_epoch=True
            if self.train_dataset is None:
                self.train_dataset = HomoImageDataset(self.data_path, self.hparams.T_pred)
            if self.current_epoch < 1000:
                # feed zero ctrl dataset and ctrl dataset in turns
                if self.current_epoch % 2 == 0:
                    u_idx = 0
                else:
                    u_idx = self.non_ctrl_ind
                    self.non_ctrl_ind += 1
                    if self.non_ctrl_ind == 9:
                        self.non_ctrl_ind = 1
            else:
                u_idx = self.current_epoch % 9
            self.train_dataset.u_idx = u_idx
            self.t_eval = torch.from_numpy(self.train_dataset.t_eval)
            return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=my_collate)
        else:
            # This is our default setting since all the u is zero
            train_dataset = ImageDataset(self.data_path, self.hparams.T_pred)
            self.t_eval = torch.from_numpy(train_dataset.t_eval)

            return DataLoader(train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=my_collate)

    def angle_vel_est(self, q0_m_n, q1_m_n, delta_t):
        delta_cos = q1_m_n[:,0:1] - q0_m_n[:,0:1]
        delta_sin = q1_m_n[:,1:2] - q0_m_n[:,1:2]
        q_dot0 = - delta_cos * q0_m_n[:,1:2] / delta_t + delta_sin * q0_m_n[:,0:1] / delta_t
        return q_dot0

    def angle_vel_est_euler(self, q0, delta_t):

        T = q0.shape[0]
        theta_dot = (q0[1:T]-q0[0:T-1]) / delta_t

        return theta_dot.unsqueeze(1)

    def get_system_parameter(self, x):

        # Set the index
        i0 = torch.LongTensor([0])
        i1 = torch.LongTensor([1])

        # Get the state
        x1 = x[i0]
        x2 = x[i1]

        # The state is [cos,sin,theta_dot,u]

        # For cos\theta
        # Add i part
        state = torch.unsqueeze(x[2:], 1).to(self.device)
        state = torch.cat((state,torch.zeros((state.shape[0],1)).to(self.device)),axis=1)
        cos_theta = state[0::self.input_dim,:]
        # Use FTT
        cos_theta_fft = torch.fft(cos_theta,1,normalized=False).to(self.device)
        # Get mag
        cos_theta_mag = cos_theta_fft[:,0] ** 2 + cos_theta_fft[:,1] ** 2
        cos_theta_mag = cos_theta_mag[0:self.cutoff_index_input]**0.5

        # For sin\theta
        # Add i part
        state = torch.unsqueeze(x[2:], 1).to(self.device)
        state = torch.cat((state,torch.zeros((state.shape[0],1)).to(self.device)),axis=1)
        sin_theta = state[1::self.input_dim,:]
        # Use FTT
        sin_theta_fft = torch.fft(sin_theta,1,normalized=False).to(self.device)
        # Get mag
        sin_theta_mag = sin_theta_fft[:,0] ** 2 + sin_theta_fft[:,1] ** 2
        sin_theta_mag = sin_theta_mag[0:self.cutoff_index_input]**0.5

        # For theta_dot
        # Add i part
        state = torch.unsqueeze(x[2:], 1).to(self.device)
        state = torch.cat((state,torch.zeros((state.shape[0],1)).to(self.device)),axis=1)
        theta_dot = state[2::self.input_dim,:] / 7.0 # do normalization
        # Use FTT
        theta_dot_fft = torch.fft(theta_dot,1,normalized=False).to(self.device)
        # Get mag
        theta_dot_mag = theta_dot_fft[:,0] ** 2 + theta_dot_fft[:,1] ** 2
        theta_dot_mag = theta_dot_mag[0:self.cutoff_index_input]**0.5

        # Only Freq
        mu = self.MLP_Spec_mu(torch.log(cos_theta_mag),\
                              torch.log(sin_theta_mag),\
                              torch.log(theta_dot_mag)) + self.plu_output
        L = self.MLP_Spec_L(torch.log(cos_theta_mag),\
                            torch.log(sin_theta_mag),\
                            torch.log(theta_dot_mag)) + self.plu_output

        return mu, L, cos_theta_mag, theta_dot_mag

    def get_system_parameter_time(self, x):

        # Set the index
        i0 = torch.LongTensor([0])
        i1 = torch.LongTensor([1])

        # Get the state
        x1 = x[i0]
        x2 = x[i1]

        x_input = x[2:]

        # Remove u information in x_input
        cos_theta = x_input[0::self.input_dim]
        sin_theta = x_input[1::self.input_dim]
        theta_dot = x_input[2::self.input_dim]
        x_input = torch.cat((cos_theta,sin_theta,theta_dot),axis=0)

        # Get the prediction
        mu = self.MLP_Spec_mu(x_input) + self.plu_output
        L = self.MLP_Spec_L(x_input) + self.plu_output

        return mu, L, None, None

    def get_theta_inv(self, cos, sin, x, y, bs=None):
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos ; theta[:, 0, 1] += -sin ; theta[:, 0, 2] += - x * cos + y * sin
        theta[:, 1, 0] += sin ; theta[:, 1, 1] += cos ;  theta[:, 1, 2] += - x * sin - y * cos
        return theta

    def encode(self, batch_image):

        '''
        NaN_flag = 0
        for name, param in self.recog_q_net_velocity.named_parameters():
            if param.requires_grad:
                if torch.isnan(param.data).any():
                    print('parameter of recog_q_net_velocity is NaN!!')
                    NaN_flag = 1

        for name, param in self.obs_net.named_parameters():
            if param.requires_grad:
                if torch.isnan(param.data).any():
                    print('parameter of obs_net is NaN!!')
                    NaN_flag = 1

        if NaN_flag == 1:
            sys.exit('parameter is NaN. Halt forcely!!')
        '''
        q_m_logv = self.recog_q_net(batch_image+1e-5)
        q_m, q_logv = q_m_logv.split([2, 1], dim=1)
        q_m_n = q_m / (q_m.norm(dim=-1, keepdim=True) + small_value)
        q_v = F.softplus(q_logv) + 1

        return q_m, q_v, q_m_n


    def src_mask(self, dim):
        #https://discuss.pytorch.org/t/how-to-add-padding-mask-to-nn-transformerencoder-module/63390/2
        mask = torch.zeros(dim,dim).float() + float('-inf')
        # Define attend half range
        for i in range(dim):
            min_ = max(0, i-att_len)
            max_ = min(dim, i+att_len+1)
            for j in range(min_,max_):
                mask[i,j] = 0.0
        return mask

    def print_parameter(self):
        
        print('-----h: recog_q_net-----')
        for name, param in self.recog_q_net.named_parameters():
            print(name,param)
            break
        
        print('-----h: recog_q_net_velocity-----')
        for name, param in self.recog_q_net_velocity.named_parameters():            
            print(name,param)
            break
                
        print('-----f: MLP_Spec_mu-----')
        for name, param in self.MLP_Spec_mu.named_parameters():
            print(name,param)
            break

        print('-----g: obs_net-----')
        for name, param in self.obs_net.named_parameters():
            print(name,param)
            break
        

        


    def encode_self_attention(self, batch_image, src_mask_v):

        

        if model_variant==0:
            # X is in the shape of (T,64*64)
            batch_image = batch_image+1e-5
            X_attn = batch_image.unsqueeze(1)
            # Here, attn_output is of size [101, 1, 6]
            attn_output, attn_output_weight = self.recog_q_net(X_attn,src_mask=src_mask_v.to(self.device))
            # Add the source mask to make sure it attends to the right position
            attn_output_velocity, attn_output_weight_velocity = self.recog_q_net_velocity(X_attn,src_mask=src_mask_v.to(self.device))
            # Here, attn_output is of size [101, 6]
            attn_output = attn_output.squeeze()
            attn_output_velocity = attn_output_velocity.squeeze()

            q_m_loc, q_logv_loc = attn_output.split([2, 1], dim=1)
            q_m_vel, q_v_vel = attn_output_velocity.split([1, 1], dim=1)
            # Location
            q_m_loc_n = q_m_loc / (q_m_loc.norm(dim=-1, keepdim=True) + small_value)
            q_v_loc = F.softplus(q_logv_loc) + 1
            # Velocity
            q_m_vel = torch.tanh(q_m_vel) * 7
            q_v_vel = torch.sigmoid(q_v_vel) * 0.05 + 0.0001

            return q_m_loc, q_v_loc, q_m_loc_n, q_m_vel, q_v_vel, attn_output_weight, attn_output_weight_velocity

        elif model_variant==1 :
            # X is in the shape of (T,64*64)
            batch_image = batch_image+1e-5
            X_attn = batch_image.unsqueeze(1)
            # Here, attn_output is of size [101, 1, 3]
            # For position
            attn_output, attn_output_weight = self.recog_q_net(X_attn)
            # For velocity
            # Add the source mask to make sure it attends to the right position
            #attn_output_velocity, attn_output_weight_velocity = self.recog_q_net_velocity(X_attn,src_mask=src_mask_v.to(self.device))
            attn_output_velocity, attn_output_weight_velocity = self.recog_q_net_velocity(X_attn,src_mask=src_mask_v.to(self.device))
            # Here, attn_output is of size [101, 3]
            attn_output = attn_output.squeeze()
            # Here, attn_output_velocity is of size [101, 3]
            attn_output_velocity = attn_output_velocity.squeeze()

            q_m_loc, q_logv_loc = attn_output.split([2, 1], dim=1)
            q_m_vel, q_logv_vel = attn_output_velocity.split([2, 1], dim=1)
            # Location
            q_m_loc_n = q_m_loc / (q_m_loc.norm(dim=-1, keepdim=True) + small_value)
            q_v_loc = F.softplus(q_logv_loc) + 1
            # Velocity
            q_m_vel_n = q_m_vel / (q_m_vel.norm(dim=-1, keepdim=True) + small_value)
            q_v_vel = F.softplus(q_logv_vel) + 1

            return q_m_loc, q_v_loc, q_m_loc_n, \
                   q_m_vel, q_v_vel, q_m_vel_n, \
                   attn_output_weight, attn_output_weight_velocity

        elif model_variant==2 :
            # X is in the shape of (T,64*64)
            batch_image = batch_image+1e-5
            X_attn = batch_image.unsqueeze(1)
            # Here, attn_output is of size [101, 1, 3]
            # For position
            attn_output_weight = []
            attn_output = self.recog_q_net(X_attn)
            # For velocity
            # Add the source mask to make sure it attends to the right position
            attn_output_velocity, attn_output_weight_velocity = self.recog_q_net_velocity(X_attn,src_mask=src_mask_v.to(self.device))
            # Here, attn_output is of size [101, 3]
            attn_output = attn_output.squeeze()
            # Here, attn_output_velocity is of size [101, 3]
            attn_output_velocity = attn_output_velocity.squeeze()

            q_m_loc, q_logv_loc = attn_output.split([2, 1], dim=1)
            q_m_vel, q_logv_vel = attn_output_velocity.split([2, 1], dim=1)
            # Location
            q_m_loc_n = q_m_loc / (q_m_loc.norm(dim=-1, keepdim=True) + small_value)
            q_v_loc = F.softplus(q_logv_loc) + 1
            # Velocity
            q_m_vel_n = q_m_vel / (q_m_vel.norm(dim=-1, keepdim=True) + small_value)
            q_v_vel = F.softplus(q_logv_vel) + 1

            return q_m_loc, q_v_loc, q_m_loc_n, \
                   q_m_vel, q_v_vel, q_m_vel_n, \
                   attn_output_weight, attn_output_weight_velocity

        elif model_variant==3 :
            # X is in the shape of (T,64*64)
            batch_image = batch_image+1e-5
            X_attn = batch_image.unsqueeze(1)
            # Here, attn_output is of size [101, 1, 3]
            # For position
            attn_output_weight = []
            attn_output = self.recog_q_net(X_attn)
            # For velocity
            # Add the source mask to make sure it attends to the right position
            attn_output_velocity, attn_output_weight_velocity = self.recog_q_net_velocity(X_attn,src_mask=src_mask_v.to(self.device))
            # Here, attn_output is of size [101, 3]
            attn_output = attn_output.squeeze()
            # Here, attn_output_velocity is of size [101, 3]
            attn_output_velocity = attn_output_velocity.squeeze()

            q_m_loc, q_logv_loc = attn_output.split([2, 1], dim=1)

            # Location
            q_m_loc_n = q_m_loc / (q_m_loc.norm(dim=-1, keepdim=True) + small_value)
            q_v_loc = F.softplus(q_logv_loc) + 1

            return q_m_loc, q_v_loc, q_m_loc_n, \
                   attn_output_velocity, \
                   attn_output_weight, attn_output_weight_velocity

        elif model_variant == 4 :

            # X is in the shape of (T,64*64)
            batch_image = batch_image+1e-5
            X_attn = batch_image.unsqueeze(1)
            # Here, attn_output is of size [101, 1, 3]
            # For position and velocity
            # Add the source mask to make sure it attends to the right position
            attn_output, attn_output_weight = self.recog_q_net_state(X_attn,src_mask=src_mask_v.to(self.device))
            #attn_output, attn_output_weight = self.recog_q_net_state(X_attn)
            # Here, attn_output_velocity is of size [101, 6]
            attn_output = attn_output.squeeze()

            q_m_loc, q_logv_loc, q_m_vel, q_logv_vel = attn_output.split([2, 1, 2, 1], dim=1)
            # Location
            q_m_loc_n = q_m_loc / (q_m_loc.norm(dim=-1, keepdim=True) + small_value)
            q_v_loc = F.softplus(q_logv_loc) + 1
            # Velocity
            q_m_vel_n = q_m_vel / (q_m_vel.norm(dim=-1, keepdim=True) + small_value)
            q_v_vel = F.softplus(q_logv_vel) + 1

            return q_m_loc, q_v_loc, q_m_loc_n, \
                   q_m_vel, q_v_vel, q_m_vel_n, \
                   attn_output_weight, attn_output_weight

    def forward(self, X, u, S, TIME_INDEX, src_mask_v, mu_full_length_list=None):

        '''
        for param in self.recog_q_net.parameters():
            param.requires_grad = True
        for param in self.recog_q_net_velocity.parameters():
            param.requires_grad = True
        for param in self.MLP_Spec_mu.parameters():
            param.requires_grad = True
        for param in self.obs_net.parameters():
            param.requires_grad = True   
        '''
        [T, self.bs, d, d] = X.shape
        #T = len(self.t_eval)

        x_enc_list = []
        x_enc_frameV_list = []
        x_sim_list = []
        mu_list = []
        L_list = []

        Enc_theta_FFT_list = []
        Enc_theta_dot_FFT_list = []
        ODE_theta_FFT_list = []
        ODE_theta_dot_FFT_list = []
        Attn_output_weight_list = []
        Attn_output_weight_velocity_list = []

        for batch_ii in tqdm(range(self.bs)):
        #for batch_ii in range(self.bs):

            u = torch.zeros((T,1)).to(self.device)

            # =======Encode=======
            # Get the mean and the variance of the distribution
            if enable_attn:
                if model_variant==0:
                    self.q0_m, self.q0_v, self.q0_m_n, \
                            self.q0_dot_m, self.q0_dot_v, \
                                self.attn_output_weight, self.attn_output_weight_velocity \
                                    = self.encode_self_attention(X[:,batch_ii,:,:].reshape(T, d*d),src_mask_v)
                elif model_variant==3:
                    self.q0_m, self.q0_v, self.q0_m_n, \
                            self.q0_dot_m, \
                                self.attn_output_weight, self.attn_output_weight_velocity \
                                    = self.encode_self_attention(X[:,batch_ii,:,:].reshape(T, d*d),src_mask_v)
                else:
                    self.q0_m, self.q0_v, self.q0_m_n, \
                        self.q0_dot_m, self.q0_dot_v, self.q0_dot_m_n, \
                            self.attn_output_weight, self.attn_output_weight_velocity \
                                = self.encode_self_attention(X[:,batch_ii,:,:].reshape(T, d*d),src_mask_v)
                
                if model_variant == 1:
                    Attn_output_weight_list.append(self.attn_output_weight[0])
                Attn_output_weight_velocity_list.append(self.attn_output_weight_velocity[0])
            else:
                self.q0_m, self.q0_v, self.q0_m_n = self.encode(X[:,batch_ii,:,:].reshape(T, d*d))

            # Sample mean and the variance
            self.Q_q = VonMisesFisher(self.q0_m_n, self.q0_v) 
            self.P_q = HypersphericalUniform(1, device=self.device)
            self.q0 = self.Q_q.rsample().to(self.device) # bs, 2 = cos\theta and sin\theta instead of \theta

            while torch.isnan(self.q0).any():
                self.q0 = self.Q_q.rsample().to(self.device) # a bad way to avoid nan

            if enable_attn:
                if model_variant==0:
                    # Sample mean and the variance
                    self.Q_dot_q = Normal(self.q0_dot_m, self.q0_dot_v)
                    self.P_normal = Normal(torch.zeros_like(self.q0_dot_m), torch.ones_like(self.q0_dot_v))
                    self.q_dot0 = self.Q_dot_q.rsample().to(self.device) # bs, 2 = cos\theta and sin\theta instead of \theta
                    self.q_dot0 = self.q_dot0[0:T-1]

                    # Compute the velocity using finit element
                    # This is achieved by comparing two frames
                    self.q_dot0_compareFrame = self.angle_vel_est(self.q0_m_n[0:T-1], self.q0_m_n[1:T], self.t_eval[1]-self.t_eval[0]).to(self.device)
                elif model_variant == 3:
                    self.q_dot0 = self.q0_dot_m[0:T-1].unsqueeze(1) # bs, 2 = cos\theta and sin\theta instead of \theta
                    self.q_dot0_compareFrame = self.angle_vel_est(self.q0_m_n[0:T-1], self.q0_m_n[1:T], self.t_eval[1]-self.t_eval[0]).to(self.device)
                else:
                    # Using attention to estimate the velocity
                    # Sample mean and the variance
                    self.Q_dot_q = VonMisesFisher(self.q0_dot_m_n, self.q0_dot_v) 
                    self.q_dot0 = self.Q_dot_q.rsample().to(self.device) # bs, 2 = cos\theta and sin\theta instead of \theta
                    while torch.isnan(self.q_dot0).any():
                        self.q_dot0 = self.Q_dot_q.rsample().to(self.device) # a bad way to avoid nan
                    # Trim it make it one time step smaller. And the output is size of [T,2], 2 is for cos/sin
                    self.q_dot0 = self.q_dot0[0:T-1,0].unsqueeze(1)

                    # Make it in the resonable scale [-7,7]
                    self.q_dot0 = self.q_dot0 * 7
                    
                    # Compute the velocity using finit element
                    # This is achieved by comparing two frames
                    self.q_dot0_compareFrame = self.angle_vel_est(self.q0_m_n[0:T-1], self.q0_m_n[1:T], self.t_eval[1]-self.t_eval[0]).to(self.device)
            else:
                # Estimate velocity using finit element
                self.q_dot0 = self.angle_vel_est(self.q0_m_n[0:T-1], self.q0_m_n[1:T], self.t_eval[1]-self.t_eval[0]).to(self.device)
                self.q_dot0_compareFrame = self.angle_vel_est(self.q0_m_n[0:T-1], self.q0_m_n[1:T], self.t_eval[1]-self.t_eval[0]).to(self.device)

            # Estimate euler velocity
            #self.q_dot0 = self.angle_vel_est_euler(torch.atan2(self.q0[:,1],self.q0[:,0]) + np.pi, \
            #                                       self.t_eval[1]-self.t_eval[0]).to(self.device)

            # predict
            z0_u = torch.cat((self.q0[0:T-1], self.q_dot0, u[0:T-1]), dim=1) #torch.Size([simulation_length, 4])
            x_enc_list.append(z0_u)
            x_enc_frameV_list.append( torch.cat((self.q0[0:T-1], self.q_dot0_compareFrame, u[0:T-1]), dim=1))

            #if batch_ii == 0:
            #    print('before ode (cos,sin,v,u):',z0_u[0:10,:],'index:',batch_ii,z0_u.shape) 
            # Transfer to the theta, theta_dot state
            z0_u = z0_u.reshape((1,-1))


            # This is form: atan2(y=sin, x=cos)
            theta = torch.zeros((1,1)).to(self.device) + torch.atan2(z0_u[:,1],z0_u[:,0]) + np.pi
            theta_dot =  torch.zeros((1,1)).to(self.device) + z0_u[0,2]
            # Append with the inital state
            s_init = torch.cat((theta, theta_dot), dim=1).to(self.device)
            z0_u = torch.cat((s_init.to(self.device),z0_u),dim=1)
            # Get the FFT and mu, L predictions
            # This is okay for the shorter simulation length, since I also want them to be correct
            if MODEL_TYPE == 0 or MODEL_TYPE == 1:
                mu, L, Enc_theta_FFT, Enc_theta_dot_FFT = self.get_system_parameter(z0_u[0,:])
            elif MODEL_TYPE == 2:
                mu, L, _, _ = self.get_system_parameter_time(z0_u[0,:])
            # Append the number on the list
            # We only need to store the mu and L in the first place
            if TIME_INDEX == 0:
                mu_list.append(mu.detach().cpu().numpy()[0])
                L_list.append(L.detach().cpu().numpy()[0])
            # Do the simulation
            # The following ifelse ensures that the prediction is only made when given the full length of the data
            if TIME_INDEX == 0:
                self.ode.reset_prediction(prediction_enable=1)
            else:
                self.ode.reset_prediction(prediction_enable=0,mu=mu_full_length_list[batch_ii],L=1.0)

            self.t_eval_ = self.t_eval[TIME_INDEX:]
            if enable_physics:
                zT_u = odeint(self.ode, z0_u[0,:], self.t_eval_, method=self.hparams.solver) # T,299
            else:
                #Require theta, and theta_dot
                z0_u = z0_u[:,2:]
                theta = torch.atan2(z0_u[:,1::4],z0_u[:,0::4]) + np.pi
                theta_dot = z0_u[:,2::4]
                zT_u = torch.cat((theta,theta_dot),axis=0)
                zT_u = torch.cat((zT_u,torch.zeros((2,1)).to(self.device)),axis=1)
                zT_u = zT_u.T

            # Get the state
            zT_u = zT_u[:,0:2]
            # get cosine
            z_cos = torch.cos(zT_u[:,0] - np.pi).unsqueeze(1)
            z_sin = torch.sin(zT_u[:,0] - np.pi).unsqueeze(1)
            z_vel = zT_u[:,1].unsqueeze(1).to(self.device)
            z_u   = torch.zeros(T).unsqueeze(1).to(self.device)
            zT_u  = torch.cat((z_cos, z_sin, z_vel, z_u),dim=1) # T,4

            # ODE output contains one more step
            state = zT_u[0:T-1].reshape((1,-1))

            # Get the Fourier transform of the ODE state
            if MODEL_TYPE == 0 or MODEL_TYPE == 1:
                _, _, ODE_theta_FFT, ODE_theta_dot_FFT = self.get_system_parameter(torch.cat((s_init.to(self.device),state),dim=1)[0])
            elif MODEL_TYPE == 2:
                _, _, _, _ = self.get_system_parameter_time(torch.cat((s_init.to(self.device),state),dim=1)[0])

            #if batch_ii == 0:
            #    print('after ode (cos,sin,v,u):',zT_u[0:10,:],'index:',batch_ii,zT_u.shape)
            #if batch_ii == 0:
            #    print('true state (cos,sin,theta,theta_dot):',S[0:10,batch_ii,0:4],'index:',batch_ii,zT_u.shape)

            x_sim_list.append(zT_u) # T, bs, 4

            if MODEL_TYPE == 0 or MODEL_TYPE == 1:
                Enc_theta_FFT_list.append(Enc_theta_FFT)
                Enc_theta_dot_FFT_list.append(Enc_theta_dot_FFT)
                ODE_theta_FFT_list.append(ODE_theta_FFT)
                ODE_theta_dot_FFT_list.append(ODE_theta_dot_FFT)

        # Stack the data, and retain_grad() after using stack function
        x_sim_list = torch.stack(x_sim_list, axis=0)
        x_sim_list.retain_grad()
        x_sim_list = x_sim_list.permute(1,0,2)

        x_enc_list = torch.stack(x_enc_list, axis=0)
        x_enc_list.retain_grad()
        x_enc_list = x_enc_list.permute(1,0,2)

        x_enc_frameV_list = torch.stack(x_enc_frameV_list, axis=0)
        x_enc_frameV_list.retain_grad()
        x_enc_frameV_list = x_enc_frameV_list.permute(1,0,2)

        if MODEL_TYPE == 0 or MODEL_TYPE == 1:
            self.FF_p_Enc = torch.stack(Enc_theta_FFT_list, axis=0)
            self.FF_p_Enc.retain_grad()
            self.FF_v_Enc = torch.stack(Enc_theta_dot_FFT_list, axis=0)
            self.FF_v_Enc.retain_grad()
            self.FF_p_ODE = torch.stack(ODE_theta_FFT_list, axis=0)
            self.FF_p_ODE.retain_grad()
            self.FF_v_ODE = torch.stack(ODE_theta_dot_FFT_list, axis=0)
            self.FF_v_ODE.retain_grad()
            if model_variant == 1:
                self.Attn_output_weight_list = torch.stack(Attn_output_weight_list, axis=0)
            if enable_attn == 1:
                self.Attn_output_weight_velocity_list = torch.stack(Attn_output_weight_velocity_list, axis=0)

        # We only need to store the mu and L in the first place
        if TIME_INDEX == 0:
            self.mu_list = mu_list
            self.L_list = L_list
        
        self.qT, self.q_dotT, _ = x_sim_list.split([2, 1, 1], dim=-1)
        self.qT = self.qT.contiguous()
        self.qT = self.qT.view(T*self.bs, 2)

        self.qT_enc, self.q_dotT_enc, _ = x_enc_list.split([2, 1, 1], dim=-1)
        self.qT_enc = self.qT_enc.contiguous()
        self.qT_enc = self.qT_enc.view((T-1)*self.bs, 2)

        self.qT_enc_frameV, self.q_dotT_enc_frameV, _ = x_enc_frameV_list.split([2, 1, 1], dim=-1)
        self.qT_enc_frameV = self.qT_enc_frameV.contiguous()
        self.qT_enc_frameV = self.qT_enc_frameV.view((T-1)*self.bs, 2)

        # =======Decode=======
        # Here we want to get the content of the pole
        ones = torch.ones_like(self.qT[:, 0:1])
        self.content = self.obs_net(ones)
        # Get the theta information to place the pole
        theta = self.get_theta_inv(self.qT[:, 0], self.qT[:, 1], 0, 0, bs=T*self.bs) # cos , sin 
        grid = F.affine_grid(theta, torch.Size((T*self.bs, 1, d, d)))
        # Get the reconstruction images
        self.Xrec = F.grid_sample(self.content.view(T*self.bs, 1, d, d), grid)
        self.Xrec = self.Xrec.view([T, self.bs, d, d])

        # Plot something to track the performance
        if self.count % 200 == 0 and Plot_enable == True:
            for tt in range(T):
                fig1 = plt.figure(constrained_layout=False, figsize=(10,4))
                gs = fig1.add_gridspec(1, 2, width_ratios=[1.0,1.0])
                ax = fig1.add_subplot(gs[0, 0])
                from torchvision import utils
                grid = utils.make_grid(X[tt, 0].view(-1, 1, 64, 64))
                X_ = np.array(grid.permute(1,2,0).detach().cpu().numpy())/ 255.0
                ax.imshow(X_)
                ax = fig1.add_subplot(gs[0, 1])

                grid = utils.make_grid(self.Xrec[tt, 0].view(-1, 1, 64, 64))
                X_ = np.array(grid.permute(1,2,0).detach().cpu().numpy())
                ax.imshow(X_)
                

       
        if TIME_INDEX == 0:
            self.count += 1

        return None

    def training_step(self, train_batch, batch_idx):

        #self.print_parameter()

        X, u, State = train_batch
        # X is in the shape of torch.Size([100, 256, 64, 64]) = [time, gray_scale,image_dim,image_dim]
        # T: simulation length: T = 100
        # size of X is (T+1, batch_size, 64, 64)
        # size of u is (64, 1), because of constant u
        # size of State is (T+1, batch_size, 7)

        lhood_list = []
        kl_q_list = []
        penalty_list = []
        Time_loss_list = []
        Time_pos_loss_list = []
        Time_vel_loss_list = []
        velocity_loss_list = []
        FFT_loss_list = []
        FFT_p_loss_list = []
        FFT_v_loss_list = []

        iii = 0
        for TIME_INDEX in np.arange(0,T_pred-cutoff_freq_input ,gap_interval): # Default: 20 is the gap interval
            X_ = X[TIME_INDEX:,:,:,:]

            State_ = State[TIME_INDEX:,:,:]
            if TIME_INDEX != 0:
                self.forward(X_, u, State_,TIME_INDEX,self.SRC_MAS_V[iii],self.mu_list)
            else:
                self.forward(X_, u, State_,TIME_INDEX,self.SRC_MAS_V[iii])

            iii += 1

            # Compute the system parameter loss
            if TIME_INDEX == 0:
                true_mu = State[0,:,5]
                true_L = State[0,:,4]
                pred_mu = torch.tensor(np.array(self.mu_list))
                pred_L = torch.tensor(np.array(self.L_list))
                mse_mu = torch.mean(self.loss_fn(pred_mu.to(self.device),true_mu.to(self.device)))
                mse_L = torch.mean(self.loss_fn(pred_L.to(self.device),true_L.to(self.device)))

                #R2_para   = r2_score(list(pred_mu.cpu().detach().numpy()), \
                #                     list(true_mu.cpu().detach().numpy()))
                R2_para = 0
                print('pred_mu:',pred_mu)
                print('true_mu:',true_mu)
                #print('R2:',R2_para)

            # Get the Fourier Transform loss
            FFT_loss = 0
            FFT_cos_theta = 0
            FFT_theta_dot = 0
            if MODEL_TYPE == 0 or MODEL_TYPE == 1:
                FFT_cos_theta = self.loss_fn(self.FF_p_Enc,self.FF_p_ODE.detach()).sum([1]).mean()
                FFT_theta_dot = self.loss_fn(self.FF_v_Enc,self.FF_v_ODE.detach()).sum([1]).mean()
                FFT_loss = FFT_cos_theta + FFT_theta_dot

            # Get the time loss to align the states between the encoder and the ode solver
            T_pred_ = X_.shape[0] - 1
            print('ode state pos:', self.qT.view(T_pred_+1,self.bs,-1)[0:T_pred_][0:10,0,:])
            print('enc state pos:', self.qT_enc.view(T_pred_,self.bs,-1)[0:10,0,:])
            print('ode state vel:', self.q_dotT.view(T_pred_+1,self.bs,-1)[0:T_pred_][0:10,0,:])
            print('enc state vel:', self.q_dotT_enc.view(T_pred_,self.bs,-1)[0:10,0,:])

            # current version mean over everthing

            if mean_over_everything == 1:
                Time_pos_loss = self.loss_fn_mean(self.qT.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.qT_enc.view(T_pred_,self.bs,-1))
                Time_vel_loss = self.loss_fn_mean(self.q_dotT[0:self.q_dotT_enc.shape[0]].detach(),self.q_dotT_enc)
            else:      
                Time_pos_loss = self.loss_fn(self.qT.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.qT_enc.view(T_pred_,self.bs,-1))
                Time_vel_loss = self.loss_fn(self.q_dotT[0:self.q_dotT_enc.shape[0]].detach(),self.q_dotT_enc)

            # The size of Time_pos_loss is [Time_Steps, BatchSize, 2]
                Time_pos_loss = Time_pos_loss.sum([0,2]).mean() # Per batchsize over steps and states
                Time_vel_loss = Time_vel_loss.sum([0,2]).mean() # Per batchsize over steps and states
            #Time_pos_loss = Time_pos_loss.mean() # Per batchsize and steps over states
            #Time_vel_loss = Time_vel_loss.mean() # Per batchsize and steps over states
            Time_loss = Time_pos_loss + Time_vel_loss
  
            # Get the loss that forces the states by the encoder is the same as the one from comparing two frames
            velocity_loss =  self.loss_fn(self.q_dotT_enc,self.q_dotT_enc_frameV)
            velocity_loss = velocity_loss.sum([0, 2]).mean()

            ######### Compute the loss #########
            # current version
            lhood = - self.loss_fn(self.Xrec[0:Loss_first_index], X_[0:Loss_first_index])
            lhood = lhood.sum([0, 2, 3]).mean()
            #lhood = lhood.sum([2, 3]).mean()

            if model_variant==0:
                kl_q = torch.distributions.kl.kl_divergence(self.Q_q, self.P_q).mean() \
                     + torch.distributions.kl.kl_divergence(self.Q_dot_q, self.P_normal).mean()
            elif model_variant==3:
                kl_q = torch.distributions.kl.kl_divergence(self.Q_q, self.P_q).mean()
            else:
                if enable_attn == 1:
                    kl_q = torch.distributions.kl.kl_divergence(self.Q_q, self.P_q).mean() \
                         + torch.distributions.kl.kl_divergence(self.Q_dot_q, self.P_q).mean()
                else:
                    kl_q = torch.distributions.kl.kl_divergence(self.Q_q, self.P_q).mean()
            if enable_attn == 1:    
                norm_penalty = (self.q0_m.norm(dim=-1).mean() - 1) ** 2 + (self.q0_dot_m.norm(dim=-1).mean() - 1) ** 2
            else:
                norm_penalty = (self.q0_m.norm(dim=-1).mean() - 1) ** 2
            lambda_ = self.current_epoch/8000 if self.hparams.annealing else 1/100

            lhood_list.append(lhood)
            kl_q_list.append(kl_q)
            penalty_list.append(lambda_ * norm_penalty)
            Time_loss_list.append(Time_loss)
            Time_pos_loss_list.append(Time_pos_loss)
            Time_vel_loss_list.append(Time_vel_loss)
            velocity_loss_list.append(velocity_loss)
            FFT_loss_list.append(FFT_loss)
            FFT_p_loss_list.append(FFT_cos_theta)
            FFT_v_loss_list.append(FFT_theta_dot)

        #### Final loss function ####
        # Reconstruction loss
        # KL loss
        # Regulization loss
        # State Alignment loss
        # Velocity Alignment Loss

        lhood_list = torch.stack(lhood_list, axis=0)
        lhood_list.retain_grad()
        kl_q_list = torch.stack(kl_q_list, axis=0)
        kl_q_list.retain_grad()
        penalty_list = torch.stack(penalty_list, axis=0)
        penalty_list.retain_grad()
        Time_loss_list = torch.stack(Time_loss_list, axis=0)
        Time_loss_list.retain_grad()
        Time_pos_loss_list = torch.stack(Time_pos_loss_list, axis=0)
        Time_pos_loss_list.retain_grad()
        Time_vel_loss_list = torch.stack(Time_vel_loss_list, axis=0)
        Time_vel_loss_list.retain_grad()
        velocity_loss_list = torch.stack(velocity_loss_list, axis=0)
        velocity_loss_list.retain_grad()
        if MODEL_TYPE != 2:
            FFT_loss_list = torch.stack(FFT_loss_list, axis=0)
            FFT_loss_list.retain_grad()
            FFT_p_loss_list = torch.stack(FFT_p_loss_list, axis=0)
            FFT_p_loss_list.retain_grad()
            FFT_v_loss_list = torch.stack(FFT_v_loss_list, axis=0)
            FFT_v_loss_list.retain_grad()
        
        # current version
        #loss = Time_loss_list.mean()
        if MODEL_TYPE != 2:
            loss = - weight_recons * lhood_list.mean() \
                   + lambda_ * penalty_list.mean() \
                   + 1.0 * kl_q_list.mean()\
                   + Time_loss_weight * Time_loss_list.mean() \
                   + FFT_loss_weight * FFT_loss_list.mean()\
                   + velocity_loss_enable * velocity_loss_list.mean()
        else:
            loss = - weight_recons * lhood_list.mean() \
                   + lambda_ * penalty_list.mean() \
                   + 1.0 * kl_q_list.mean()\
                   + Time_loss_weight * Time_loss_list.mean() \
                   + velocity_loss_enable * velocity_loss_list.mean()

        '''
        if self.current_epoch % training_loss_interval == 0:
            if self.training_loss_flag == 0:
                self.training_loss_flag = 1
            else:
                self.training_loss_flag = 0

        if self.training_loss_flag == 0:
            print('update f, freeze h & g')
        elif self.training_loss_flag == 1:
            print('update h & g, freeze f')
        '''
        # freeze the parameter 
        '''
        if self.training_loss_flag == 1:
            # Freeze f(), only update g and h
            for param in self.recog_q_net.parameters():
                param.requires_grad = True
            for param in self.recog_q_net_velocity.parameters():
                param.requires_grad = True
            for param in self.MLP_Spec_mu.parameters():
                param.requires_grad = False   
            for param in self.obs_net.parameters():
                param.requires_grad = True   
            # compute the loss
            loss = - weight_recons * lhood_list.mean() \
                   + lambda_ * penalty_list.mean() \
                   + 1.0 * kl_q_list.mean()\

        elif self.training_loss_flag == 0:
            # Freeze the h() that predicts the states
            # only update f() that predicts the parameters
            for param in self.recog_q_net.parameters():
                param.requires_grad = False
            for param in self.recog_q_net_velocity.parameters():
                param.requires_grad = False
            for param in self.MLP_Spec_mu.parameters():
                param.requires_grad = True       
            for param in self.obs_net.parameters():
                param.requires_grad = False   
            # compute the loss
            loss = Time_loss_weight * Time_loss_list.mean()
        '''

        if MODEL_TYPE != 2:
            logs = {'MSE_mu': mse_mu, \
                    'MSE_L': mse_L, \
                    'R2': R2_para, \
                    'Recons_Loss': -lhood_list.mean(), \
                    'State_DFT_Loss': FFT_loss_list.mean(), \
                    'State_DFT_Pos_Loss': FFT_p_loss_list.mean(), \
                    'State_DFT_Vel_Loss': FFT_v_loss_list.mean(), \
                    'State_Loss': Time_loss_list.mean(), \
                    'State_Pos_Loss': Time_pos_loss_list.mean(), \
                    'State_Vel_Loss': Time_vel_loss_list.mean(), \
                    'VelocityAlignment_Frame_Loss': velocity_loss_list.mean(), \
                    'KL_loss': kl_q_list.mean(), \
                    'Regulization_loss':  penalty_list.mean(), \
                    'Regulization_loss_lambda':  lambda_, \
                    'loss': loss, \
                    'monitor': loss}
        else: 
            logs = {'MSE_mu': mse_mu, \
                        'MSE_L': mse_L, \
                        'R2': R2_para, \
                        'Recons_Loss': -lhood_list.mean(), \
                        'State_Loss': Time_loss_list.mean(), \
                        'State_Pos_Loss': Time_pos_loss_list.mean(), \
                        'State_Vel_Loss': Time_vel_loss_list.mean(), \
                        'VelocityAlignment_Frame_Loss': velocity_loss_list.mean(), \
                        'KL_loss': kl_q_list.mean(), \
                        'Regulization_loss':  penalty_list.mean(), \
                        'Regulization_loss_lambda':  lambda_, \
                        'loss': loss, \
                        'monitor': loss}

        # Log the running loss
        return {'loss':loss, 'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        """
        Specify the hyperparams for this LightningModule
        """
        # MODEL specific
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', default=lr, type=float)
        parser.add_argument('--batch_size', default=num_batch, type=int)
        
        return parser

def main(args):


    # Define the dataset type
    if dataset_type == 0:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L0.1-1.0_mu0.1-1.0_Hz100_sL750_nS64_DaTrue.pkl'
    elif dataset_type == 1:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_20Hz_HighinitV.pkl'
    elif dataset_type == 2:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_20Hz_HighinitV_mu.pkl'
    elif dataset_type == 3:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz100_sL750_nS64_DaTrue.pkl'
    elif dataset_type == 4:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz100_sL750_nS64_DaFalse.pkl'
    elif dataset_type == 5:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz100_sL405_nS32_DaFalse.pkl'
    elif dataset_type == 6:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz20_sL125_nS100_DaTrue.pkl'
    elif dataset_type == 7:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz100_sL125_nS100_DaTrue.pkl'
    elif dataset_type == 8:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-0.1_Hz20_sL102_nS10_DaTrue.pkl'
    elif dataset_type == 9:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz20_sL102_nS100_DaTrue_Sv5.0-10.0.pkl'
    elif dataset_type == 10:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz20_sL102_nS100_DaTrue_Sv0.5-4.0.pkl'
    elif dataset_type == 11:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz20_sL102_nS100_DaTrue_Sv0.5-1.0_Sp-1.5707963267948966-1.5707963267948966.pkl'
    elif dataset_type == 12:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz20_sL102_nS10_DaTrue_Sv1.0-1.0_Sp-1.5707963267948966-1.5707963267948966.pkl'
    elif dataset_type == 13:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz20_sL102_nS20_DaTrue_Sv1.0-1.0_Sp1.5707963267948966-1.5707963267948966.pkl'
    elif dataset_type == 14:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz20_sL125_nS1000_DaFalse_Sv0.5-4.0_Sp-3.141592653589793-3.141592653589793.pkl'


    model = Model(hparams=args, data_path=os.path.join(PARENT_DIR, 'datasets', dataset_name))

    # doc link for "ModelCheckpoint"
    # https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/model_checkpoint.py
    checkpoint_callback = ModelCheckpoint(monitor='monitor',
                                          dirpath=args.name + '/',
                                          filename='Model-{epoch:05d}-{loss:.2f}',
                                          save_top_k=5, 
                                          save_last=True)

    # doc link for "Trainer"
    # https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/trainer.py
    trainer = Trainer.from_argparse_args(args, 
                                         limit_train_batches=1,
                                         max_epochs=10000,
                                         deterministic=True,
                                         terminate_on_nan=True,
                                         log_every_n_steps=1,
                                         default_root_dir=os.path.join(PARENT_DIR, 'logs', args.name),
                                         checkpoint_callback=checkpoint_callback,gradient_clip_val=gradient_clip,track_grad_norm=2) 
    

    trainer.fit(model)

if __name__ == '__main__':

    parser = ArgumentParser(add_help=False)
    parser.add_argument('--name', default=''+save_dir, type=str)
    parser.add_argument('--T_pred', default=T_pred, type=int)
    parser.add_argument('--solver', default='rk4', type=str)#euler # rk4
    parser.add_argument('--homo_u', dest='homo_u', action='store_true')
    parser.add_argument('--annealing', dest='annealing', action='store_true')
    parser.set_defaults(homo_u=False, annealing=True)
    # Add args from trainer
    parser = Trainer.add_argparse_args(parser)
    # Give the module a chance to add own params
    # Good practice to define LightningModule speficic params in the module
    parser = Model.add_model_specific_args(parser)
    # Parse params
    args = parser.parse_args()

    main(args)


    # make gif
    #https://ezgif.com/maker/ezgif-6-69fc2d3f-gif