from pygcn import load_data
import numpy as np


dataset = 'protein_vidal'

test_pct = 0.5
valid_pct = 0.1
train_pct = 1 - test_pct - valid_pct

edges = np.loadtxt('data/{}/{}.cites'.format(dataset, dataset), dtype=int)

m = edges.shape[0]
order = np.random.permutation(list(range(m)))
edges_train = edges[order[:int(m*train_pct)]]
edges_valid = edges[order[int(m*train_pct):int(int(m*train_pct) + int(m*valid_pct))]]
edges_test = edges[order[int(int(m*train_pct) + int(m*valid_pct)):]]
np.savetxt('data/{}/{}.cites'.format(dataset, dataset, edges, fmt='%d'))
np.savetxt('data/{}/{}_train_{:.2f}.cites'.format(dataset, dataset, round(train_pct,2)), edges_train, fmt='%d')
np.savetxt('data/{}/{}_valid_{:.2f}.cites'.format(dataset, dataset, round(train_pct, 2)), edges_valid, fmt='%d')
np.savetxt('data/{}/{}_test_{:.2f}.cites'.format(dataset, dataset, round(train_pct, 2)), edges_test, fmt='%d')

#edges = np.loadtxt('data/citeseer/citeseer_old.cites', dtype=str)
#edges_num = np.zeros(edges.shape)
#nodes = list(np.unique(edges))
#for i in range(edges.shape[0]):
#    edges_num[i, 0] = nodes.index(edges[i, 0])
#    edges_num[i, 1] = nodes.index(edges[i, 1])
#
#features = np.loadtxt('data/citeseer/citeseer_old.content', dtype=str)
#for i in range(features.shape[0]):
#    features[i, 0] = str(nodes.index(features[i, 0]))
#np.savetxt('data/citeseer/citeseer.content', features, delimiter='\t', fmt = '%s')
#nodes_with_features = [int(x) for x in np.unique(features[:, 0])]
#
#edges_data = np.zeros(edges_num.shape)
#cur_row = 0
#for i in range(edges_num.shape[0]):
#    if edges_num[i, 0] in nodes_with_features and edges_num[i, 1] in nodes_with_features:
#        edges_data[cur_row] = edges_num[i]
#        cur_row += 1
#edges_data = edges_data[:cur_row]
#np.savetxt('data/citeseer/citeseer.cites', edges_data, delimiter='\t', fmt = '%d')
#
#
#idx_features_labels = np.genfromtxt('data/citeseer/citeseer.content', dtype=np.dtype(str))
#idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
#idx_map = {j: i for i, j in enumerate(idx)}
#edges_unordered = np.genfromtxt("data/citeseer/citeseer.cites",
#                                    dtype=np.int32)
#edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
#                     dtype=np.int32).reshape(edges_unordered.shape)


