import torch
from torch import nn, autograd
from constraint_module import *

class Projection_Rope_Soft(nn.Module):
    def __init__(self, num_particles, dimension, iter = 3, k = 0.8):
        super(Projection_Rope_Soft, self).__init__()
        self.iter = iter
        self.num_particles = num_particles
        self.dimension = dimension
        self.constrains = Constraint(num_particles, dimension)
        self.k = k

    def forward(self, x):
        # x : B * num_particles * dimension 
        upd_x = x
        for i in range(self.iter):
            k_iter = 1 - (1 - self.k**(1.0 / (i + 1)))
            input_ = upd_x.requires_grad_(True)
            output_ = self.constrains(input_)
            # https://sweetice.github.io/2019/01/26/Compute-the-gradients-of-gradients-with-pytorch/
            grad = autograd.grad(
                outputs=output_,
                inputs=input_,
                grad_outputs=torch.ones_like(output_),
                create_graph=True,
                retain_graph=True
            )[0]
            cons = output_
            # cons : B * 1; grad: B * num_particles * dimension
            eps = 1e-7
            s = (cons.squeeze() / ((grad*grad).sum([1,2])+eps) ).expand(x.size()[1],x.size()[2],-1).permute(2,0,1)
            upd_x = upd_x - s * grad * k_iter
            # delta_p = - ( c / sum(grad_p^2) ) * grad_p 
        return upd_x




class Common_NN(nn.Module):
    def __init__(self, num_particles, dimension, iter = 3):
        super(Common_NN, self).__init__()
        self.iter = iter
        self.num_particles = num_particles
        self.dimension = dimension
        self.net = Net(num_particles, dimension)

    def forward(self, x):
        # x : B * num_particles * dimension 
        upd_x = self.net(x)
        out = upd_x.view([x.size()[0], x.size()[1], x.size()[2]]) 
        return out