import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import imageio
from cal_constraints import *
from proj_io_adapter_io_adapter import *

DEMO = "collision_2"

def plot_gif(before_project, after_project, arrow, boundary_list, name = '0'):
    def generate_one_frame(data, pred, arrow, boundary, xy_max = 2):

        fig, ax = plt.subplots(figsize=(10,10))
        
        circle1 = plt.Circle((0, 0), 2.0, color='r', fill=False)
        ax.add_artist(circle1)

        X, Y, U, V = zip(*[arrow])
        ax.quiver(X, Y, U, V, angles='xy', scale_units='xy', scale=1)

        ax.scatter(boundary[:,0], boundary[:,1], c = 'r')       
        ax.scatter(data[:,0], data[:,1], c = 'y')
        ax.scatter(pred[:,0], pred[:,1], c = 'b')       
        ax.grid()
        ax.set(xlabel='X', ylabel='Y', title='yellow: points before projection; blue: after projection.')
        ax.set_xlim(-xy_max, xy_max)
        ax.set_ylim(-xy_max, xy_max)
        fig.canvas.draw()     

        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
        image  = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        
        return image
    
    imageio.mimsave('./gif/' + name + '.gif', [
        generate_one_frame(before_project[i], after_project[i], arrow[i], boundary_list[i]) 
        for i in range(len(before_project)-1)
        ], fps=10)

class Proj():
    def __init__(self):
        model_path = 'models/colli.pt'
        self.model = torch.load(model_path).cpu()
        print(self.model.iter)
    def project(self, data):
        d = torch.Tensor(data[None, :, :])
        pred = self.model(d)
        pred = pred[0,:,:]
        return pred.detach().numpy()

def get_boundary(new_data):
    center1 = (new_data[0,:]+new_data[1,:]+new_data[2,:]+new_data[3,:])/4
    boundary_data = np.zeros([4,2])
    boundary_data[0, :] = 0.6 * (new_data[3,:]-center1) - 1.2 * (new_data[1,:]-center1) + center1
    boundary_data[1, :] = 1.2 * (new_data[3,:]-center1) + 0.6 * (new_data[1,:]-center1) + center1
    boundary_data[2, :] = - 0.6 * (new_data[3,:]-center1) + 1.2 * (new_data[1,:]-center1) + center1
    boundary_data[3, :] = - 1.2 * (new_data[3,:]-center1) - 0.6 * (new_data[1,:]-center1) + center1
    return boundary_data

def get_full_data(boundary_data):
    full_data = np.zeros([16,2])
    x1 = boundary_data[3,0]
    y1 = boundary_data[3,1]
    x2 = boundary_data[1,0]
    y2 = boundary_data[1,1]
    # print(boundary_data)
    for i in range(4):
        for j in range(4):
            temp1 = (j * boundary_data[3,:] + (3-j) * boundary_data[2,:] ) / 3
            temp2 = (j * boundary_data[0,:] + (3-j) * boundary_data[1,:] ) / 3
            full_data[i*4+j, :] = (i * temp1 + (3-i) * temp2) / 3
    return full_data


def test(te = 0):
    rigid = np.array([[ 0.130514,   0.437662 ],
                        [ 0.0862001,  0.882675 ],
                        [ 0.575526,   0.481976 ],
                        [ 0.531213,   0.926988 ]])

    
    data = np.zeros([16,2])
    data[0:4, :] = rigid + np.array([-1.2,-0.5])
    data[4:8, :] = rigid + np.array([0.25,-0.5])
    data[8:12, :] = rigid + np.array([-0.2,0.6])
    data[12:16, :] = rigid + np.array([-1.3,0.5])

    full_data = np.zeros([64,2])
    boundary_data = np.zeros([16,2])
    for i in range(4):
        boundary_data[i*4:(i+1)*4, :] = get_boundary(data[i*4:(i+1)*4,:])
        full_data[i*16:(i+1)*16, :] = get_full_data(boundary_data[i*4:(i+1)*4,:])
    # print(boundary_data)
    # print(full_data)


    timestamp = 0.1
    vel = np.array([[3.0, 0.0],[3.0, 0.0],[3.0, 0.0],[3.0, 0.0],
                    [1.0, 0.0],[1.0, 0.0],[1.0, 0.0],[1.0, 0.0],
                    [4.0, 0.0],[4.0, 0.0],[4.0, 0.0],[4.0, 0.0],
                    [4.0, 0.0],[4.0, 0.0],[4.0, 0.0],[4.0, 0.0]])


    force = lambda t: np.array([0.5, np.cos(t)]) * 0 
    g = np.array([0, -5])
    vel *= 0.3

    initial_boundary = np.array([[ 0.6,  0.0],
        [ 0.6, 0.6],
        [0.0,  0.6],
        [0.0, 0.0],[ 0.6,  0.0],
        [ 0.6, 0.6],
        [0.0,  0.6],
        [0.0, 0.0],[ 0.6,  0.0],
        [ 0.6, 0.6],
        [0.0,  0.6],
        [0.0, 0.0],[ 0.6,  0.0],
        [ 0.6, 0.6],
        [0.0,  0.6],
        [0.0, 0.0]])




    data_list = [np.array(data)]
    proj_list = [np.array(data)]
    boundary_list = [boundary_data]
    full_list = [full_data]
    force_list = [np.array([0, 0, 0, 0])]
    
    pr = Proj()

    print(len(proj_list))

    for ite in range(100):
        new_data = np.array(data)
        boundary_data = np.array(data)
        f_3 = force(ite * timestamp)
        vel[3, :] += f_3 * timestamp
        for i in range (0, len(data)):
            vel[i, :] += timestamp * g
            new_data[i, :] = data[i,:] + vel[i,:] * timestamp 

        data_list.append(np.array(new_data))
        force_list.append(np.array([new_data[3,0], new_data[3,1], f_3[0]/50, f_3[1]/50]))

        for i in range(10):
            for j in range(4):
                for k in range(j+1,4):
                    temp = np.concatenate((new_data[j*4:(j+1)*4, :], new_data[k*4:(k+1)*4, :]), axis=0)
                    temp =  pr.project(temp)
                    new_data[j*4:(j+1)*4, :] = temp[0:4, :]
                    new_data[k*4:(k+1)*4, :] = temp[4:8, :]


        full_data = np.zeros([64,2])
        boundary_data = np.zeros([16,2])
        for i in range(4):
            boundary_data[i*4:(i+1)*4, :] = get_boundary(new_data[i*4:(i+1)*4,:])
            full_data[i*16:(i+1)*16, :] = get_full_data(boundary_data[i*4:(i+1)*4,:])

        # print(boundary_data)

        proj_list.append(np.array(new_data))
        boundary_list.append(np.array(boundary_data))
        full_list.append(np.array(full_data))
        
        for i in range (len(data)):
           vel[i,:] =  (new_data[i,:] - data[i,:]) / timestamp 

        data = (np.array(new_data))

    print(len(full_list))
    print(len(data_list))
    print(len(data_list))

    plot_gif(data_list, proj_list, force_list, full_list, DEMO)
    run(boundary_list, [0,0,0,0,1], initial_boundary, DEMO)
    run_e(boundary_list, [1,0,0,0,0], initial_boundary, DEMO)
    write_viewer(full_list, output_dir = "output_" + DEMO)




if __name__ == '__main__':
    test()
