import torch
from torch.utils.data import *
from matplotlib import pyplot as plt
import numpy as np

def read_data(start = 0, end = 6000):
    root_path = 'data/distance_constraint_square_rope_6/'
    data_path = lambda i : root_path + 'data_distance_constraint_square_rope_6_' + str(i) + '_ori.txt'
    label_path = lambda i : root_path + 'data_distance_constraint_square_rope_6_' + str(i) + '.txt'


    N = end - start #1024
    TS = 32
    P_N = 10
    DIM = 2
    #  N * (P_N*DIM)
    data = torch.ones([N*TS, P_N, DIM])
    label = torch.ones([N*TS, P_N, DIM])
    for i in range (N):
        data_f = open(data_path(i+start), "r") 
        label_f = open(label_path(i+start), "r") 
        for t in range (TS):
            for j in range (P_N):
                for k in range (DIM):
                    data[i*TS + t, j, k] = float(data_f.readline())
                    label[i*TS + t, j, k] = float(label_f.readline())
                    # label[i*TS + t, j+P_N, k] = data[i*TS + t, j, k]
        data_f.close()
        label_f.close()

    # for i in range(data.shape[0]):
    #     print(data[i, :, :])
    #     target = label[i, :, :]
    #     target1 = target[:target.shape[0]//2, :]
    #     target2 = target[target.shape[0]//2:, :]
    #     print(target1, target2)

    errors = np.sum((np.abs(data.numpy()-label[:, :P_N, :].numpy())), axis=(1,2))
    print(errors.size)
    index = np.argwhere(errors < 2)
    # print(index)
    # print(errors[index])
    index = index[:, 0]
    print(index.shape)
    data = data[index, :, :]
    label = label[index, :, :]

    sum_diff = 0
    sum_abs_diff = 0
    for i in range (data.shape[0]):
        c = (data[i, :, :] - label[i, :P_N, :])**2
        sum_diff += sum(sum(c))
        sum_abs_diff += sum(sum(np.abs((data[i, :, :] - label[i, :P_N, :]))))
        # print(sum(sum(np.abs((data[i, :, :] - label[i, :, :])))))
    print(sum_diff / data.shape[0])
    print(sum_abs_diff / data.shape[0])
    print(len(data))

    return data, label
               
    

def get_data_loader():
    d, l= read_data() # generate_2dim(num)
    num = d.shape[0]
    idx = num // 10 * 9
    d_tr = d[0:idx, :, :].cuda()
    d_val = d[idx:len(d), :, :].cuda()
    print(d_tr.shape); print(d_val.shape)
    l_tr = l[0:idx, :, :].cuda()
    l_val = l[idx:len(d), :, :].cuda()
    train_ds = TensorDataset(d_tr, l_tr)
    val_ds = TensorDataset(d_val, l_val) 
    return train_ds, val_ds

def plot(d, l):
    plt.scatter(d[:,0,0], d[:,0,1], c = 'r')
    plt.scatter(l[:,0,0], l[:,0,1], c = 'b')
    plt.show()
    

def view_data(d, l):
    
    data = d.numpy()
    label = l.numpy()

    # print(label[0,:,:])

    for i in range(data.shape[0]):
        plt.scatter(data[i,:,0], data[i,:,1], c = 'y')
        plt.scatter(label[i,:,0], label[i,:,1], c = 'g')
        error = np.sum(np.sum(np.abs(label[i,:,:] - data[i,:,:])))
        plt.title(str(i) + 'error: ' + str(error))
        xy_max = 2
        plt.xlim(-xy_max, xy_max)
        plt.ylim(-xy_max, xy_max)
        plt.show()


if __name__ == '__main__':
    # d, l= generate_2dim(1000)
    # print(d.size(), l.size())
    # plt.scatter(d[:,0,0], d[:,0,1], c = 'r')
    # plt.scatter(l[:,0,0], l[:,0,1], c = 'b')
    # plt.show()
    d, l = read_data(0, 4)
    # print(d); print(l)
    print(d.shape); print(l.shape)
    view_data(d, l)