import os.path as osp

import torch
import torch.nn.functional as F
from torch.nn import Sequential as Seq, Dropout, Linear as Lin, init

from torch import nn
from torch_geometric.datasets import ModelNet
import torch_geometric.transforms as T
import torch_geometric.data.batch as dataBatch
from torch_geometric.data import DataLoader
from torch_geometric.nn import DynamicEdgeConv, global_max_pool
from pointnet2_classification import MLP
from torch_geometric.utils import sort_edge_index, to_dense_batch
from DiffGCN import DiffGCNBlock
from mgpool import mgunpool
from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D
import sys
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
from torch.autograd import Variable

currPath = osp.dirname(osp.realpath(__file__))
expname = 'classification.pth'
print("Exp name:", expname)
path = '/home/cluster/users/erant_group/ModelNet40'
savepath = '/home/cluster/users/erant_group/diffops/new/' + expname

if "slurm" in sys.argv:
    path = '/home/eliasof/meshfit/pytorch_geometric/data/'
    savepath = '/home/eliasof/meshfit/pytorch_geometric/checkpoints/new/' + expname

train_transform = T.Compose([
    T.RandomScale((0.9, 1.1)),
    T.RandomRotate(15, axis=0),
    T.RandomRotate(15, axis=1),
    T.RandomRotate(15, axis=2),
    T.SamplePoints(1024)
])
pre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024)
batchsize = 20
train_dataset = ModelNet(path, '10', True, transform=transform, pre_transform=pre_transform)
test_dataset = ModelNet(path, '10', False, transform=transform, pre_transform=pre_transform)
train_loader = DataLoader(
    train_dataset, batch_size=batchsize, shuffle=True, num_workers=6, drop_last=True)
test_loader = DataLoader(
    test_dataset, batch_size=batchsize, shuffle=False, num_workers=6)



def plot_grad_flow(named_parameters, batch, epoch, filename):
    """Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.

    Usage: Plug this function in Trainer class after loss.backwards() as
    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow"""
    ave_grads = []
    max_grads = []
    layers = []
    for n, p in named_parameters:
        if (p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())
            max_grads.append(p.grad.abs().max())
    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
    plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
    plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom=-0.001, top=0.02)  # zoom in on the lower gradient regions
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.legend([Line2D([0], [0], color="c", lw=4),
                Line2D([0], [0], color="b", lw=4),
                Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
    plt.savefig(osp.join(currPath, filename + str(batch) + '_epoch_' + str(epoch) + '.jpg'))


def stnknn(x, k):
    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)

    idx = pairwise_distance.topk(k=k, dim=-1)[1]  # (batch_size, num_points, k)
    return idx


def get_graph_feature(x, k=20, normals=None, idx=None, dim9=False):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if dim9 == False:
            idx = stnknn(x, k=k)  # (batch_size, num_points, k)
        else:
            idx = stnknn(x[:, 6:], k=k)
    device = torch.device('cuda')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points

    idx = idx + idx_base

    idx = idx.view(-1)

    _, num_dims, _ = x.size()

    x = x.transpose(2,
                    1).contiguous()  # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size * num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    if normals is None:
        feature = torch.cat((x, feature - x), dim=3).permute(0, 3, 1, 2)
    else:
        normals = normals.contiguous()
        normals = normals.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
        feature = torch.cat((x, feature - x, normals), dim=3).permute(0, 3, 1, 2)

    return feature  # (batch_size, 2*num_dims, num_points, k)


class Transform_Net(nn.Module):
    def __init__(self, args=None, normals=False, open=True):
        super(Transform_Net, self).__init__()
        self.args = args
        self.k = 3
        if open:
            if normals:
                self.initialFeatSize = 6
            else:
                self.initialFeatSize = 6
            self.outputSize = 3
        else:
            self.initialFeatSize = 2 * 64
            self.outputSize = 64
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm1d(1024)

        self.conv1 = nn.Sequential(nn.Conv2d(self.initialFeatSize, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(128, 1024, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))

        self.linear1 = nn.Linear(1024, 512, bias=False)
        self.bn3 = nn.BatchNorm1d(512)
        self.linear2 = nn.Linear(512, 256, bias=False)
        self.bn4 = nn.BatchNorm1d(256)

        self.transform = nn.Linear(256, self.outputSize * self.outputSize)
        init.constant_(self.transform.weight, 0)
        init.eye_(self.transform.bias.view(self.outputSize, self.outputSize))

    def forward(self, x):
        batch_size = x.size(0)

        x = self.conv1(x)  # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(x)  # (batch_size, 64, num_points, k) -> (batch_size, 128, num_points, k)
        x = x.max(dim=-1, keepdim=False)[0]  # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)

        x = self.conv3(x)  # (batch_size, 128, num_points) -> (batch_size, 1024, num_points)
        x = x.max(dim=-1, keepdim=False)[0]  # (batch_size, 1024, num_points) -> (batch_size, 1024)

        x = F.leaky_relu(self.bn3(self.linear1(x)), negative_slope=0.2)  # (batch_size, 1024) -> (batch_size, 512)
        x = F.leaky_relu(self.bn4(self.linear2(x)), negative_slope=0.2)  # (batch_size, 512) -> (batch_size, 256)

        x = self.transform(x)  # (batch_size, 256) -> (batch_size, 3*3)
        x = x.view(batch_size, self.outputSize, self.outputSize)  # (batch_size, 3*3) -> (batch_size, 3, 3)

        return x


class STN3d(nn.Module):
    def __init__(self):
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
            batchsize, 1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x


class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k * k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(
            batchsize, 1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x


class Net(torch.nn.Module):
    def __init__(self, out_channels, k=20, aggr='max'):
        super().__init__()
        self.k = k
        self.transform_net = STN3d()

        self.conv1 = DiffGCNBlock(3, 64, k, 1)
        self.conv2 = DiffGCNBlock(64, 64, int(np.ceil(k / 1)), 2, pool=False)
        self.conv3 = DiffGCNBlock(64, 64, int(np.ceil(k / 1)), 2, pool=True)
        self.conv4 = DiffGCNBlock(64, 128, int(np.ceil(k / 1)), 2, pool=True)

        self.lin1 = MLP([3 * 64 + 128, 1024])

        self.mlp = Seq(
            MLP([1024, 512]), Dropout(0.5), MLP([512, 256]), Dropout(0.5),
            Lin(256, out_channels))

    def forward(self, data):
        pos, batch = data.pos, data.batch
        x0 = pos
        x0, mask = to_dense_batch(x0, batch)
        x0 = x0.transpose(2, 1).contiguous()
        t = self.transform_net(x0).contiguous()
        x0 = x0.transpose(2, 1).contiguous()
        x0 = torch.bmm(x0, t).contiguous()
        x0 = x0[mask, :]
        pos = x0.contiguous()

        origbatch = batch.clone()
        x1, pos, batch = self.conv1(x0, pos, batch)

        x2, pos, batch = self.conv2(x1, pos, batch)

        x3, pos, batch, pooldata3 = self.conv3(x2, pos, batch)

        x4, pos, batch, pooldata4 = self.conv4(x3, pos, batch)
        # Do unpooling:
        x4 = mgunpool(mgunpool(x4, *pooldata4), *pooldata3)
        x3 = mgunpool(x3, *pooldata3)
        out = self.lin1(torch.cat([x1, x2, x3, x4], dim=1))

        out = global_max_pool(out, origbatch)
        out = self.mlp(out)
        return F.log_softmax(out, dim=1), pos, t


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(train_dataset.num_classes, k=5).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
prev_test_acc = 0
if "continue" in sys.argv:
    optimizer = torch.optim.Adam(model.parameters(), lr=0.000125)
    checkpoint = torch.load(savepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch = 0
    loss = checkpoint['loss']
    prev_test_acc = checkpoint['test_acc']


def train(epoch):
    model.train()
    total_loss = 0
    curr_loss = 0
    optimizer.zero_grad()
    for i, data in enumerate(train_loader):
        data.pos = data.pos
        data = data.to(device)

        out, pos, t = model(data)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        total_loss += loss.item() * data.num_graphs

        optimizer.step()
        optimizer.zero_grad()
        curr_loss += loss.item() * data.num_graphs
        if (i + 1) % 10 == 0:
            print('[{}/{}] Loss: {:.4f}'.format(
                i + 1, len(train_loader), curr_loss / 10), flush=True)
            curr_loss = 0

    return total_loss / len(train_dataset)


def test(loader):
    model.eval()
    total_seen_class = torch.zeros(40)
    total_correct_class = torch.zeros(40)
    correct = 0
    for data in loader:
        data.pos = data.pos
        data = data.to(device)
        with torch.no_grad():
            pred, pos, t = model(data)
            pred = pred.max(dim=1)[1]
        correct += pred.eq(data.y).sum().item()
        res = pred.eq(data.y)
        for i, p in enumerate(pred):
            total_seen_class[data.y[i]] += 1

            if res[i]:
                total_correct_class[data.y[i]] += 1

    print("total correct per class:", total_correct_class)
    print("total seen per class:", total_seen_class)
    class_accuracy = total_correct_class / total_seen_class
    print("Per Class accuracy:", total_correct_class / total_seen_class)
    avg_class_accuracy = torch.mean(class_accuracy, dim=0)
    print("Per class average accuracy:", avg_class_accuracy)
    return correct / len(loader.dataset)


for epoch in range(1, 1001):
    if "continue" in sys.argv and epoch == 1:
        print("Continuing, testing first:")
    loss = train(epoch)
    test_acc = test(test_loader)
    print('Epoch {:03d}, Loss: {:.4f}, Test: {:.4f}'.format(
        epoch, loss, test_acc))
    scheduler.step()
    if "continue" not in sys.argv:
        if prev_test_acc < test_acc:
            prev_test_acc = test_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                'test_acc': test_acc,
            }, savepath)
