import os.path as osp
import sys
import time

import torch
from torch.autograd import gradcheck

sys.path.append(osp.abspath(osp.join(__file__, '../../')))
from corner_pool import TopPool  # noqa: E402, isort:skip
from corner_pool import TopPoolTest, BottomPoolTest, LeftPoolTest, RightPoolTest

# feat = torch.randn(2, 16, 100, 150, requires_grad=True).cuda()
feat = torch.randn(4, 16, 15, 15, requires_grad=True).cuda()

# test time
tp = TopPool()
tpt = TopPoolTest()

assert torch.equal(tp(feat), tpt(feat))
# print(tp(feat))
# print(feat)
# print(tpt(feat))

# st = time.time()
# out1 = tp(feat).sum()
# print("forward pass: {}".format(time.time() - st))
#
# st = time.time()
# out1.backward()
# print("backward pass: {}".format(time.time() - st))
#
# st = time.time()
# out2 = tpt(feat).sum()
# print("forward pass: {}".format(time.time() - st))
#
# st = time.time()
# out2.backward()
# print("backward pass: {}".format(time.time() - st))


inputs = feat
print('Gradcheck for corner_pooling...')
# test = gradcheck(TopPoolTest(), inputs, eps=1e-5, atol=1e-3)
# print(test)

test = gradcheck(BottomPoolTest(), inputs, eps=1e-5, atol=1e-3)
print(test)

test = gradcheck(LeftPoolTest(), inputs, eps=1e-5, atol=1e-3)
print(test)

test = gradcheck(RightPoolTest(), inputs, eps=1e-5, atol=1e-3)
print(test)
