from torch_scatter import scatter
import torch
import numpy as np
out=torch.tensor((1,2))


for j in range(3):
    print(j)
    
from torch_geometric.nn import GCNConv

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import community as community_louvain
import torch
from torch_geometric.data import InMemoryDataset, Data
from scipy import sparse
from scipy.sparse import coo_matrix
# build a graph
G = nx.Graph()
edgelist = [(0, 1), (0, 2),(1,3),(3,0)]  # note that the order of edges
G.add_edges_from(edgelist)
x = torch.tensor([[0.26215712543900227,0.9085666240799217],[0.5628533963982644, 0.26962294593726166],[0.3482586437373957,0.6794686451031617],[0.221873877966124,0.15592168410567253]], dtype=torch.float)

adj = nx.to_scipy_sparse_matrix(G).tocoo()
row = torch.from_numpy(adj.row.astype(np.int64)).to(torch.long)
col = torch.from_numpy(adj.col.astype(np.int64)).to(torch.long)
edge_index = torch.stack([row, col], dim=0)


# Compute communities.
partition = community_louvain.best_partition(G)
y = torch.tensor([partition[i] for i in range(G.number_of_nodes())])

# Select a single training node for each community  ``
# (we just use the first one).
train_mask = torch.zeros(y.size(0), dtype=torch.bool)
for i in range(int(y.max()) + 1):
    train_mask[(y == i).nonzero(as_tuple=False)[0]] = True

data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask)


A=data.edge_index
value=torch.ones(data.edge_index.shape[1],dtype=torch.float32)
A = torch.sparse_coo_tensor(indices=A, values=value , size=[data.x.shape[0], data.x.shape[0]])
def compute_ppr(a, self_loop=True):
    a = a.to_dense()
    a_new=torch.mm(a,a)
    a_new =a_new-a-torch.eye(a.shape[0])
    return a_new

A=compute_ppr(A)
zero = torch.zeros_like(A)
one = torch.ones_like(A)
new_A1=torch.where(A < 0,zero,A)
new_A=torch.where(new_A1 > 1,one,new_A1)
new_A=new_A-torch.eye(new_A.shape[0])
new_A=torch.where(new_A < 0,zero,new_A)
adj=coo_matrix(new_A)
row = torch.from_numpy(adj.row.astype(np.int64)).to(torch.long)
col = torch.from_numpy(adj.col.astype(np.int64)).to(torch.long)
edge_index = torch.stack([row, col], dim=0)

# [0, 1, 2, 1, 3, 2, 3]
# [0, 1, 1, 2, 2, 3, 3]






remaining = (~data.train_mask).nonzero(as_tuple=False).view(-1)
remaining = remaining[torch.randperm(remaining.size(0))]
data.test_mask = torch.zeros(y.size(0), dtype=torch.bool)
data.test_mask.fill_(False)
data.test_mask[remaining[:]] = True



edge_index = data.edge_index
edge_index_i,edge_index_j = edge_index
second_neighbour=None
for edge_num in range(edge_index_i.shape[0]):
    for index_i in range(edge_index_i.shape[0]):
        if edge_index_j[edge_num] == edge_index_i[index_i] and edge_index_i[edge_num] !=edge_index_j[index_i] :
            x_i=edge_index_i[edge_num].view(-1,1)
            x_j=edge_index_j[index_i].view(-1,1)
            index=torch.cat((x_i,x_j),axis = 0)
            if second_neighbour==None:
                second_neighbour=index
            else:
                second_neighbour=torch.cat((second_neighbour,index),axis=1)

edge_index=torch.cat((edge_index,second_neighbour),axis=1)
uniques = np.unique(edge_index.numpy(),axis=0)

edge_index=torch.from_numpy(uniques)
data.edge_index=edge_index
nx.draw_networkx(G)
plt.show()


                    









import torch.nn.functional as F
from gcn_conv_copy import GCNConv

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(data.num_node_features, 2)
        self.conv2 = GCNConv(2, 2)

    def forward(self):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        # x = F.relu(x)
        # x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

    
device = torch.device('cpu')
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam([
    dict(params=model.conv1.parameters(), weight_decay=5e-4),
    dict(params=model.conv2.parameters(), weight_decay=0)
], lr=0.01)  # Only perform weight-decay on first convolution.

def train():
    optimizer.zero_grad()  
    out = model()
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
def test():
    model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

for epoch in range(1, 11):
    train()
    log = 'Epoch: {:03d}, Train: {:.4f}, Test: {:.4f}'
    print(log.format(epoch, *test()))