from model import *
import os
import numpy as np
import math
import cv2
from matplotlib import pyplot as plt

def result_match(img_1_rgb, img_2_rgb, desc_1, desc_2, kpt_1, kpt_2, ratio, scene_name):
    bf = cv2.BFMatcher(cv2.NORM_L2)
    matches_all = bf.knnMatch(desc_1, desc_2, k=2)

    # Apply ratio test
    good = []
    for m, n in matches_all:
        if m.distance < ratio * n.distance:
            good.append(m)

    pts_1 = np.float32([kpt_1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
    pts_2 = np.float32([kpt_2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)

    good_ransac = []
    M, mask = cv2.findHomography(pts_1, pts_2, cv2.RANSAC, 5.0)
    matchesMask = mask.ravel().tolist()
    for i, s in enumerate(matchesMask):
        if s == 1:
            good_ransac.append(good[i])

    color = len(good_ransac)*[[0, 255, 0]]


    if len(img_1_rgb.shape) == 3:
        new_shape = (max(img_1_rgb.shape[0], img_2_rgb.shape[0]), img_1_rgb.shape[1] + img_2_rgb.shape[1], img_1_rgb.shape[2])
    elif len(img_1_rgb.shape) == 2:
        new_shape = (max(img_1_rgb.shape[0], img_2_rgb.shape[0]), img_1_rgb.shape[1] + img_2_rgb.shape[1])
    new_img = np.zeros(new_shape, type(img_1_rgb.flat[0]))
    new_img[0:img_1_rgb.shape[0], 0:img_1_rgb.shape[1]] = img_1_rgb
    new_img[0:img_2_rgb.shape[0], img_1_rgb.shape[1]:img_1_rgb.shape[1] + img_2_rgb.shape[1]] = img_2_rgb

    r = 5
    for i, m in enumerate(good_ransac):
        end1 = tuple(np.round(kpt_1[m.queryIdx].pt).astype(int))
        end2 = tuple(np.round(kpt_2[m.trainIdx].pt).astype(int) + np.array([img_1_rgb.shape[1], 0]))
        if color[i] == [255, 0, 0]:
            thickness = 4
        else:
            thickness = 2
        cv2.line(new_img, end1, end2, color[i], thickness)
        cv2.circle(new_img, end1, r, color[i], thickness)
        cv2.circle(new_img, end2, r, color[i], thickness)



    fig = plt.figure()
    plt.imshow(new_img)
    plt.axis('off')
    plt.title(scene_name)
    img_name = os.path.join('results',  scene_name + '.png')
    fig.savefig(img_name, dpi=300)

    return

def extract_patch_opencv(img, kpts, N, mag_factor):
    """
    Rectifies patches around openCV keypoints, and returns patches tensor
    From tFeat
    """
    patches = []
    for kp in kpts:
        x, y = kp.pt
        s = kp.size
        a = kp.angle

        s = mag_factor * s / N
        cos = math.cos(a * math.pi / 180.0)
        sin = math.sin(a * math.pi / 180.0)

        M = np.matrix([
            [+s * cos, -s * sin, (-s * cos + s * sin) * N / 2.0 + x],
            [+s * sin, +s * cos, (-s * sin - s * cos) * N / 2.0 + y]])

        patch = cv2.warpAffine(img, M, (N, N), flags=cv2.WARP_INVERSE_MAP + cv2.INTER_CUBIC + cv2.WARP_FILL_OUTLIERS)

        patches.append(patch)

    patches = torch.from_numpy(np.asarray(patches)).to(torch.uint8).to(torch.float32)
    patches = torch.unsqueeze(patches, 1)

    return patches

def sort_response(kpt):
    num_kpt = len(kpt)
    resp = []
    for i in range(num_kpt):
        resp.append(kpt[i].response)

    resp = np.array(resp)
    loc_sorted = np.argsort(resp)

    return loc_sorted[::-1]

def test_net(net, patch, dim_desc, sz_batch=200):
    nb_patch = patch.size(0)
    nb_loop = int(np.ceil(nb_patch/sz_batch))
    desc = torch.zeros(nb_patch, dim_desc)
    with torch.set_grad_enabled(False):
        for i in range(nb_loop):
            st = i * sz_batch
            en = np.min([(i + 1) * sz_batch, nb_patch])
            batch = patch[st:en].to(device)
            out_desc = net(batch)
            out_desc = out_desc.to('cpu')
            desc[st:en] = out_desc
    return desc





if __name__ == '__main__':
    # Example matching on Hpatches dataset
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    torch.set_grad_enabled(False)
    dpi = 300

    #Define detector
    detect_type = 'BRISK'
    detector = cv2.BRISK_create()
    mag_factor = 3
    num_kpt_max = 1024


    #Define descriptro
    net = HyNet()
    net.load_state_dict(torch.load('weights/HyNet_LIB.pth'))
    net.to(device)
    net.eval()

    #Get images
    # For illustiration, we provide v_dogman, v_bird, i_londonbridge and i_crownday from HPatches dataset,
    # where v_ stands for viewpoint change and i_ stands for illumination change
    # Any other image pairs can also be tested
    scene_name = 'i_crownday' #
    img_id_1 = 1
    img_id_2 = 3


    img_1_rgb = cv2.imread(os.path.join('images', scene_name, str(img_id_1)+'.ppm'))
    img_2_rgb = cv2.imread(os.path.join('images', scene_name, str(img_id_2)+'.ppm'))
    img_1_rgb = cv2.cvtColor(img_1_rgb, cv2.COLOR_BGR2RGB)
    img_2_rgb = cv2.cvtColor(img_2_rgb, cv2.COLOR_BGR2RGB)

    img_1 = cv2.cvtColor(img_1_rgb, cv2.COLOR_RGB2GRAY)
    img_2 = cv2.cvtColor(img_2_rgb, cv2.COLOR_RGB2GRAY)

    kpt_1 = detector.detect(img_1, None)
    patches_1 = extract_patch_opencv(img_1, kpt_1, 32, mag_factor)

    kpt_2 = detector.detect(img_2, None)
    patches_2 = extract_patch_opencv(img_2, kpt_2, 32, mag_factor)


    if len(kpt_1) > num_kpt_max:
        loc_sorted = sort_response(kpt_1)
        loc_sorted = loc_sorted[0:num_kpt_max]
        kp_temp = [kpt_1[x] for x in loc_sorted]
        kpt_1 = kp_temp
        patches_1 = patches_1[np.array(loc_sorted)]

    if len(kpt_2) > num_kpt_max:
        loc_sorted = sort_response(kpt_2)
        loc_sorted = loc_sorted[0:num_kpt_max]
        kp_temp = [kpt_2[x] for x in loc_sorted]
        kpt_2 = kp_temp
        patches_2 = patches_2[np.array(loc_sorted)]

    desc_1 = test_net(net, patches_1, 128, sz_batch=200).cpu().numpy()
    desc_2 = test_net(net, patches_2, 128, sz_batch=200).cpu().numpy()


    nn_ratio = 1 # Ratio test between the first and the second nearest neighbour
    result_match(img_1_rgb, img_2_rgb, desc_1, desc_2, kpt_1, kpt_2, nn_ratio, scene_name=scene_name)
