#include <torch/extension.h>

#include <ATen/ATen.h>

#include <vector>

int RightPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor argmax,
                             const int batch_size, const int channels,
                             const int height, const int width,
                             at::Tensor bottom_grad);

#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
  AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
  CHECK_CUDA(x);       \
  CHECK_CONTIGUOUS(x)

int right_pool_forward(at::Tensor input, at::Tensor output, at::Tensor argmax, bool cal_argmax) {
  // Initialize output

  // Get height
  int64_t batch   = input.size(0);
  int64_t channel = input.size(1);
  int64_t height  = input.size(2);
  int64_t width   = input.size(3);

  // Copy the last column
  at::Tensor input_temp  = input.select(3, 0);
  at::Tensor output_temp = output.select(3, 0);
  output_temp.copy_(input_temp);

  at::Tensor argmax_temp;
  if (cal_argmax) {
    at::Tensor argmax_temp = argmax.select(3, 0);
    argmax_temp.fill_(0);
  }
  auto gt_mask  = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kByte));
  at::Tensor max_temp;
  for (int64_t ind = 0; ind < width - 1; ++ind) {
    input_temp  = input.select(3, ind + 1);
    output_temp = output.select(3, ind);
    max_temp    = output.select(3, ind + 1);

    if (cal_argmax) {
      at::gt_out(gt_mask, input_temp, output_temp);
      argmax_temp = argmax.select(3, ind + 1);
      argmax_temp.copy_(argmax.select(3, ind));
      argmax_temp.masked_fill_(gt_mask, ind + 1);
    }

    at::max_out(max_temp, input_temp, output_temp);
  }

  return 1;
}

int right_pool_backward(at::Tensor grad_output,
                        at::Tensor input,
                        at::Tensor argmax,
                        at::Tensor grad_input) {
  CHECK_INPUT(grad_output);
  CHECK_INPUT(input);
  CHECK_INPUT(argmax);
  CHECK_INPUT(grad_input);

  int batch   = input.size(0);
  int channel = input.size(1);
  int height  = input.size(2);
  int width   = input.size(3);

  RightPoolBackwardLaucher(grad_output, argmax,
                           batch, channel, height, width,
                           grad_input);

  return 1;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &right_pool_forward, "Right Pool Forward", py::call_guard<py::gil_scoped_release>());
  m.def("backward", &right_pool_backward, "Right Pool Backward", py::call_guard<py::gil_scoped_release>());
}
