import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input, adj):
        h = torch.mm(input, self.W)
        N = h.size()[0]

        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class SpecialSpmmFunction(torch.autograd.Function):
    """Special function for only sparse region backpropataion layer."""
    @staticmethod
    def forward(ctx, indices, values, shape, b):
        assert indices.requires_grad == False
        a = torch.sparse_coo_tensor(indices, values, shape)
        ctx.save_for_backward(a, b)
        ctx.N = shape[0]
        return torch.matmul(a, b)

    @staticmethod
    def backward(ctx, grad_output):
        a, b = ctx.saved_tensors
        grad_values = grad_b = None
        if ctx.needs_input_grad[1]:
            grad_a_dense = grad_output.matmul(b.t())
            edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
            grad_values = grad_a_dense.view(-1)[edge_idx]
        if ctx.needs_input_grad[3]:
            grad_b = a.t().matmul(grad_output)
        return None, grad_values, None, grad_b


class SpecialSpmm(nn.Module):
    def forward(self, indices, values, shape, b):
        return SpecialSpmmFunction.apply(indices, values, shape, b)

    
class SpGraphAttentionLayer(nn.Module):
    """
    Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(SpGraphAttentionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_normal_(self.W.data, gain=1.414)
                
        self.a = nn.Parameter(torch.zeros(size=(1, 2*out_features)))
        nn.init.xavier_normal_(self.a.data, gain=1.414)

        self.dropout = nn.Dropout(dropout)
        self.leakyrelu = nn.LeakyReLU(self.alpha)
        self.special_spmm = SpecialSpmm()

    def forward(self, input, adj):
        N = input.size()[0]
#        edge = adj.nonzero().t()
        edge = adj.coalesce().indices()

        h = torch.mm(input, self.W)
        # h: N x out
        if torch.isnan(h).any():
            print(h)
            print(input)
            print(self.W)
        assert not torch.isnan(h).any()

        # Self-attention on the nodes - Shared attention mechanism
        edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t()
        # edge: 2*D x E

        edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze()))
        assert not torch.isnan(edge_e).any()
        # edge_e: E

        e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N,1)))
        # e_rowsum: N x 1

        edge_e = self.dropout(edge_e)
        # edge_e: E

        h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h)
        assert not torch.isnan(h_prime).any()
        # h_prime: N x out
        
        h_prime = h_prime.div(e_rowsum)
        # h_prime: N x out
        assert not torch.isnan(h_prime).any()

        if self.concat:
            # if this layer is not last layer,
            return F.elu(h_prime)
        else:
            # if this layer is last layer,
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        """Dense version of GAT."""
        super(GAT, self).__init__()
        self.dropout = dropout

        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)

    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))
        return x


class SpGAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        """Sparse version of GAT."""
        super(SpGAT, self).__init__()
        self.dropout = dropout

        self.attentions = [SpGraphAttentionLayer(nfeat, 
                                                 nhid, 
                                                 dropout=dropout, 
                                                 alpha=alpha, 
                                                 concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        self.out_att = SpGraphAttentionLayer(nhid * nheads, 
                                             nclass, 
                                             dropout=dropout, 
                                             alpha=alpha, 
                                             concat=False)

    def forward(self, x, adj):
        print('start forward')
        x = F.dropout(x, self.dropout, training=self.training)
        print(x)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        print(x)
        x = F.dropout(x, self.dropout, training=self.training)
        print(x)
        x = F.elu(self.out_att(x, adj))
        print(x)
        print('end forward')
        return x


import torch
from torch.nn import Parameter
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax

import math
from torch_geometric.nn import GATConv



def uniform(size, tensor):
    bound = 1.0 / math.sqrt(size)
    if tensor is not None:
        tensor.data.uniform_(-bound, bound)


def kaiming_uniform(tensor, fan, a):
    bound = math.sqrt(6 / ((1 + a**2) * fan))
    if tensor is not None:
        tensor.data.uniform_(-bound, bound)


def glorot(tensor):
    stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
    if tensor is not None:
        tensor.data.uniform_(-stdv, stdv)


def zeros(tensor):
    if tensor is not None:
        tensor.data.fill_(0)


def ones(tensor):
    if tensor is not None:
        tensor.data.fill_(1)


def reset(nn):
    def _reset(item):
        if hasattr(item, 'reset_parameters'):
            item.reset_parameters()

    if nn is not None:
        if hasattr(nn, 'children') and len(list(nn.children())) > 0:
            for item in nn.children():
                _reset(item)
        else:
            _reset(nn)

#class GATConv(MessagePassing):
#    r"""The graph attentional operator from the `"Graph Attention Networks"
#    <https://arxiv.org/abs/1710.10903>`_ paper
#
#    .. math::
#        \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} +
#        \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},
#
#    where the attention coefficients :math:`\alpha_{i,j}` are computed as
#
#    .. math::
#        \alpha_{i,j} =
#        \frac{
#        \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
#        [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j]
#        \right)\right)}
#        {\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
#        \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
#        [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k]
#        \right)\right)}.
#
#    Args:
#        in_channels (int): Size of each input sample.
#        out_channels (int): Size of each output sample.
#        heads (int, optional): Number of multi-head-attentions.
#            (default: :obj:`1`)
#        concat (bool, optional): If set to :obj:`False`, the multi-head
#        attentions are averaged instead of concatenated. (default: :obj:`True`)
#        negative_slope (float, optional): LeakyReLU angle of the negative
#            slope. (default: :obj:`0.2`)
#        dropout (float, optional): Dropout probability of the normalized
#            attention coefficients which exposes each node to a stochastically
#            sampled neighborhood during training. (default: :obj:`0`)
#        bias (bool, optional): If set to :obj:`False`, the layer will not learn
#            an additive bias. (default: :obj:`True`)
#    """
#
#    def __init__(self,
#                 in_channels,
#                 out_channels,
#                 heads=1,
#                 concat=True,
#                 negative_slope=0.2,
#                 dropout=0,
#                 bias=True):
#        super(GATConv, self).__init__('add')
#
#        self.in_channels = in_channels
#        self.out_channels = out_channels
#        self.heads = heads
#        self.concat = concat
#        self.negative_slope = negative_slope
#        self.dropout = dropout
#
#        self.weight = Parameter(
#            torch.Tensor(in_channels, heads * out_channels))
#        self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels))
#
#        if bias and concat:
#            self.bias = Parameter(torch.Tensor(heads * out_channels))
#        elif bias and not concat:
#            self.bias = Parameter(torch.Tensor(out_channels))
#        else:
#            self.register_parameter('bias', None)
#
#        self.reset_parameters()
#
#    def reset_parameters(self):
#        glorot(self.weight)
#        glorot(self.att)
#        zeros(self.bias)
#
#
#    def forward(self, x, edge_index):
#        """"""
#        edge_index, _ = remove_self_loops(edge_index)
#        edge_index = add_self_loops(edge_index, num_nodes=x.size(0))
#
#        x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels)
#        return self.propagate(edge_index, x=x, num_nodes=x.size(0))
#
#
#    def message(self, edge_index_i, x_i, x_j, num_nodes):
#        # Compute attention coefficients.
#        alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
#        alpha = F.leaky_relu(alpha, self.negative_slope)
#        alpha = softmax(alpha, edge_index_i, num_nodes)
#
#        # Sample attention coefficients stochastically.
#        if self.training and self.dropout > 0:
#            alpha = F.dropout(alpha, p=self.dropout, training=True)
#
#        return x_j * alpha.view(-1, self.heads, 1)
#
#    def update(self, aggr_out):
#        if self.concat is True:
#            aggr_out = aggr_out.view(-1, self.heads * self.out_channels)
#        else:
#            aggr_out = aggr_out.mean(dim=1)
#
#        if self.bias is not None:
#            aggr_out = aggr_out + self.bias
#        return aggr_out
#
#    def __repr__(self):
#        return '{}({}, {}, heads={})'.format(self.__class__.__name__,
#                                             self.in_channels,
#                                             self.out_channels, self.heads)
#


class GATGeom(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        """Dense version of GAT."""
        super(GATGeom, self).__init__()
        self.dropout = dropout
        
        self.conv1 = GATConv(nfeat, nhid, dropout=dropout, heads=nheads, negative_slope=alpha)
        self.conv2 = GATConv(nhid*nheads, nclass, dropout=dropout, heads=1, negative_slope=alpha)

    def forward(self, x, adj):
        edge_index = adj.indices()
        x = F.dropout(x, p=self.dropout, training=self.training)
#        print(x)
#        print(torch.isnan(x).any())
        x = F.elu(self.conv1(x, edge_index))
#        print(x)
#        print(torch.isnan(x).any())
        x = F.dropout(x, p=self.dropout, training=self.training)
#        print(x)
#        print(torch.isnan(x).any())
        x = self.conv2(x, edge_index)
#        print(x)
#        print(torch.isnan(x).any())
        x = x + 0.000001
        return x
