from pygcn import load_data
import numpy as np
import networkx as nx
import torch

dataset = 'fb_small'
adj_test, features, labels, idx_train, idx_val, idx_test = load_data('data/{}/'.format(dataset), '{}'.format(dataset))
g = nx.from_numpy_array(adj_test.to_dense().detach().numpy())
for c in nx.connected_components(g):
    subg = g.subgraph(c)
    nodes_include = list(c)
    nodes_include.sort()
    subfeats = features[nodes_include, :]
    sublabels = labels[nodes_include].unsqueeze(1)
    break
print(len(subg))

edges = torch.tensor(nx.to_numpy_array(subg)).nonzero()
idx = torch.tensor(list(range(len(subg)))).unsqueeze(1)
idx_feat_lab = torch.cat((idx.float(), subfeats, sublabels.float()), dim=1)
idx_feat_lab = idx_feat_lab.numpy()

    
test_pct = 0.5
valid_pct = 0.1
train_pct = 1 - test_pct - valid_pct

#edges = np.loadtxt('data/cora/cora.cites', dtype=int)

m = edges.shape[0]
order = np.random.permutation(list(range(m)))
edges_train = edges[order[:int(m*train_pct)]]
edges_valid = edges[order[int(m*train_pct):int(int(m*train_pct) + int(m*valid_pct))]]
edges_test = edges[order[int(int(m*train_pct) + int(m*valid_pct)):]]

np.savetxt('data/{}_connected/{}_connected.cites'.format(dataset, dataset), edges, fmt='%d')
np.savetxt('data/{}_connected/{}_connected_train_{:.2f}.cites'.format(dataset, dataset, round(train_pct,2)), edges_train, fmt='%d')
np.savetxt('data/{}_connected/{}_connected_valid_{:.2f}.cites'.format(dataset,dataset, round(train_pct, 2)), edges_valid, fmt='%d')
np.savetxt('data/{}_connected/{}_connected_test_{:.2f}.cites'.format(dataset, dataset, round(train_pct, 2)), edges_test, fmt='%d')
#np.savetxt('data/{}_connected/{}_connected_train_{:.2f}.content'.format(dataset, dataset, round(train_pct, 2)), idx_feat_lab, fmt='%d')
#np.savetxt('data/{}_connected/{}_connected_test_{:.2f}.content'.format(dataset, dataset, round(train_pct, 2)), idx_feat_lab, fmt='%d')
#np.savetxt('data/{}_connected/{}_connected_valid_{:.2f}.content'.format(dataset, dataset, round(train_pct, 2)), idx_feat_lab, fmt='%d')

#edges = np.loadtxt('data/citeseer/citeseer_old.cites', dtype=str)
#edges_num = np.zeros(edges.shape)
#nodes = list(np.unique(edges))
#for i in range(edges.shape[0]):
#    edges_num[i, 0] = nodes.index(edges[i, 0])
#    edges_num[i, 1] = nodes.index(edges[i, 1])
#
#features = np.loadtxt('data/citeseer/citeseer_old.content', dtype=str)
#for i in range(features.shape[0]):
#    features[i, 0] = str(nodes.index(features[i, 0]))
#np.savetxt('data/citeseer/citeseer.content', features, delimiter='\t', fmt = '%s')
#nodes_with_features = [int(x) for x in np.unique(features[:, 0])]
#
#edges_data = np.zeros(edges_num.shape)
#cur_row = 0
#for i in range(edges_num.shape[0]):
#    if edges_num[i, 0] in nodes_with_features and edges_num[i, 1] in nodes_with_features:
#        edges_data[cur_row] = edges_num[i]
#        cur_row += 1
#edges_data = edges_data[:cur_row]
#np.savetxt('data/citeseer/citeseer.cites', edges_data, delimiter='\t', fmt = '%d')
#
#
#idx_features_labels = np.genfromtxt('data/citeseer/citeseer.content', dtype=np.dtype(str))
#idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
#idx_map = {j: i for i, j in enumerate(idx)}
#edges_unordered = np.genfromtxt("data/citeseer/citeseer.cites",
#                                    dtype=np.int32)
#edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
#                     dtype=np.int32).reshape(edges_unordered.shape)


