from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from functools import partial
import hashlib
import logging
import math
import os
import pickle
import sys
import warnings

from absl import flags
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

from gnn import *
from graph_data import *
from loss import *
from utils import *

warnings.filterwarnings("ignore")

# Attention params.
flags.DEFINE_integer('attn_kv_dim', 20, '')
flags.DEFINE_integer('attn_output_dim', 20, '')
flags.DEFINE_integer('attn_num_heads', 8, '')
flags.DEFINE_bool('attn_multi_proj', True, '')
flags.DEFINE_integer('attn_multi_proj_dim', 160, '')
flags.DEFINE_bool('attn_concat', True, '')
flags.DEFINE_bool('attn_residual', False, '')
flags.DEFINE_bool('attn_layer_norm', False, '')

# Dataset params.
flags.DEFINE_string('dataset', 'graph_rnn_grid', '')
flags.DEFINE_float('split_train_percent', 0.8, '')

# Eval params.
flags.DEFINE_bool('run_eval', True, '')
flags.DEFINE_integer('eval_every_n_steps', 100, '')

# Training params.
flags.DEFINE_integer('random_seed', 12345, '')
flags.DEFINE_integer('tf_random_seed', 601904901297, '')
flags.DEFINE_string('logdir', 'test_gnn', 'Where to write training files.')
flags.DEFINE_integer('train_batch_size', 32, '')
flags.DEFINE_integer('num_train_iters', 200000, '')
flags.DEFINE_integer('log_every_n_steps', 50, '')
flags.DEFINE_integer('summary_every_n_steps', 25, '')
flags.DEFINE_bool('save_trajectories', False, '')
flags.DEFINE_integer('max_checkpoints_to_keep', 5, '')
flags.DEFINE_integer('save_every_n_iter', 10000, '')
flags.DEFINE_string('pretrained_model', '', '')

# Optimizer params.
flags.DEFINE_float('learning_rate', 5e-04, 'Learning rate for optimizer.')
flags.DEFINE_float('momentum', 0.99, 'Used for momentum optimizer.')
flags.DEFINE_bool('use_learning_rate_decay', True, '')
flags.DEFINE_integer('learning_rate_decay_steps', 1000, '')
flags.DEFINE_float('learning_rate_decay_rate', 0.99, '')
flags.DEFINE_bool('learning_rate_decay_staircase', True, '')
flags.DEFINE_integer('learning_rate_rampup', 1000, '')
flags.DEFINE_integer('learning_rate_hold', 2000, '')
flags.DEFINE_bool('use_momentum_decay', True, '')
flags.DEFINE_integer('momentum_decay_steps', 1000, '')
flags.DEFINE_float('momentum_decay_rate', 0.96, '')
flags.DEFINE_bool('momentum_decay_staircase', True, '')
flags.DEFINE_float('l2_regularizer_weight', 0.0000001,
                   'Used to regularizer weights and biases of MLP.')

# Loss + distance function.
flags.DEFINE_string('loss_type', 'binary', 'Can be binary or triplet.')
flags.DEFINE_string('binary_dist_fn', 'hacky_sigmoid_l2', '')
flags.DEFINE_bool('tune_sigmoid', False, '')

# Triplet loss.
flags.DEFINE_string('triplet_dist_fn', 'l2', '')
flags.DEFINE_string('triplet_adj_dist_fn', 'hacky_sigmoid_l2', '')
flags.DEFINE_string('triplet_loss_fn', 'margin', 'Can be margin or relative.')
flags.DEFINE_integer('num_sampling_loops', 8, '')

# GRevNet params.
flags.DEFINE_integer('grevnet_num_coupling_layers', 10, '')

# GNN params.
flags.DEFINE_bool('residual', False, '')
flags.DEFINE_bool('weight_sharing', True, '')
flags.DEFINE_bool('use_batch_norm', True, '')
flags.DEFINE_bool('use_layer_norm', False, '')
flags.DEFINE_bool('use_fc_adj_mat', False, '')
flags.DEFINE_bool('bn_test_local_stats', True, '')
flags.DEFINE_integer('num_layers', 3, 'Num of layers of MLP used in GNN.')
flags.DEFINE_integer('latent_dim', 2048, 'Latent dim of MLP used in GNN.')
flags.DEFINE_integer('num_processing_steps', 15,
                     'Number of steps to take in the GNN.')
flags.DEFINE_float('bias_init_stddev', 0.3,
                   'Used for initializing bias weights in GNN.')
flags.DEFINE_float(
    'node_weighting_epsilon', 2.0,
    'How much to weight current node embedding over its neighbors.')

# Graph params.
flags.DEFINE_string('graph_type', 'grid', 'Can be grid or barabasi.')

# Node feature params.
flags.DEFINE_integer('node_embedding_dim', 14, 'Dimension of node embeddings.')
flags.DEFINE_string('node_features', 'gaussian',
                    'Can be laplacian, gaussian, or zero.')
flags.DEFINE_float('gaussian_scale', 0.3,
                   'Scale to use for random Gaussian features.')
flags.DEFINE_integer('laplacian_random_seed', 1234,
                     'Random seed used for Laplacian feature generation.')

# Grid graph params.
flags.DEFINE_string('graph_dim', '10,10', '')

# Barabasi-Albert graph params.
flags.DEFINE_integer('barabasi_n', '20', 'Num of nodes in graph.')
flags.DEFINE_integer(
    'barabasi_m', '4',
    'Num of edges to attach from new node to existing nodes.')

flags.DEFINE_integer('skip_gnn_output_norm', 1, '')
flags.DEFINE_bool('print_adj', False, '')

FLAGS = tf.app.flags.FLAGS
logdir_prefix = os.environ.get('MLPATH')
if not logdir_prefix:
    logdir_prefix = '.'
LOGDIR = os.path.join(logdir_prefix, FLAGS.logdir)
os.makedirs(LOGDIR)

# Logging and print options.
np.set_printoptions(suppress=True, formatter={'float': '{: 0.3f}'.format})
handlers = [logging.StreamHandler(sys.stdout)]
handlers.append(logging.FileHandler(os.path.join(LOGDIR, 'OUTPUT_LOG')))
logging.basicConfig(level=logging.INFO, handlers=handlers)
logger = logging.getLogger("logger")

tf.random.set_random_seed(FLAGS.tf_random_seed)
random.seed(FLAGS.random_seed)

temp_const = tf.constant(10.0)
shift_const = tf.constant(1.0)
temp = tf.get_variable(
    'temp', trainable=True, initializer=temp_const,
    dtype=tf.float32) if FLAGS.tune_sigmoid else temp_const
shift = tf.get_variable(
    'shift', trainable=True, initializer=shift_const,
    dtype=tf.float32) if FLAGS.tune_sigmoid else shift_const

TRIPLET_LOSS_FN_MAP = {'margin': margin_loss, 'relative': relative_loss}
DIST_FN_MAP = {
    'exp_l2': exp_l2,
    'sigmoid_dot': sigmoid_dot,
    'hacky_sigmoid_l2': hacky_sigmoid_l2,
    'dot': dot,
    'l2': l2,
    'sigmoid_l2': partial(sigmoid_l2, temp=temp, shift=shift),
}
NODE_FEATURES_MAP = {
    'laplacian':
    partial(add_laplacian_features, num_components=FLAGS.node_embedding_dim),
    'gaussian':
    partial(
        add_gaussian_noise_features,
        num_components=FLAGS.node_embedding_dim,
        scale=FLAGS.gaussian_scale),
    'zeros':
    partial(add_zero_features, num_components=FLAGS.node_embedding_dim),
    'positional':
    partial(
        add_positional_encoding_features,
        num_components=FLAGS.node_embedding_dim),
    'gaussian_adj':
    partial(
        add_adj_gaussian_features,
        num_components=FLAGS.node_embedding_dim,
        scale=1.0),
}
add_node_features_fn = NODE_FEATURES_MAP[FLAGS.node_features]
GRAPH_DIM = tuple(int(x) for x in FLAGS.graph_dim.split(','))

DATASET_MAP = {
    'grid_all':
    partial(get_grid_dataset_all, add_node_features_fn),
    'grid_split':
    partial(get_grid_dataset_split, add_node_features_fn),
    'grid_train_even_test_odd':
    partial(get_grid_dataset_train_even_test_odd, add_node_features_fn),
    'grid_train_odd_test_even':
    partial(get_grid_dataset_train_odd_test_even, add_node_features_fn),
    'grid_test_larger':
    partial(get_grid_dataset_all_test_larger, add_node_features_fn),
    'grid_test_smaller':
    partial(get_grid_dataset_all_test_smaller, add_node_features_fn),
    'grid_test_square':
    partial(get_grid_dataset_all_test_square, add_node_features_fn),
    'grid_single':
    partial(get_grid_dataset_single, GRAPH_DIM, add_node_features_fn),
    'community_large_split':
    partial(get_large_community_dataset_split, add_node_features_fn),
    'community_small_split':
    partial(get_small_community_dataset_split, add_node_features_fn),
    'ego_large_split':
    partial(get_large_ego_dataset_split, add_node_features_fn),
    'ego_small_split':
    partial(get_small_ego_dataset_split, add_node_features_fn),
    'graph_rnn_grid':
    partial(load_graph_rnn_dataset,
            'graphs/GraphRNN_RNN_grid_4_128_train_0.dat',
            add_node_features_fn),
    'graph_rnn_protein':
    partial(load_graph_rnn_dataset,
            'graphs/GraphRNN_RNN_protein_4_128_train_0.dat',
            add_node_features_fn),
    'graph_rnn_ego':
    partial(load_graph_rnn_dataset,
            'graphs/GraphRNN_RNN_citeseer_4_128_train_0.dat',
            add_node_features_fn),
    'graph_rnn_community':
    partial(load_graph_rnn_dataset,
            'graphs/GraphRNN_RNN_caveman_4_128_train_0.dat',
            add_node_features_fn),
    'graph_rnn_ego_small':
    partial(load_graph_rnn_dataset,
            'graphs/GraphRNN_RNN_citeseer_small_4_64_train_0.dat',
            add_node_features_fn),
    'graph_rnn_community_small':
    partial(load_graph_rnn_dataset,
            'graphs/GraphRNN_RNN_caveman_small_4_64_train_0.dat',
            add_node_features_fn),
}

binary_dist_fn = DIST_FN_MAP[FLAGS.binary_dist_fn]
triplet_dist_fn = DIST_FN_MAP[FLAGS.triplet_dist_fn]
triplet_adj_dist_fn = DIST_FN_MAP[FLAGS.triplet_adj_dist_fn]
triplet_loss_fn = TRIPLET_LOSS_FN_MAP[FLAGS.triplet_loss_fn]
dataset = DATASET_MAP[FLAGS.dataset]()

# Define GNN and output.
true_graph_phs = gn.utils_tf.placeholders_from_networkxs(
    dataset.test_set, force_dynamic_num_graphs=True, name="true_graph_phs")

fc_graph_phs = None
if FLAGS.use_fc_adj_mat:
    fc_graph_phs = gn.utils_tf.placeholders_from_networkx(
        dataset.test_set, name="fc_graph_phs")
make_mlp_fn = partial(
    make_mlp_model,
    FLAGS.latent_dim,
    FLAGS.node_embedding_dim,
    FLAGS.num_layers,
    l2_regularizer_weight=FLAGS.l2_regularizer_weight,
    bias_init_stddev=FLAGS.bias_init_stddev)
self_attn_gnn = partial(
    self_attn_gnn,
    kv_dim=FLAGS.attn_kv_dim,
    output_dim=FLAGS.attn_output_dim,
    make_mlp_fn=make_mlp_fn,
    batch_size=FLAGS.train_batch_size,
    num_heads=FLAGS.attn_num_heads,
    multi_proj_dim=FLAGS.attn_multi_proj_dim,
    multi_proj=FLAGS.attn_multi_proj,
    concat=FLAGS.attn_concat,
    residual=FLAGS.attn_residual,
    layer_norm=FLAGS.attn_layer_norm)

grevnet_make_mlp_fn = partial(
    make_mlp_model,
    FLAGS.latent_dim,
    FLAGS.node_embedding_dim / 2,
    FLAGS.num_layers,
    l2_regularizer_weight=FLAGS.l2_regularizer_weight,
    bias_init_stddev=FLAGS.bias_init_stddev)
grevnet_self_attn_gnn = partial(
    self_attn_gnn,
    kv_dim=FLAGS.attn_kv_dim,
    output_dim=FLAGS.attn_output_dim,
    make_mlp_fn=grevnet_make_mlp_fn,
    batch_size=FLAGS.train_batch_size,
    num_heads=FLAGS.attn_num_heads,
    multi_proj_dim=FLAGS.attn_multi_proj_dim,
    multi_proj=FLAGS.attn_multi_proj,
    concat=FLAGS.attn_concat,
    residual=FLAGS.attn_residual,
    layer_norm=FLAGS.attn_layer_norm)

is_training = tf.placeholder(tf.bool, name="is_training")
gnn = TimestepGNN(
    self_attn_gnn,
    FLAGS.num_processing_steps,
    weight_sharing=FLAGS.weight_sharing,
    use_batch_norm=FLAGS.use_batch_norm,
    residual=FLAGS.residual,
    test_local_stats=FLAGS.bn_test_local_stats,
    use_layer_norm=FLAGS.use_layer_norm)
gnn_output = gnn(
    fc_graph_phs if fc_graph_phs else true_graph_phs, is_training=is_training)
gnn_output = gn.utils_tf.stop_gradient(gnn_output)
gnn_output_norm = tf.norm(gnn_output.nodes, axis=1)

# Get latent zs.
locs = tf.contrib.layers.fully_connected(
    gnn_output.nodes, num_outputs=FLAGS.node_embedding_dim, activation_fn=None)
scales = tf.contrib.layers.fully_connected(
    gnn_output.nodes,
    num_outputs=FLAGS.node_embedding_dim,
    activation_fn=tf.nn.relu)
epsilon = tf.random.normal(
    [tf.reduce_sum(gnn_output.n_node), FLAGS.node_embedding_dim])
sampled_zs = locs + scales * epsilon

# log prob(A|z).
new_gnn_output = gnn_output.replace(nodes=sampled_zs)
log_prob_reconstruction, pred_adj, true_adj = reconstruction_prob(
    new_gnn_output, true_graph_phs, hacky_sigmoid_l2)
num_incorrect = num_incorrect(true_adj, pred_adj)
log_prob_reconstruction = tf.reduce_sum(log_prob_reconstruction)

# log prob(z).
grevnet = GRevNet(
    grevnet_self_attn_gnn,
    FLAGS.grevnet_num_coupling_layers,
    FLAGS.node_embedding_dim,
    use_batch_norm=True)
grevnet_reverse_output, log_det_jacobian = grevnet(
    new_gnn_output, inverse=True)
grevnet_output_norm = tf.norm(grevnet_reverse_output.nodes, axis=1)
mvn = tfd.MultivariateNormalDiag(
    tf.zeros(FLAGS.node_embedding_dim), tf.ones(FLAGS.node_embedding_dim))
log_prob_zs = tf.reduce_sum(mvn.log_prob(
    grevnet_reverse_output.nodes)) + log_det_jacobian

# entropy H(z|A).
epsilon = 0.0001
entropy = 0.5 * tf.math.log(2 * math.pi * math.e * scales + epsilon)
total_entropy = tf.reduce_sum(entropy)
total_log_prob = log_prob_reconstruction + log_prob_zs + total_entropy
total_loss = -1 * total_log_prob

# Optimizer.
global_step = tf.Variable(0, trainable=False, name='global_step')
decaying_learning_rate = tf.train.exponential_decay(
    learning_rate=FLAGS.learning_rate,
    global_step=global_step,
    decay_steps=FLAGS.learning_rate_decay_steps,
    decay_rate=FLAGS.learning_rate_decay_rate,
    staircase=FLAGS.learning_rate_decay_staircase)
learning_rate = decaying_learning_rate if FLAGS.use_learning_rate_decay else FLAGS.learning_rate

learning_rate_placeholder = tf.placeholder(
    tf.float32, [], name='learning_rate')

optimizer = tf.train.AdamOptimizer(learning_rate)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    step_op = optimizer.minimize(total_loss, global_step=global_step)

for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
    tf.summary.histogram(v.op.name, v)

tf.summary.scalar('total_loss', total_loss)
tf.summary.scalar('num_incorrect', num_incorrect)
tf.summary.scalar('log_prob_reconstruction', log_prob_reconstruction)
tf.summary.scalar('log_prob_zs', log_prob_zs)
tf.summary.scalar('total_entropy', total_entropy)
tf.summary.scalar('log_det_jacobian', log_det_jacobian)

merged = tf.summary.merge_all()
sess = reset_sess()
train_writer = tf.summary.FileWriter(os.path.join(LOGDIR, 'train'), sess.graph)
eval_writer = tf.summary.FileWriter(os.path.join(LOGDIR, 'test'), sess.graph)

flags_map = tf.app.flags.FLAGS.flag_values_dict()
with open(os.path.join(LOGDIR, 'desc.txt'), 'w') as f:
    for (k, v) in flags_map.items():
        f.write("{}: {}\n".format(k, str(v)))

saver = tf.train.Saver(tf.global_variables(scope="TimestepGNN"))
checkpoint = tf.train.latest_checkpoint(
    os.path.join(logdir_prefix, FLAGS.pretrained_model))
logger.info("Restoring encoder-decoder model from {}".format(checkpoint))

#saver.restore(sess, checkpoint)
logger.info("Done restoring encoder-decoder model!")

train_values = {}
values_map = {
    "merge": merged,
    "step_op": step_op,
    "true_graph_phs": true_graph_phs,
    "gnn_output": gnn_output.nodes,
    "total_loss": total_loss,
    "log_prob_reconstruction": log_prob_reconstruction,
    "log_prob_zs": log_prob_zs,
    "total_entropy": total_entropy,
    "true_adj": true_adj,
    "pred_adj": pred_adj,
    "num_incorrect": num_incorrect,
    "gnn_output_norm": gnn_output_norm,
    "grevnet_output_norm": grevnet_output_norm,
}

for k, v in values_map.items():
    if k is not "true_graph_phs":
        tf.add_to_collection(k, v)
'''
eval_values_map = {
    "merge": merged,
    "eval_mean_loss": mean_loss,
    "eval_sum_loss": sum_loss,
    "eval_num_incorrect": num_incorrect,
    "eval_true_graph_phs": true_graph_phs,
}
'''

for iteration in range(0, FLAGS.num_train_iters):
    # If needed run eval step.
    '''
    if FLAGS.run_eval and iteration % FLAGS.eval_every_n_steps == 0:
        feed_dict = dataset.get_test_set(true_graph_phs)
        feed_dict[is_training] = False
        values = sess.run(eval_values_map, feed_dict=feed_dict)
        eval_writer.add_summary(values["merge"], iteration)
        logger.info("*" * 100)
        logger.info("iteration num: {}".format(iteration))
        logger.info("eval sum loss: {}".format(values["eval_sum_loss"]))
        logger.info("eval mean loss: {}".format(values["eval_mean_loss"]))
        logger.info("eval num incorrect: {}".format(
            values["eval_num_incorrect"]))
        logger.info("eval num nodes: {}".format(
            values["eval_true_graph_phs"].n_node))
    '''

    # Run train step.
    feed_dict = dataset.get_next_train_batch(FLAGS.train_batch_size,
                                             true_graph_phs)
    feed_dict[is_training] = True
    feed_dict[learning_rate_placeholder] = get_learning_rate(
        iteration, FLAGS.learning_rate, FLAGS.learning_rate_rampup,
        FLAGS.learning_rate_hold)

    train_values = sess.run(values_map, feed_dict=feed_dict)
    train_values["gnn_output_norm"].sort()
    if train_writer and (iteration % FLAGS.summary_every_n_steps == 0):
        train_writer.add_summary(train_values['merge'], iteration)
    if iteration % FLAGS.log_every_n_steps == 0:
        logger.info("*" * 100)
        logger.info("iteration num: {}".format(iteration))
        logger.info("num nodes: {}".format(
            train_values["true_graph_phs"].n_node))
        logger.info("total loss: {}".format(train_values["total_loss"]))
        logger.info("log prob reconstruction: {}".format(
            train_values["log_prob_reconstruction"]))
        logger.info("log prob zs: {}".format(train_values["log_prob_zs"]))
        logger.info("total entropy: {}".format(train_values["total_entropy"]))
        logger.info("num incorrect: {}".format(train_values["num_incorrect"]))
        logger.info("gnn output norm:{}".format(
            np.mean(train_values["gnn_output_norm"])))
        logger.info("grevnet output norm:{}".format(
            np.mean(train_values["grevnet_output_norm"])))

        if FLAGS.print_adj:
            logger.info("true_adj:\n{}".format(train_values["true_adj"]))
            logger.info("pred_adj:\n{}".format(train_values["pred_adj"]))

    # Save model.
    '''
    if iteration % FLAGS.save_every_n_iter == 0:
        saver.save(
            sess, os.path.join(LOGDIR, 'checkpoints'), global_step=global_step)
    '''

print_adjacency_summary(logger, train_values)
