import torch

from ..bbox import assign_and_sample, build_assigner, PseudoSampler
from ..utils import multi_apply


def point_hm_target(proposals_list,
                    valid_flag_list,
                    gt_bboxes_list,
                    gt_labels_list,
                    cfg,
                    unmap_outputs=True):
    """Compute refinement and classification targets for points.

    Args:
        points_list (list[list]): Multi level points of each image.
        valid_flag_list (list[list]): Multi level valid flags of each image.
        gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
        img_metas (list[dict]): Meta info of each image.
        cfg (dict): train sample configs.

    Returns:
        tuple
    """
    assert len(proposals_list) == len(valid_flag_list)

    # points number of multi levels
    num_level_proposals = [points.size(0) for points in proposals_list[0]]

    # concat all level points and flags to a single tensor
    for i in range(len(proposals_list)):
        assert len(proposals_list[i]) == len(valid_flag_list[i])
        proposals_list[i] = torch.cat(proposals_list[i])
        valid_flag_list[i] = torch.cat(valid_flag_list[i])

    (all_gt_hm_tl, all_gt_offset_tl, all_hm_tl_weights, all_offset_tl_weights, pos_inds_tl_list, neg_inds_tl_list,
     all_gt_hm_br, all_gt_offset_br, all_hm_br_weights, all_offset_br_weights, pos_inds_br_list, neg_inds_br_list) = \
        multi_apply(
        point_hm_target_single,
        proposals_list,
        valid_flag_list,
        gt_bboxes_list,
        gt_labels_list,
        cfg=cfg,
        unmap_outputs=unmap_outputs)

    # sampled points of all images
    num_total_pos_tl = sum([max(inds.numel(), 1) for inds in pos_inds_tl_list])
    num_total_neg_tl = sum([max(inds.numel(), 1) for inds in neg_inds_tl_list])
    num_total_pos_br = sum([max(inds.numel(), 1) for inds in pos_inds_br_list])
    num_total_neg_br = sum([max(inds.numel(), 1) for inds in neg_inds_br_list])

    gt_hm_tl_list = images_to_levels(all_gt_hm_tl, num_level_proposals)
    gt_offset_tl_list = images_to_levels(all_gt_offset_tl, num_level_proposals)
    hm_tl_weight_list = images_to_levels(all_hm_tl_weights, num_level_proposals)
    offset_tl_weight_list = images_to_levels(all_offset_tl_weights, num_level_proposals)

    gt_hm_br_list = images_to_levels(all_gt_hm_br, num_level_proposals)
    gt_offset_br_list = images_to_levels(all_gt_offset_br, num_level_proposals)
    hm_br_weight_list = images_to_levels(all_hm_br_weights, num_level_proposals)
    offset_br_weight_list = images_to_levels(all_offset_br_weights, num_level_proposals)

    return (gt_hm_tl_list, gt_offset_tl_list, hm_tl_weight_list, offset_tl_weight_list,
            gt_hm_br_list, gt_offset_br_list, hm_br_weight_list, offset_br_weight_list,
            num_total_pos_tl, num_total_neg_tl, num_total_pos_br, num_total_neg_br)


def images_to_levels(target, num_level_grids):
    """Convert targets by image to targets by feature level.

    [target_img0, target_img1] -> [target_level0, target_level1, ...]
    """
    target = torch.stack(target, 0)
    level_targets = []
    start = 0
    for n in num_level_grids:
        end = start + n
        level_targets.append(target[:, start:end].squeeze(0))
        start = end
    return level_targets


def point_hm_target_single(flat_points,
                        inside_flags,
                        gt_bboxes,
                        gt_labels,
                        cfg,
                        unmap_outputs=True):
    # assign gt and sample points
    points = flat_points[inside_flags, :]

    heatmap_assigner = build_assigner(cfg.assigner)
    gt_hm_tl, gt_offset_tl, pos_inds_tl, neg_inds_tl, \
    gt_hm_br, gt_offset_br, pos_inds_br, neg_inds_br = \
        heatmap_assigner.assign(points, gt_bboxes, gt_labels)

    num_valid_points = points.shape[0]
    hm_tl_weights = points.new_zeros(num_valid_points, dtype=torch.float)
    hm_br_weights = points.new_zeros(num_valid_points, dtype=torch.float)
    offset_tl_weights = points.new_zeros([num_valid_points, 2], dtype=torch.float)
    offset_br_weights = points.new_zeros([num_valid_points, 2], dtype=torch.float)

    hm_tl_weights[pos_inds_tl] = 1.0
    hm_tl_weights[neg_inds_tl] = 1.0
    offset_tl_weights[pos_inds_tl, :] = 1.0

    hm_br_weights[pos_inds_br] = 1.0
    hm_br_weights[neg_inds_br] = 1.0
    offset_br_weights[pos_inds_br, :] = 1.0

    # map up to original set of grids
    if unmap_outputs:
        num_total_points = flat_points.shape[0]
        gt_hm_tl = unmap(gt_hm_tl, num_total_points, inside_flags)
        gt_offset_tl = unmap(gt_offset_tl, num_total_points, inside_flags)
        hm_tl_weights = unmap(hm_tl_weights, num_total_points, inside_flags)
        offset_tl_weights = unmap(offset_tl_weights, num_total_points, inside_flags)

        gt_hm_br = unmap(gt_hm_br, num_total_points, inside_flags)
        gt_offset_br = unmap(gt_offset_br, num_total_points, inside_flags)
        hm_br_weights = unmap(hm_br_weights, num_total_points, inside_flags)
        offset_br_weights = unmap(offset_br_weights, num_total_points, inside_flags)

    return (gt_hm_tl, gt_offset_tl, hm_tl_weights, offset_tl_weights, pos_inds_tl, neg_inds_tl,
                gt_hm_br, gt_offset_br, hm_br_weights, offset_br_weights, pos_inds_br, neg_inds_br)


def unmap(data, count, inds, fill=0):
    """ Unmap a subset of item (data) back to the original set of items (of
    size count) """
    if data.dim() == 1:
        ret = data.new_full((count,), fill)
        ret[inds] = data
    else:
        new_size = (count,) + data.size()[1:]
        ret = data.new_full(new_size, fill)
        ret[inds, :] = data
    return ret
