import torch
from torch.utils.data import *
from matplotlib import pyplot as plt
import numpy as np

def read_data(start = 0, end = 8192):
    root_path = 'data/distance_constraint_rigid_collision_32_4/'
    data_path = lambda i : root_path + 'data_distance_constraint_rigid_collision_32_4_' + str(i) + '_ori.txt'
    label_path = lambda i : root_path + 'data_distance_constraint_rigid_collision_32_4_' + str(i) + '.txt'

    N = end - start #1024
    TS = 32
    P_N = 8
    PN_A = 32
    DIM = 2
    #  N * (P_N*DIM)
    data = torch.ones([N*TS, P_N, DIM])
    label = torch.ones([N*TS, P_N, DIM])

    temp = [1, 7, 8, 14] # [0, 3, 12, 15]
    temp += [i + 16 for i in temp]
    print(temp)

    for i in range (N):
        data_f = open(data_path(i+start), "r") 
        label_f = open(label_path(i+start), "r") 
        for t in range (TS):
            t_j = 0
            for j in range (PN_A):
                if (j in temp):
                    for k in range (DIM):
                        # print(t, j, t_j)
                        data[i*TS + t, t_j, k] = float(data_f.readline())
                        label[i*TS + t, t_j, k] = float(label_f.readline())
                    t_j += 1
                else:
                    for k in range (DIM):
                        data_f.readline()
                        label_f.readline()
        data_f.close()
        label_f.close()


    errors = np.sum(np.abs((data.numpy()-label.numpy())), axis=(1,2))
    print(errors.size)
    index = np.argwhere(errors < 1)
    # print(index)
    # print(errors[index])
    index = index[:, 0]
    print(index.shape)
    data = data[index, :, :]
    label = label[index, :, :]

    sum_diff = 0
    sum_abs_diff = 0
    for i in range (data.shape[0]):
        c = (data[i, :, :] - label[i, :, :])**2
        sum_diff += sum(sum(c))
        sum_abs_diff += sum(sum(np.abs((data[i, :, :] - label[i, :, :]))))
    print(sum_diff / data.shape[0])
    print(sum_abs_diff / data.shape[0])
    print(len(data))

    return data, label
               
    

def get_data_loader():
    d, l= read_data() # generate_2dim(num)
    num = d.shape[0]
    idx = num // 10 * 9
    d_tr = d[0:idx, :, :].cuda()
    d_val = d[idx:len(d), :, :].cuda()
    print(d_tr.shape); print(d_val.shape)
    l_tr = l[0:idx, :, :].cuda()
    l_val = l[idx:len(d), :, :].cuda()
    train_ds = TensorDataset(d_tr, l_tr)
    val_ds = TensorDataset(d_val, l_val) 
    return train_ds, val_ds

def plot(d, l):
    plt.scatter(d[:,0,0], d[:,0,1], c = 'r')
    plt.scatter(l[:,0,0], l[:,0,1], c = 'b')
    plt.show()
    

def view_data(d, l):
    
    data = d.numpy()
    label = l.numpy()

    # print(label[0,:,:])

    for i in range(data.shape[0]):
        plt.scatter(data[i,:,0], data[i,:,1], c = 'y')
        plt.scatter(label[i,:,0], label[i,:,1], c = 'g')
        error = np.sum(np.sum(np.abs((label[i,:,0] - data[i,:,0]))))
        plt.title(str(i) + 'error: ' + str(error))
        plt.show()


if __name__ == '__main__':
    # d, l= generate_2dim(1000)
    # print(d.size(), l.size())
    # plt.scatter(d[:,0,0], d[:,0,1], c = 'r')
    # plt.scatter(l[:,0,0], l[:,0,1], c = 'b')
    # plt.show()
    d, l = read_data(0, 40)
    # print(d); print(l)
    print(d.shape); print(l.shape)
    view_data(d, l)