import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import imageio
from cal_constraints import *
from proj_io_adapter_io_adapter import *


DEMO = "rigid_44"

def plot_gif(before_project, after_project, arrow, name = '0'):
    def generate_one_frame(data, pred, arrow, xy_max = 3):

        fig, ax = plt.subplots(figsize=(10,10))

        X, Y, U, V = zip(*[arrow])
        ax.quiver(X, Y, U, V, angles='xy', scale_units='xy', scale=1)

        ax.scatter(data[:,0], data[:,1], c = 'y')
        ax.scatter(pred[:,0], pred[:,1], c = 'b')       
        ax.grid()
        ax.set(xlabel='X', ylabel='Y', title='yellow: points before projection; blue: after projection.')
        ax.set_xlim(-xy_max, xy_max)
        ax.set_ylim(-xy_max, xy_max)
        fig.canvas.draw()     

        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
        image  = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        
        return image
    
    imageio.mimsave('./gif/' + name + '.gif', [
        generate_one_frame(before_project[i], after_project[i], arrow[i]) 
        for i in range(len(before_project))
        ], fps=10)

class Proj():
    def __init__(self):
        model_path = 'models/rigid_16.pt'
        self.model = torch.load(model_path).cpu()
    def project(self, data):
        d = torch.Tensor(data[None, :, :])
        pred = self.model(d)
        pred = pred[0,:,:]
        return pred.detach().numpy()

def test():
    data = np.array([[ -0.0705747,  -0.0477225],
        [ -0.143344, 0.308463],
        [-0.037158,  0.483448],
        [ -0.112173, 0.829575],
        [0.233954,  0.0116891],
        [ 0.166068, 0.207133],
        [0.111511,  0.50575],
        [ 0.160129, 0.709886],
        [0.364264,  -0.065618],
        [ 0.338761, 0.206584],
        [0.360962,  0.466334],
        [0.370712,  0.765977],
        [ 0.670355, 0.0703561],
        [ 0.724735, 0.256815],
        [0.661194,  0.515199],
        [0.669577, 0.747849]])


    initial_shape = np.array(data)
    
    force1 = lambda t: np.array([np.sin(t*2), np.cos(t*2)]) * 15
    force2 = lambda t: np.array([-np.sin(t*2), -np.cos(t*2)]) * 15
    
    vel = np.array(data)*0
    timestamp = 0.1
    data_list = [np.array(data)]
    proj_list = [np.array(data)]
    force_list = []
    force_list.append(np.array([[data[0,0], data[15,0]],
            [data[0,1], data[15,1]],
            [0, 0],
            [0, 0]]))
    pr = Proj()
    for ite in range(100):
        new_data = np.array(data)
        
        f_0 = force1(ite * timestamp)
        f_3 = force2(ite * timestamp)
        
        vel[0, :] += f_0 * timestamp
        vel[15, :] += f_3 * timestamp
        
        for i in range (len(data)):
            new_data[i, :] = data[i,:] + vel[i,:] * timestamp   

        data_list.append(np.array(new_data))
        force_list.append(np.array([[new_data[0,0], new_data[15,0]],
            [new_data[0,1], new_data[15,1]],
            [f_0[0]/20, f_3[0]/20],
            [f_0[1]/20, f_3[1]/20]]))

        new_data = pr.project(new_data)
        proj_list.append(np.array(new_data))
        
        for i in range (len(data)):
            vel[i,:] = (new_data[i,:] - data[i,:]) / timestamp
        data = (np.array(new_data))

    plot_gif(data_list, proj_list, force_list,  DEMO)
    run(proj_list, [1,0,0,0,0], initial_shape, DEMO)
    write_viewer(proj_list, output_dir = "output_" + DEMO)
    cal_mse(proj_list, 2, DEMO)

if __name__ == '__main__':
    test()


