def get_feature_vector(a):
    mapping = {}
    mapping['region'] = {}
    names = ['northwest_antelope_valley', 'northeast_antelope_valley', 'lancaster', 'quartz_hill', 'lake_los_angeles', 'northwest_palmdale', 'leona_valley', 'palmdale', 'desert_view_highlands', 'sun_village', 'littlerock', 'acton', 'southeast_antelope_valley']
    for i, region in enumerate(names):
        mapping['region'][region] = [0]*len(names)
        mapping['region'][region][i] = 1
    mapping['ethnicity'] = {}
    mapping['ethnicity']['latino'] = [1, 0, 0, 0, 0]
    mapping['ethnicity']['white'] = [0, 1, 0, 0, 0]
    mapping['ethnicity']['black'] = [0, 0, 1, 0, 0]
    mapping['ethnicity']['asian'] = [0, 0, 0, 1, 0]
    mapping['ethnicity']['other'] = [0, 0, 0, 0, 1]
    mapping['age'] = {}
    mapping['age']['18-24'] = [0]
    mapping['age']['25-29'] = [1]
    mapping['age']['30-39'] = [2]
    mapping['age']['40-49'] = [3]
    mapping['age']['50-59'] = [4]
    mapping['age']['60-64'] = [5]
    mapping['age']['65+'] = [6]
    mapping['gender'] = {}
    mapping['gender']['female'] = [0]
    mapping['gender']['male'] = [1]
    mapping['status'] = {}
    mapping['status']['normal'] = [0]
    mapping['status']['overweight'] = [1]
    mapping['status']['obese'] = [2]
    x = []
    for feature in ['region', 'ethnicity', 'age', 'gender', 'status']:
        x.extend(mapping[feature][a[feature]])
    return x


def spatial_pa(X, r, m, p_recip):
    '''
    Generates a random spatial preferential attachment network. 
    
    n: number of nodes
    
    r: rate of decay of connection probability with distance
    
    m: number of neighbors each node forms
    '''
    import networkx as nx
    import numpy as np
    from math import exp
    from scipy.spatial.distance import squareform, pdist
    import random
    g = nx.DiGraph()
    g.add_node(0)
    #pairwise distance corresponding to n random points in the unit square
    X = np.array(X)
    X[:, :13] *= 5
    n = X.shape[0]
    dist = squareform(pdist(X))
    for i in range(1, n):
        g.add_node(i)
        #probability of connecting to each previously arrived nodes
        distr = np.zeros((i))
        for j in range(i):
            distr[j] += exp(-dist[i,j]/r)*g.degree(j)
        #normalize
        distr /= distr.sum()
        #draw m connections
        for j in range(m):
            neighbor = np.random.choice(range(i), p = distr)
            g.add_edge(i, neighbor)
            if random.random() < p_recip:
                g.add_edge(neighbor, i)
            
    return g

import json
import numpy as np
import networkx as nx
import random
import scipy as sp
import scipy.stats
import pickle
nvals = [500]
#nvals = [250, , 3000]
#nvals = [250]

for n in nvals:
    for netnum in range(100):
        with open('data/synthetic_spa/agents_' + str(n) + '_' + str(netnum) + '.json', 'r') as f:
            a = json.load(f)
        X = []
        for agent in a:
            X.append(get_feature_vector(agent))
        X = np.array(X, dtype=np.float)
        r = 0.5
        g = spatial_pa(X, r, 5, 1)
        for i, agent in enumerate(a):
            for feature in ['region', 'ethnicity', 'age', 'gender', 'status']:
                g.node[i][feature] = agent[feature]
        status = X[:, -1]
#        status /= 2.0
#        status[status == 0.5] = 0.1 
#        for i in range(n):
#            if random.random() < 0.3:
#                status[i] = random.choice([0, 1, 2])
        state1 = [status[u] for u,v in g.edges()]
        state2 = [status[v] for u,v in g.edges()]
        print(n, sp.stats.kendalltau(state1, state2)[0])
#        pickle.dump(g, open('graph_spa_{}_{}_{}.pickle'.format(n, r, netnum), 'wb'))
        nx.write_edgelist(g, 'data/synthetic_spa/synthetic_spa_{:.2f}_{}.cites'.format(r, netnum), data=False)
        
        node_indices = np.array(list(range(n))).reshape((n, 1))
        statuses = X[:, -1].reshape((n, 1))
        featurematrix = np.concatenate((node_indices, X, statuses), axis=1)
        np.savetxt('data/synthetic_spa/synthetic_spa_{:.2f}_{}.content'.format(r, netnum), featurematrix, fmt = '%d')
#        nx.write_edgelist(g, 'new_graphs/av_new_graph_'+ str(n) + '_' + str(netnum) +'.txt', data=False)
#        np.savetxt('new_graphs/av_new_status_'+ str(n) + '_' + str(netnum) + '.txt', status)