import numpy as np

def get_my_neighbors():
    from torch_geometric.datasets import Planetoid
    dataset=Planetoid(root='./dataset',name='Cora')
    data=dataset.data
    edge_index=data.edge_index
    edge_index=edge_index.numpy()
    node_neighbors = {}
    for i in range(data.x.shape[0]):
        node_neighbors[i] = [i] 
    for i in range(data.edge_index.shape[1]):
          node_neighbors[edge_index[0,i]].append(edge_index[1,i])
    np.save('node_neighbors.npy', node_neighbors)


def get_my_feats():
    from torch_geometric.datasets import Planetoid
    dataset=Planetoid(root='./dataset',name='Cora')
    data=dataset.data
    feats=data.x.numpy()
    np.save('feats.npy', feats)
def get_my_labels():
    from torch_geometric.datasets import  Planetoid
    dataset=Planetoid(root='./dataset',name='Cora')
    data=dataset.data
    labels=data.y.numpy()
    np.save('labels.npy', labels)


def dataset_splits(data, num_classes, train_rate=0.6, val_rate=0.2):

    indices = []
    percls_trn=[]
    val_lb=[]
    
    for i in range(num_classes):
        index = (data.y == i).nonzero().view(-1) 
        index = index[torch.randperm(index.size(0))]
        percls_trn.append(int(round(train_rate*len(index))))
        val_lb.append(int(round(val_rate*len(index))))
        indices.append(index)
    train_index=[]
    val_index=[]
    test_index=[]
    for i in range(num_classes):
        t=percls_trn[i]
        v=val_lb[i]
        index=indices[i] 
        train_index.append(index[:t])
        val_index.append(index[t:t+v])
        test_index.append(index[t+v:])
        
    train_index = torch.cat([i for i in train_index], dim=0)
    val_index = torch.cat([i for i in val_index], dim=0)
    test_index = torch.cat([i for i in test_index], dim=0)
    
    return train_index,val_index,test_index


def get_t_v_t(train_index, val_index, test_index):
    file = open('train_set.txt','w')
    for i in range(len(train_index)):
         a=train_index[i].numpy()
         file.writelines(str(a)+'\n')

    file = open('val_set.txt','w')
    for i in range(len(val_index)):
         b=val_index[i].numpy()
         file.writelines(str(b)+'\n')

    file = open('test_set.txt','w')
    for i in range(len(test_index)):
         c=test_index[i].numpy()
         file.writelines(str(c)+'\n')

from torch_geometric.datasets import Amazon
import torch
dataset = Amazon(root='./dataset',name='Photo')
data=dataset.data
num_classes=dataset.num_classes
train_index,val_index,test_index = dataset_splits(data, num_classes,train_rate=0.6, val_rate=0.2)

get_t_v_t(train_index,val_index,test_index)
get_my_neighbors()
get_my_feats()
get_my_labels()


