"""
Compute the node embeddings
with continuous WL
"""

import numpy as np


from sklearn.preprocessing import scale



import argparse
import igraph as ig
import os

parser = argparse.ArgumentParser(description = "Computing kernel")
parser.add_argument('--data_dir', required=True, help = "Path to the data directory")
parser.add_argument('--data_name', required=True, help = "Dataset Name")
parser.add_argument('--h', type = int, required=True, help = "Number of WL iterations")

args = parser.parse_args()

data_dir = args.data_dir
data_name = args.data_name
h = args.h


def create_adj_avg(adj_cur):

	"""
	create adjacency
	"""
	
	deg = np.sum(adj_cur, axis = 1)

	deg = np.asarray(deg).reshape(-1)

	deg[deg!=1] -= 1

	deg = 1/deg
	deg_mat = np.diag(deg)
	adj_cur = adj_cur.dot(deg_mat.T).T

	
	return adj_cur


def create_labels_seq_cont(node_features, adj_mat, h):
	
	n_graphs = len(node_features)

	labels_sequence = []

	for i in range(n_graphs):

		graph_feat = []

		for it in range(h+1):
			
			
			if it == 0:
				graph_feat.append(node_features[i])
			else:

				adj_cur = adj_mat[i]+np.identity(adj_mat[i].shape[0])
				
				adj_cur = create_adj_avg(adj_cur)

				np.fill_diagonal(adj_cur, 0)
				graph_feat_cur = 0.5*(np.dot(adj_cur, graph_feat[it-1]) + graph_feat[it-1])


				graph_feat.append(graph_feat_cur)


		labels_sequence.append(np.concatenate(graph_feat, axis = 1))
		if i % 100 == 0:
			print("Processed %d graphs out of %d" % (i, n_graphs))
	
	return labels_sequence


def read_gml(file_name):

	node_features = []
	g = ig.read(file_name)
		
	if not 'label' in g.vs.attribute_names():
		g.vs['label'] = list(map(str, [l for l in g.vs.degree()]))    
	
	node_features = g.vs['label']

	adj_mat = np.asarray(g.get_adjacency().data)
	
	return node_features, adj_mat


files = os.listdir("%s/%s" % (data_dir, data_name))

graphs = [g for g in files if g.endswith('gml')]
graphs.sort()


node_features = []
adj_mat = []

n_nodes = []

for it_graph in range(len(graphs)):

	file_name = "%s/%s/%s" % (data_dir, data_name, graphs[it_graph])
	node_features_cur, adj_mat_cur = read_gml(file_name)

	node_features.append(np.asarray(node_features_cur).astype(float).reshape(-1,1))
	adj_mat.append(adj_mat_cur.astype(int))
	n_nodes.append(adj_mat_cur.shape[0])

am = adj_mat

node_features = np.load('./data/%s_node_feat.npy' % (data_name))



n_nodes = []
for i in range(node_features.shape[0]):
	n_nodes.append(node_features[i].shape[0])

n_nodes = np.asarray(n_nodes)


node_features_data = scale(np.concatenate(node_features, axis =0), axis = 0)
splits_idx = np.cumsum(n_nodes).astype(int)
node_features_split = np.vsplit(node_features_data,splits_idx)		
node_features = node_features_split[:-1]

labels_sequence = create_labels_seq_cont(node_features, adj_mat, h)

node_aggr_type = 'avg'


out_dir = "./output/embeddings" 

if not os.path.exists(out_dir):
	os.makedirs(out_dir)

out_name = "%s_wl_cont_embed_h%d" % (data_name, h)

np.save("%s/%s" % (out_dir, out_name), labels_sequence)

