#include <torch/extension.h>

#include <ATen/ATen.h>

#include <vector>

int TopPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor argmax,
                           const int batch_size, const int channels,
                           const int height, const int width,
                           at::Tensor bottom_grad);

#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
  AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
  CHECK_CUDA(x);       \
  CHECK_CONTIGUOUS(x)

int top_pool_forward(at::Tensor input, at::Tensor output, at::Tensor argmax, bool cal_argmax) {
  // Initialize output

  // Get height
  int64_t batch   = input.size(0);
  int64_t channel = input.size(1);
  int64_t height  = input.size(2);
  int64_t width   = input.size(3);

  // Copy the last column
  at::Tensor input_temp  = input.select(2, height - 1);
  at::Tensor output_temp = output.select(2, height - 1);
  output_temp.copy_(input_temp);

  at::Tensor argmax_temp;
  if (cal_argmax) {
    at::Tensor argmax_temp = argmax.select(2, height - 1);
    argmax_temp.fill_(height - 1);
  }
  auto gt_mask  = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kByte));
  at::Tensor max_temp;
  for (int64_t ind = 1; ind < height; ++ind) {
    input_temp  = input.select(2, height - ind - 1);
    output_temp = output.select(2, height - ind);
    max_temp    = output.select(2, height - ind - 1);

    if (cal_argmax) {
      at::gt_out(gt_mask, input_temp, output_temp);
      argmax_temp = argmax.select(2, height - ind - 1);
      argmax_temp.copy_(argmax.select(2, height - ind));
      argmax_temp.masked_fill_(gt_mask, height - ind - 1);
    }

    at::max_out(max_temp, input_temp, output_temp);
  }

  return 1;
}

int top_pool_backward(at::Tensor grad_output,
                      at::Tensor input,
                      at::Tensor argmax,
                      at::Tensor grad_input) {
  CHECK_INPUT(grad_output);
  CHECK_INPUT(input);
  CHECK_INPUT(argmax);
  CHECK_INPUT(grad_input);

  int batch   = input.size(0);
  int channel = input.size(1);
  int height  = input.size(2);
  int width   = input.size(3);

  TopPoolBackwardLaucher(grad_output, argmax,
                         batch, channel, height, width,
                         grad_input);

  return 1;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &top_pool_forward, "Top Pool Forward", py::call_guard<py::gil_scoped_release>());
  m.def("backward", &top_pool_backward, "Top Pool Backward", py::call_guard<py::gil_scoped_release>());
}
