import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable

from . import top_pool_cuda, bottom_pool_cuda, left_pool_cuda, right_pool_cuda

class TopPoolFunction(Function):

    @staticmethod
    def forward(ctx, input, keep_arg=True):
        ctx.save_for_backward(input)

        output = torch.zeros_like(input)
        argmax = torch.zeros_like(input, dtype=torch.int)
        top_pool_cuda.forward(input, output, argmax, keep_arg)

        ctx.argmax = argmax

        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        argmax = ctx.argmax
        input = ctx.saved_variables[0]

        grad_input = torch.zeros_like(grad_output)
        top_pool_cuda.backward(grad_output.contiguous(), input, argmax, grad_input)

        return grad_input, None


class BottomPoolFunction(Function):

    @staticmethod
    def forward(ctx, input, keep_arg=True):
        ctx.save_for_backward(input)

        output = torch.zeros_like(input)
        argmax = torch.zeros_like(input, dtype=torch.int)
        bottom_pool_cuda.forward(input, output, argmax, keep_arg)

        ctx.argmax = argmax

        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        argmax = ctx.argmax
        input = ctx.saved_variables[0]

        grad_input = torch.zeros_like(grad_output)
        bottom_pool_cuda.backward(grad_output.contiguous(), input, argmax, grad_input)

        return grad_input, None


class LeftPoolFunction(Function):

    @staticmethod
    def forward(ctx, input, keep_arg=True):
        ctx.save_for_backward(input)

        output = torch.zeros_like(input)
        argmax = torch.zeros_like(input, dtype=torch.int)
        left_pool_cuda.forward(input, output, argmax, keep_arg)

        ctx.argmax = argmax

        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        argmax = ctx.argmax
        input = ctx.saved_variables[0]

        grad_input = torch.zeros_like(grad_output)
        left_pool_cuda.backward(grad_output.contiguous(), input, argmax, grad_input)

        return grad_input, None


class RightPoolFunction(Function):

    @staticmethod
    def forward(ctx, input, keep_arg=True):
        ctx.save_for_backward(input)

        output = torch.zeros_like(input)
        argmax = torch.zeros_like(input, dtype=torch.int)
        right_pool_cuda.forward(input, output, argmax, keep_arg)

        ctx.argmax = argmax

        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        argmax = ctx.argmax
        input = ctx.saved_variables[0]

        grad_input = torch.zeros_like(grad_output)
        right_pool_cuda.backward(grad_output.contiguous(), input, argmax, grad_input)

        return grad_input, None


top_pool = TopPoolFunction.apply
bottom_pool = BottomPoolFunction.apply
left_pool = LeftPoolFunction.apply
right_pool = RightPoolFunction.apply


class TopPool(nn.Module):
    def forward(self, x):
        return top_pool(x, self.training)


class BottomPool(nn.Module):
    def forward(self, x):
        return bottom_pool(x, self.training)


class LeftPool(nn.Module):
    def forward(self, x):
        return left_pool(x, self.training)


class RightPool(nn.Module):
    def forward(self, x):
        return right_pool(x, self.training)