import torch
import os.path as osp
import aug.augmentors as A
import torch_geometric.transforms as T
import torch.nn.functional as F

from torch import nn
from tqdm import tqdm
from torch.optim import Adam
from aug.eval import get_split, LREvaluator
from torch_geometric.nn import GCNConv
from torch_geometric.nn.inits import uniform
from torch_geometric.datasets import Planetoid

import copy
import gc

import ot
from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein2, gromov_wasserstein, fused_gromov_wasserstein
from torch_geometric.utils import to_scipy_sparse_matrix, to_dense_adj
# from torchmetrics.functional import pairwise_cosine_similarity
from geomloss import SamplesLoss  # See also ImagesLoss, VolumesLoss
import numpy as np


device = torch.device('cuda:0')


def f(G):
    return 0.5 * torch.sum(G**2)


def df(G):
    return G


class GConv(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(GConv, self).__init__()
        self.layers = torch.nn.ModuleList()
        self.activation = nn.PReLU(hidden_dim)
        for i in range(num_layers):
            if i == 0:
                self.layers.append(GCNConv(input_dim, hidden_dim))
            else:
                self.layers.append(GCNConv(hidden_dim, hidden_dim))

    def forward(self, x, edge_index, edge_weight=None):
        z = x
        for conv in self.layers:
            z = conv(z, edge_index, edge_weight)
            z = self.activation(z)
        return z


class Encoder(torch.nn.Module):
    def __init__(self, encoder1, encoder2, augmentor, hidden_dim):
        super(Encoder, self).__init__()
        self.encoder1 = encoder1
        self.encoder2 = encoder2
        self.augmentor = augmentor
        self.project = torch.nn.Linear(hidden_dim, hidden_dim)
        uniform(hidden_dim, self.project.weight)
        self.linear = torch.nn.ModuleList(
            [torch.nn.Linear(hidden_dim + 512, hidden_dim) for _ in range(1)])

    @staticmethod
    def corruption(x, edge_index, edge_weight):
        return x[torch.randperm(x.size(0))], edge_index, edge_weight

    def forward(self, x, edge_index, mode, edge_weight=None):
        aug1, aug2 = self.augmentor

        x1, edge_index1, edge_weight1, _ = aug1(x, edge_index, edge_weight)
        x2, edge_index2, edge_weight2, subset = aug2(
            x, edge_index, edge_weight)

        C1 = torch.squeeze(to_dense_adj(edge_index1))
        F1 = x1
        N1l = x1.shape[0]
        N1r = x1.shape[1]
        h1 = ot.unif(N1l, type_as=x1)

        C2 = torch.squeeze(to_dense_adj(
            edge_index2, max_num_nodes=x2.shape[0]))
        F2 = x2
        N2l = x2.shape[0]
        N2r = x2.shape[1]
        h2 = ot.unif(N2l, type_as=x2)

        cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        Mp = ot.dist(F1, F2, metric='euclidean')
        # Mp = pairwise_cosine_similarity(F1, F2)

        loss = SamplesLoss(loss='sinkhorn', p=2, debias=True, blur=0.1**(1 / 2), backend='tensorized')
        
        # logP = loss(F1, F2)

        # loss.potentials = True
        # u, v = loss(F1, F2)
        # P = torch.exp(1 / 0.1 * (u.t() + v - Mp))  # * (pq)

        # gw0, logP = ot.gromov.gromov_wasserstein(C1, C2, h1, h2, 'square_loss', verbose=True, log=True)
        # logP = ot.gromov.gromov_wasserstein2(C1, C2, h1, h2, 'square_loss')

        # gw, logP = ot.gromov.entropic_gromov_wasserstein(C1, C2, h1, h2, 'square_loss', epsilon=5e-4, log=True, verbose=True)
        # logP = logP['gw_dist']

        # gw, logP = ot.gromov.entropic_gromov_wasserstein2(C1, C2, h1, h2, 'square_loss', epsilon=5e-4, verbose=True)

        # ot.tic()
        # P, log = fused_gromov_wasserstein(Mp, C1, C2, h1, h2, symmetric=True, alpha=0.5, log=True)
        # ot.toc()

        # ot.tic()

        # P, logP = semirelaxed_fused_gromov_wasserstein(Mp, C1, C2, h1, symmetric=True, alpha=0.5, log=True, G0=None)
        # logP =  logP['srfgw_dist']
        # print('sr distance is', logP['srfgw_dist'])
        if mode=='train':
            P, logP = ot.emd(h1, h2, Mp, log=True)
        else:
            P = 0
            logP = 0
        # logP = logP['cost']
        # logP = ot.emd2(h1, h2, Mp)

        # print('logP is', logP)
        # print('logP is', logP['cost'])
        # ot.toc()

        z1 = self.encoder1(x1, edge_index1, edge_weight1)

        z2 = self.encoder2(x2, edge_index2, edge_weight2)

        g1 = self.project(torch.sigmoid(z1.mean(dim=0, keepdim=True)))
        g2 = self.project(torch.sigmoid(z2.mean(dim=0, keepdim=True)))

        z1n = self.encoder1(*self.corruption(x1, edge_index1, edge_weight1))
        z2n = self.encoder2(*self.corruption(x2, edge_index2, edge_weight2))

        # Mb = torch.cdist(z1, z2, p=2)
        Mb = ot.dist(z1, z2, metric='euclidean')
        # C1b = ot.dist(z1, z1, metric='euclidean')
        # C2b = ot.dist(z2, z2, metric='euclidean')
        # Mb = pairwise_cosine_similarity(z1, z2)
        """    metric : str | callable, optional
        'sqeuclidean' or 'euclidean' on all backends. On numpy the function also
        accepts from the scipy.spatial.distance.cdist function : 'braycurtis',
        'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
        'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis',
        'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
        'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'."""
        # Mb /= Mb.max()

        # logB = semirelaxed_fused_gromov_wasserstein2(
        #     Mb, C1b, C2b, h1, symmetric=True, alpha=0.5, G0=None)
        if mode=='train':
            B, logB = ot.emd(h1, h2, Mb, log=True)
        else:
            B = 0
            logB = 0
        # logB = ot.emd2(h1, h2, Mb)
        # loss.potentials = False
        # logB = loss(z1, z2)

        # loss.potentials = True
        # u, v = loss(z1, z2)
        # B = torch.exp(1 / 0.1 * (u.t() + v - Mp))  # * (pq)

        # print('logB is', logB)

        # print('logP/logB = ', logP/logB)

        reg = 1e-1
        # B = ot.optim.cg(h1, h2, Mb, reg=reg, f=f, df=df)
        # B = ot.optim.semirelaxed_cg(h1, h2, Mb, reg=reg, f=f, df=df)
        # B = torch.rand(N1l, N2l).to(device)
        # P = torch.rand_like(B)

        # P.requires_grad=True
        # logP.requires_grad=True
        # logB.requires_grad=True
        # B.requires_grad=True

        # del Mb, Mp, C1, C2
        # torch.cuda.empty_cache()

        return z1, z2, g1, g2, z1n, z2n, P, B, logP, logB, Mp, Mb


def train(encoder_model, data, optimizer, rho=1):
    encoder_model.train()
    optimizer.zero_grad()

    z1, z2, g1, g2, z1n, z2n, P, B, lP, lB, Mp, Mb = encoder_model(
        data.x, data.edge_index, mode='train')

    # reduction = 'batchmean', log_target=True
    kl_loss = nn.KLDivLoss(reduction='batchmean')
    loss = kl_loss(Mp, Mb)


    sloss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)
    # loss = sloss(Mp, Mb)

    loss = rho * loss + torch.linalg.matrix_norm(P - B, ord='fro')

    loss.backward()
    optimizer.step()

    return loss.item()


def test(encoder_model, data):
    encoder_model.eval()
    z1, z2, _, _, _, _, _, _, _, _, _, _ = encoder_model(data.x, data.edge_index, mode='test')
    z = z1 + z2
    split = get_split(num_samples=z.size()[0], train_ratio=0.1, test_ratio=0.8)
    result = LREvaluator()(z, data.y, split)
    return result


def main():
     
    dataset = Planetoid('datasets', name='PubMed', transform=T.NormalizeFeatures())

    data = dataset[0].to(device)

    aug1 = A.Identity()
    aug2 = A.Compose([A.EdgePerturbation(pe=0.3), A.RWSampling(
        use=False, num_seeds=100, walk_length=10), A.FeatureMasking(pf=0.4), A.NodeDropping(pn=0.0)])

    gconv1 = GConv(input_dim=dataset.num_features,
                   hidden_dim=512, num_layers=2).to(device)
    gconv2 = GConv(input_dim=dataset.num_features,
                   hidden_dim=512, num_layers=2).to(device)

    encoder_model = Encoder(encoder1=gconv1, encoder2=gconv2, augmentor=(
        aug1, aug2), hidden_dim=512).to(device)

    optimizer = Adam(encoder_model.parameters(), lr=0.001)

    with tqdm(total=1000, desc='(T)') as pbar:
        for epoch in range(1, 1001):
            rho = 1
            loss = train(encoder_model, data, optimizer, rho=rho)

            pbar.set_postfix({'loss': loss})
            pbar.update()

            if epoch % 1 == 0:
                test_result = test(encoder_model, data)
                print(f'Best test ACC={test_result["acc"]:.4f}')


if __name__ == '__main__':
    main()
