import torch
from torch import nn, autograd
from constraint_module import *

class Projection(nn.Module):
    def __init__(self, num_particles, dimension, iter = 3):
        super(Projection, self).__init__()
        self.iter = iter
        self.num_particles = num_particles
        self.dimension = dimension
        self.constrains = Constraint(num_particles, dimension)

    def forward(self, x):
        # x : B * num_particles * dimension 
        upd_x = x
        for i in range(self.iter):
            input_ = upd_x.requires_grad_(True)
            output_ = self.constrains(input_)
            # https://sweetice.github.io/2019/01/26/Compute-the-gradients-of-gradients-with-pytorch/
            grad = autograd.grad(
                outputs=output_,
                inputs=input_,
                grad_outputs=torch.ones_like(output_),
                create_graph=True,
                retain_graph=True
            )[0]
            cons = output_
            # cons : B * 1; grad: B * num_particles * dimension
            eps = 1e-7
            s = (cons.squeeze() / ((grad*grad).sum([1,2])+eps) ).expand(x.size()[1],x.size()[2],-1).permute(2,0,1)
            upd_x = upd_x - s * grad 
            # delta_p = - ( c / sum(grad_p^2) ) * grad_p 
        return upd_x



class Projection_2(nn.Module):
    def __init__(self, num_particles, dimension, iter = 3):
        super(Projection_2, self).__init__()
        self.iter = iter
        self.num_particles = num_particles
        self.dimension = dimension
        self.constrains = Constraint_Larger(num_particles, dimension)

    def forward(self, x):
        # x : B * num_particles * dimension 
        upd_x = x
        for i in range(self.iter):
            input_ = upd_x.requires_grad_(True)
            output_ = self.constrains(input_)
            # https://sweetice.github.io/2019/01/26/Compute-the-gradients-of-gradients-with-pytorch/
            grad = autograd.grad(
                outputs=output_,
                inputs=input_,
                grad_outputs=torch.ones_like(output_),
                create_graph=True,
                retain_graph=True
            )[0]
            cons = output_
            # cons : B * 1; grad: B * num_particles * dimension
            eps = 1e-7
            s = (cons.squeeze() / ((grad*grad).sum([1,2])+eps) ).expand(x.size()[1],x.size()[2],-1).permute(2,0,1)
            upd_x = upd_x - s * grad 
            # delta_p = - ( c / sum(grad_p^2) ) * grad_p 
        return upd_x

