import torch.nn as nn
import torch.nn.functional as F
from pygcn.layers import GraphConvolution
import torch
import numpy as np


class Dense_GCN(nn.Module):
    def __init__(self, nfeat, nlayers, nhid, nclass, dropout, activation, eye=True, with_features=False):
        super(Dense_GCN, self).__init__()
        self.dropout = dropout
        self.activation = activation
        self.eye=eye
        self.with_features=with_features
        self.hidden = nn.ModuleList()
        for k in range(nlayers):
            self.hidden.append(GraphConvolution(k*nhid+nfeat,nhid))
        if self.with_features == True:
            self.out = GraphConvolution((nlayers)*nhid+nfeat,nclass)
        else:
            self.out = GraphConvolution((nlayers)*nhid,nclass)

    def forward(self, x, adj):
        layer_input = x
        extracted_features=torch.Tensor([]).cuda()
        for layer in self.hidden:
            x = self.activation(layer(x,adj))
            x = F.dropout(x, self.dropout, training=self.training)
            layer_input = torch.cat([x, layer_input],1)
            extracted_features = torch.cat([x, extracted_features],1)
            x=layer_input
        if self.with_features == True:
            output= self.out(layer_input, adj, eye=self.eye)
            return F.log_softmax(output, dim=1), layer_input
        else:
            output= self.out(extracted_features, adj, eye=self.eye)
            return F.log_softmax(output, dim=1), extracted_features

    def reset_parameters(self):
        for layer in self.hidden:
            layer.reset_parameters()
        self.out.reset_parameters()

class Simplified_GCN(nn.Module):
    def __init__(self, nfeat, nlayers, nhid, nclass, dropout, activation,eye,with_features=False):
        super(Simplified_GCN, self).__init__()
        self.dropout = dropout
        self.activation = activation
        self.eye=eye
        self.with_features=with_features
        self.hidden = nn.ModuleList()
        width = np.hstack([nfeat, np.repeat(nhid,nlayers)])
        for k in range(nlayers):
            self.hidden.append(GraphConvolution(width[k],width[k+1]))
    
        if self.with_features == True:
            self.out = GraphConvolution(np.sum(width),nclass)
        else:
            self.out = GraphConvolution(np.sum(width)-nfeat,nclass)
 
    def forward(self, x, adj):
        layer_input = torch.Tensor([]).cuda()
        input_feature=x
        for layer in self.hidden:
            x = self.activation(layer(x,adj))
            x = F.dropout(x, self.dropout, training=self.training)
            layer_input=torch.cat([x, layer_input],1)
        
        if self.with_features == True:
            output= self.out(torch.cat([input_feature, layer_input],1), adj, eye=self.eye)
            return F.log_softmax(output, dim=1), torch.cat([input_feature, layer_input],1)
        else:
            output= self.out(layer_input, adj, eye=self.eye)
            return F.log_softmax(output, dim=1), torch.cat([layer_input],1)
    def reset_parameters(self):
        for layer in self.hidden:
            layer.reset_parameters()
        self.out.reset_parameters()

class GCN(nn.Module):
    def __init__(self, nfeat, nlayers, nhid, nclass, dropout, activation,eye, with_features=False):
        super(GCN, self).__init__()
        self.dropout = dropout
        self.activation = activation
        self.eye=eye
        self.with_features=with_features
        self.hidden = nn.ModuleList()
        width = np.hstack([nfeat, np.repeat(nhid,nlayers)])
        for k in range(nlayers):
            self.hidden.append(GraphConvolution(width[k],width[k+1]))
        self.out = GraphConvolution(nhid,nclass)
    
    
    #self.gc1 = GraphConvolution(nfeat, nhid)
    #self.gc2 = GraphConvolution(nhid, nclass)
    #self.dropout = dropout
    
    def forward(self, x, adj):
        layer_input = x
        for layer in self.hidden:
            x = self.activation(layer(x,adj))
            x = F.dropout(x, self.dropout, training=self.training)
            layer_input=torch.cat([x, layer_input],1)
        output= self.out(x, adj, eye=self.eye)
        return F.log_softmax(output, dim=1), layer_input
    
    def reset_parameters(self):
        for layer in self.hidden:
            layer.reset_parameters()
        self.out.reset_parameters()
