import torch
from torch.utils.data import *
from matplotlib import pyplot as plt



def read_data(start = 0, end = 2048):
    root_path = 'data/distance_constraint_modified_square/'
    data_path = lambda i : root_path + 'data_distance_constraint_modified_square_' + str(i) + '_ori.txt'
    label_path = lambda i : root_path + 'data_distance_constraint_modified_square_' + str(i) + '.txt'

    N = end - start 
    TS = 20 
    P_N = 4
    DIM = 2
    #  N * (P_N*DIM)
    data = torch.ones([N*TS, P_N, DIM])
    label = torch.ones([N*TS, P_N, DIM])
    for i in range (N):
        data_f = open(data_path(i+start), "r") 
        label_f = open(label_path(i+start), "r") 
        for t in range (TS):
            for j in range (P_N):
                for k in range (DIM):
                    data[i*TS + t, j, k] = float(data_f.readline())
                    label[i*TS + t, j, k] = float(label_f.readline())
        data_f.close()
        label_f.close()
    
    sum_diff = 0
    for i in range (data.shape[0]):
        c = (data[i, :, :] - label[i, :, :])**2
        sum_diff += sum(sum(c))
    print(sum_diff / data.shape[0])
    print(len(data))
    return data, label
               
    

def get_data_loader():
    d, l= read_data() # generate_2dim(num)
    num = d.shape[0]
    idx = num // 10 * 8
    d_tr = d[0:idx, :, :].cuda()
    d_val = d[idx:len(d), :, :].cuda()
    print(d_tr.shape); print(d_val.shape)
    l_tr = l[0:idx, :, :].cuda()
    l_val = l[idx:len(d), :, :].cuda()
    train_ds = TensorDataset(d_tr, l_tr)
    val_ds = TensorDataset(d_val, l_val) 
    return train_ds, val_ds

def plot(d, l):
    plt.scatter(d[:,0,0], d[:,0,1], c = 'r')
    plt.scatter(l[:,0,0], l[:,0,1], c = 'b')
    plt.show()
    

if __name__ == '__main__':
    # d, l= generate_2dim(1000)
    # print(d.size(), l.size())
    # plt.scatter(d[:,0,0], d[:,0,1], c = 'r')
    # plt.scatter(l[:,0,0], l[:,0,1], c = 'b')
    # plt.show()
    d, l = read_data()
    # print(d); print(l)
    print(d.shape); print(l.shape)