#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>

#define CUDA_1D_KERNEL_LOOP(i, n)                            \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
       i += blockDim.x * gridDim.x)

#define THREADS_PER_BLOCK 1024

inline int GET_BLOCKS(const int N) {
  int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
  int max_block_num = 65000;
  return min(optimal_block_num, max_block_num);
}

template <typename scalar_t>
__global__ void TopPoolBackward(const int nthreads,
                                const scalar_t *top_diff, const int *argmax_data,
                                const int channels, const int height, const int width,
                                scalar_t *bottom_diff) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    int w = index % width;
    int h = (index / width) % height;
    int c = (index / width / height) % channels;
    int b = index / width / height / channels;

    int bottom_h = argmax_data[index];

    atomicAdd(bottom_diff + (b * channels + c) * height * width + bottom_h * width + w,
              top_diff[index]);
  }
}

int TopPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor argmax,
                           const int batch_size, const int channels,
                           const int height, const int width,
                           at::Tensor bottom_grad) {
  const int output_size = batch_size * channels * height * width;

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      top_grad.scalar_type(), "ROIPoolLaucherBackward", ([&] {
        const scalar_t *top_diff = top_grad.data<scalar_t>();
        const int *argmax_data = argmax.data<int>();
        scalar_t *bottom_diff = bottom_grad.data<scalar_t>();

        if (sizeof(scalar_t) == sizeof(double)) {
          fprintf(stderr, "double is not supported\n");
          exit(-1);
        }

        TopPoolBackward<scalar_t>
            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>(
                output_size, top_diff, argmax_data,
                channels, height, width,
                bottom_diff);
      }));
  THCudaCheck(cudaGetLastError());
  return 1;
}
