import torch
import os.path as osp
import aug.augmentors as A

from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning

from torch import nn
from tqdm import tqdm
from torch.optim import Adam
from aug.eval import get_split, SVMEvaluator
from torch_geometric.nn import GCNConv, global_add_pool
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset

import ot
from ot.gromov import (
    semirelaxed_gromov_wasserstein,
    semirelaxed_fused_gromov_wasserstein,
    semirelaxed_fused_gromov_wasserstein2,
    gromov_wasserstein,
    fused_gromov_wasserstein,
)
from torch_geometric.utils import to_scipy_sparse_matrix, to_dense_adj

# from torchmetrics.functional import pairwise_cosine_similarity
from geomloss import SamplesLoss  # See also ImagesLoss, VolumesLoss
import numpy as np
from ot.gromov._utils import (
    init_matrix,
    gwloss,
    gwggrad,
    init_matrix_semirelaxed,
    tensor_product,
)
from ot.backend import get_backend
import torch.nn.functional as F

device = torch.device("cuda:0")


class GConv(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(GConv, self).__init__()
        self.layers = nn.ModuleList()
        self.activation = nn.PReLU(hidden_dim)
        for i in range(num_layers):
            if i == 0:
                self.layers.append(GCNConv(input_dim, hidden_dim))
            else:
                self.layers.append(GCNConv(hidden_dim, hidden_dim))

    def forward(self, x, edge_index, batch):
        z = x
        zs = []
        for conv in self.layers:
            z = conv(z, edge_index)
            z = self.activation(z)
            zs.append(z)
        gs = [global_add_pool(z, batch) for z in zs]
        g = torch.cat(gs, dim=1)
        return z, g


class FC(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(FC, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU(),
        )
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x) + self.linear(x)


class Encoder(torch.nn.Module):
    def __init__(self, gcn1, gcn2, mlp1, mlp2, aug1, aug2):
        super(Encoder, self).__init__()
        self.gcn1 = gcn1
        self.gcn2 = gcn2
        self.mlp1 = mlp1
        self.mlp2 = mlp2
        self.aug1 = aug1
        self.aug2 = aug2

    def forward(self, x, edge_index, batch, mode, sigma=1):
        x1, edge_index1, edge_weight1, _ = self.aug1(x, edge_index)
        x2, edge_index2, edge_weight2, _ = self.aug2(x, edge_index)

        z1, g1 = self.gcn1(x1, edge_index1, batch)
        z2, g2 = self.gcn2(x2, edge_index2, batch)
        a1, a2 = [self.mlp1(h) for h in [z1, z2]]
        g1, g2 = [self.mlp2(g) for g in [g1, g2]]

        C1 = torch.squeeze(to_dense_adj(edge_index1, max_num_nodes=x1.shape[0]))
        F1 = x1
        N1l = x1.shape[0]
        N1r = x1.shape[1]
        h1 = ot.unif(N1l, type_as=x1)

        C2 = torch.squeeze(to_dense_adj(edge_index2, max_num_nodes=x2.shape[0]))
        F2 = x2
        N2l = x2.shape[0]
        N2r = x2.shape[1]
        h2 = ot.unif(N2l, type_as=x2)

        Mp = ot.dist(x1, x2, metric="euclidean")
        Mb = ot.dist(a1, a2, metric="euclidean")

        # sl = SamplesLoss(loss='sinkhorn', p=2, debias=True, blur=0.5**(1 / 2), backend='tensorized')

        if mode == "train":
            # P, log_23 = semirelaxed_gromov_wasserstein(Mp + C1, Mp + C2, h1, symmetric=True, log=True, G0=None)
            # P, log_32 = semirelaxed_gromov_wasserstein(C2+Mp, C1+Mp, h2, symmetric=None, log=True, G0=None)
            P, logP = semirelaxed_fused_gromov_wasserstein(
                Mp, C1, C2, h1, symmetric=True, alpha=1 - sigma, log=True, G0=None
            )

            nx = get_backend(h1, C1, C2)
            constC, hC1, hC2, fC2t = init_matrix_semirelaxed(
                C1, C2, h1, loss_fun="square_loss", nx=nx
            )
            OM = torch.ones(N1l, N2l).to(device)
            OM = OM / (N1l * N2l)
            qOneM = nx.sum(OM, 0)
            ones_p = nx.ones(h1.shape[0], type_as=h1)
            marginal_product = nx.outer(ones_p, nx.dot(qOneM, fC2t))
            Mp2 = tensor_product(constC + marginal_product, hC1, hC2, P, nx=nx)
            Mp2 = F.normalize(Mp2)
            Mp = (sigma) * Mp + (1 - sigma) * Mp2

            # P = ot.emd(h1, h2, Mp)
            B = ot.emd(h1, h2, Mb)
            P.requires_grad = True
            B.requires_grad = True
        else:
            P = 0
            B = 0

        return a1, a2, g1, g2, P, B, Mp, Mb


def train(encoder_model, dataloader, optimizer, sigma=1, rho=1):
    encoder_model.train()
    epoch_loss = 0
    i = 0
    for data in dataloader:
        print(f"i=", i)
        i = i + 1
        data = data.to(device)
        optimizer.zero_grad()

        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones(
                (num_nodes, 1), dtype=torch.float32, device=data.batch.device
            )

        a1, a2, g1, g2, P, B, Mp, Mb = encoder_model(
            data.x.float(), data.edge_index, data.batch, mode="train", sigma=sigma
        )

        kl_loss = nn.KLDivLoss(reduction="batchmean")
        loss = kl_loss(Mp, Mb)

        sloss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)
        # loss = sloss(Mp, Mb)  # By default, use constant weights = 1/number of samples

        loss = rho * loss + torch.linalg.matrix_norm(P - B, ord="fro")

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
    return epoch_loss


@ignore_warnings(category=ConvergenceWarning)
def test(encoder_model, dataloader):
    encoder_model.eval()
    x = []
    y = []

    for data in dataloader:
        data = data.to(device)
        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones(
                (num_nodes, 1), dtype=torch.float32, device=data.batch.device
            )

        _, _, g1, g2, _, _, _, _ = encoder_model(
            data.x.float(), data.edge_index, data.batch, mode="test"
        )

        x.append(g1 + g2)
        y.append(data.y)
        # if data.y.shape[1] == 1:
        #     y.append(data.y)
        # else:
        #     y.append(torch.argmax(data.y, dim=1))

    x = torch.cat(x, dim=0)
    y = torch.cat(y, dim=0)

    split = get_split(num_samples=x.size()[0], train_ratio=0.8, test_ratio=0.1)
    result = SVMEvaluator(linear=True)(x, y, split)

    return result


def main():
    dataset = TUDataset("datasets", name="PROTEINS")

    dataloader = DataLoader(dataset, batch_size=32)

    input_dim = max(dataset.num_features, 1)

    aug1 = A.Identity()
    aug2 = A.Compose(
        [
            A.EdgePerturbation(pe=0.2),
            A.RWSampling(use=False, num_seeds=1000, walk_length=1000),
            A.FeatureMasking(pf=0.3),
            A.NodeDropping(pn=0.0),
        ]
    )

    gcn1 = GConv(input_dim=input_dim, hidden_dim=512, num_layers=2).to(device)
    gcn2 = GConv(input_dim=input_dim, hidden_dim=512, num_layers=2).to(device)
    mlp1 = FC(input_dim=512, output_dim=512)
    mlp2 = FC(input_dim=512 * 2, output_dim=512)

    encoder_model = Encoder(
        gcn1=gcn1, gcn2=gcn2, mlp1=mlp1, mlp2=mlp2, aug1=aug1, aug2=aug2
    ).to(device)

    optimizer = Adam(encoder_model.parameters(), lr=0.01)

    res = []
    with tqdm(total=1000, desc="(T)") as pbar:
        for epoch in range(1, 1001):
            sigma = 1
            rho = 1
            loss = train(encoder_model, dataloader, optimizer, sigma=sigma, rho=rho)
            pbar.set_postfix({"loss": loss})
            pbar.update()

            if epoch % 2 == 0:
                test_result = test(encoder_model, dataloader)
                res.append(test_result["acc"])
                print(f'Best test ACC={test_result["acc"]:.4f}')


if __name__ == "__main__":
    main()
