import cv2
import numpy as np
import torch 


from torch.nn import functional as F
from scipy.spatial.transform import Rotation as sRot
from relive.utils.torch_geometry_transforms import *
from relive.utils.transformation import quaternion_slerp
from relive.utils.torch_utils import quaternion_matrix_batch


def batch_rodrigues(axisang):
    # This function is borrowed from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py#L37
    # axisang N x 3
    axisang_norm = torch.norm(axisang + 1e-8, p=2, dim=1)
    angle = torch.unsqueeze(axisang_norm, -1)
    axisang_normalized = torch.div(axisang, angle)
    angle = angle * 0.5
    v_cos = torch.cos(angle)
    v_sin = torch.sin(angle)
    quat = torch.cat([v_cos, v_sin * axisang_normalized], dim=1)
    rot_mat = quat2mat(quat)
    rot_mat = rot_mat.view(rot_mat.shape[0], 9)
    return rot_mat



def smpl_mat_to_aa(poses):
    poses_aa = []
    for pose_frame in poses:
        pose_frames = []
        for joint in pose_frame:
            pose_frames.append(cv2.Rodrigues(joint)[0].flatten())
        pose_frames = np.array(pose_frames)
        poses_aa.append(pose_frames)
    poses_aa = np.array(poses_aa)
    return poses_aa


def compute_rotation_matrix_from_ortho6d(ortho6d):
    x_raw = ortho6d[:,0:3]#batch*3
    y_raw = ortho6d[:,3:6]#batch*3
        
    x = normalize_vector(x_raw) #batch*3
    z = cross_product(x,y_raw) #batch*3
    z = normalize_vector(z)#batch*3
    y = cross_product(z,x)#batch*3
        
    x = x.view(-1,3,1)
    y = y.view(-1,3,1)
    z = z.view(-1,3,1)
    zeros = torch.zeros(z.shape, dtype = z.dtype).to(ortho6d.device)
    matrix = torch.cat((x,y,z, zeros), 2) #batch*3*3
    return matrix

def convert_orth_6d_to_mat(orth6d):
    num_joints = int(orth6d.shape[-1]/6)
    orth6d_flat = orth6d.reshape(-1, 6)

    rot_mat6d = compute_rotation_matrix_from_ortho6d(orth6d_flat)[:,:,:3]

    shape_curr = list(orth6d.shape)
    shape_curr[-1] = num_joints
    shape_curr += [3, 3]
    shape_curr = tuple([int(i) for i in shape_curr])
    rot_mat = rot_mat6d.reshape(shape_curr) 
    return rot_mat


def cross_product( u, v):
    batch = u.shape[0]
    #print (u.shape)
    #print (v.shape)
    i = u[:,1]*v[:,2] - u[:,2]*v[:,1]
    j = u[:,2]*v[:,0] - u[:,0]*v[:,2]
    k = u[:,0]*v[:,1] - u[:,1]*v[:,0]
        
    out = torch.cat((i.view(batch,1), j.view(batch,1), k.view(batch,1)),1)#batch*3
        
    return out

def compute_orth6d_from_rotation_matrix(rot_mats):
    rot_mats = rot_mats[:,:,:2].transpose(1, 2).reshape(-1, 6)
    return rot_mats

def convert_mat_to_6d(poses):
    if torch.is_tensor(poses):
        curr_pose = poses.to(poses.device).float().reshape(-1, 3, 3)
    else:
        curr_pose = torch.tensor(poses).to(poses.device).float().reshape(-1, 3, 3)


    orth6d = curr_pose[:,:,:2].transpose(1, 2).reshape(-1, 6)
    orth6d = orth6d.view(poses.shape[0], -1, 6) 
    return orth6d

def convert_aa_to_6d(poses):
    if torch.is_tensor(poses):
        curr_pose = poses.to(poses.device).float().reshape(-1, 3)
    else:
        curr_pose = torch.tensor(poses).to(poses.device).float().reshape(-1, 3)
    rot_mats = angle_axis_to_rotation_matrix(curr_pose)
    rot_mats = rot_mats[:,:3,:]
    orth6d = compute_orth6d_from_rotation_matrix(rot_mats)
    orth6d = orth6d.view(poses.shape[0], -1, 6) 
    orth6d = orth6d
    return orth6d

def convert_6d_to_aa(orth6d):
    orth6d_flat = orth6d.reshape(-1, 6)
    rot_mat6d = compute_rotation_matrix_from_ortho6d(orth6d_flat)
    pose_aa = rotation_matrix_to_angle_axis(rot_mat6d)
    
    shape_curr = list(orth6d.shape)
    shape_curr[-1] /= 2
    shape_curr = tuple([int(i) for i in shape_curr])
    pose_aa = pose_aa.reshape(shape_curr)
    return pose_aa

def convert_6d_to_mat(orth6d):
    num_joints = int(orth6d.shape[-1]/6)
    orth6d_flat = orth6d.reshape(-1, 6)

    rot_mat6d = compute_rotation_matrix_from_ortho6d(orth6d_flat)[:,:,:3]

    shape_curr = list(orth6d.shape)
    shape_curr[-1] = num_joints
    shape_curr += [3, 3]
    shape_curr = tuple([int(i) for i in shape_curr])
    rot_mat = rot_mat6d.reshape(shape_curr) 
    return rot_mat


# def convert_quat_to_6d(quat):
#     all_mats = sRot.from_quat(quat).as_matrix()
#     quat_6d = convert_mat_to_6d(torch.tensor(all_mats))
#     return quat_6d

def convert_quat_to_6d(quats):
    quat_flat = quats.reshape(-1, 4)
    all_mats = quaternion_matrix_batch(quat_flat)
    quat_6d = convert_mat_to_6d(torch.tensor(all_mats))


    shape_curr = list(quats.shape)
    shape_curr[-1] = 6
    shape_curr = tuple([int(i) for i in shape_curr])
    quat_6d = quat_6d.reshape(shape_curr)
    return quat_6d

def convert_6d_to_quat(orth6d):
    num_joints = int(orth6d.shape[-1]/6)
    orth6d_flat = orth6d.reshape(-1, 6)

    rot_mat6d = compute_rotation_matrix_from_ortho6d(orth6d_flat)[:,:,:3]
    quats = rotation_matrix_to_quaternion(rot_mat6d)

    shape_curr = list(orth6d.shape)
    shape_curr[-1] = 4
    shape_curr = tuple([int(i) for i in shape_curr])
    quats = quats.reshape(shape_curr) 

    return quats




def normalize_vector( v, return_mag =False):
    batch=v.shape[0]
    v_mag = torch.sqrt(v.pow(2).sum(1))# batch
    v_mag = torch.max(v_mag, torch.autograd.Variable(torch.tensor([1e-8], dtype = v_mag.dtype).to(v.device)))
    v_mag = v_mag.view(batch,1).expand(batch,v.shape[1])
    v = v/v_mag
    if(return_mag==True):
        return v, v_mag[:,0]
    else:
        return v


def vertizalize_smpl_root(poses, root_vec = [np.pi/2, 0,0]):
    device = poses.device
    target_mat = angle_axis_to_rotation_matrix(torch.tensor([root_vec], dtype = poses.dtype).to(device))[:,:3,:3].to(device)
    org_mats =  angle_axis_to_rotation_matrix(poses[:, :3])[:,:3,:3].to(device)
    org_mat_inv = torch.inverse(org_mats[0]).to(device)
    apply_mat = torch.matmul(target_mat, org_mat_inv)
    res_root_mat = torch.matmul(apply_mat, org_mats)
    zeros = torch.zeros((res_root_mat.shape[0], res_root_mat.shape[1], 1), dtype=res_root_mat.dtype).to(device)
    res_root_mats_4 = torch.cat((res_root_mat, zeros), 2) #batch*3*4
    res_root_aa = rotation_matrix_to_angle_axis(res_root_mats_4)

    poses[:, :3] = res_root_aa
    # print(res_root_aa)
    return poses

def rot6d_to_rotmat(x):
    x = x.view(-1,3,2)

    # Normalize the first vector
    b1 = F.normalize(x[:, :, 0], dim=1, eps=1e-6)

    dot_prod = torch.sum(b1 * x[:, :, 1], dim=1, keepdim=True)
    # Compute the second vector by finding the orthogonal complement to it
    b2 = F.normalize(x[:, :, 1] - dot_prod * b1, dim=-1, eps=1e-6)

    # Finish building the basis by taking the cross product
    b3 = torch.cross(b1, b2, dim=1)
    rot_mats = torch.stack([b1, b2, b3], dim=-1)

    return rot_mats


def perspective_projection_cam(pred_joints, pred_camera):
    pred_cam_t = torch.stack([pred_camera[:, 1],
                              pred_camera[:, 2],
                              2 * 5000. / (224. * pred_camera[:, 0] + 1e-9)], dim=-1)
    batch_size = pred_joints.shape[0]
    camera_center = torch.zeros(batch_size, 2)
    pred_keypoints_2d = perspective_projection(pred_joints,
                                               rotation=torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(pred_joints.device),
                                               translation=pred_cam_t,
                                               focal_length=5000.,
                                               camera_center=camera_center)
    # Normalize keypoints to [-1,1]
    pred_keypoints_2d = pred_keypoints_2d / (224. / 2.)
    return pred_keypoints_2d

def perspective_projection(points, rotation, translation,
                           focal_length, camera_center):
    """
    This function computes the perspective projection of a set of points.
    Input:
        points (bs, N, 3): 3D points
        rotation (bs, 3, 3): Camera rotation
        translation (bs, 3): Camera translation
        focal_length (bs,) or scalar: Focal length
        camera_center (bs, 2): Camera center
    """
    batch_size = points.shape[0]
    K = torch.zeros([batch_size, 3, 3], device=points.device)
    K[:,0,0] = focal_length
    K[:,1,1] = focal_length
    K[:,2,2] = 1.
    K[:,:-1, -1] = camera_center

    # Transform points
    points = torch.einsum('bij,bkj->bki', rotation, points)
    points = points + translation.unsqueeze(1)

    # Apply perspective distortion
    projected_points = points / points[:,:,-1].unsqueeze(-1)

    # Apply camera intrinsics
    projected_points = torch.einsum('bij,bkj->bki', K, projected_points)

    return projected_points[:, :, :-1]

def quat_smooth(quat, ratio = 0.3):
    """ Converts quaternion to minimize Euclidean distance from previous quaternion (wxyz order) """
    for q in range(1, quat.shape[0]):
        quat[q] = quaternion_slerp(quat[q-1], quat[q], ratio)
    return quat
