'''
Config of FSD in Waymo Open Dataset
'''

_base_ = [
    '../_base_/datasets/waymoD5-3d-3class.py', # standard waymo data config in MMDetection3D.
]

sst_voxel_size = (0.32, 0.32, 6)
point_cloud_range = [-74.88, -74.88, -2, 74.88, 74.88, 4]
class_names = ['Car', 'Pedestrian', 'Cyclist']
num_classes = len(class_names)

# Only use points with scores higher than the thresholds to vote
foreground_score_thresh = (0.5, 0.25, 0.25)

instance_point_grouping = dict(

    type='IPG',

    # Use the implementation in MMDetection3D without modification
    voxel_layer=dict(
        voxel_size=sst_voxel_size,
        max_num_points=-1,
        point_cloud_range=point_cloud_range,
        max_voxels=(-1, -1)
    ),

    # Use the implementation in MMDetection3D without modification
    voxel_encoder=dict(
        type='DynamicVFE',
        in_channels=5,
        feat_channels=[64, 128],
        with_distance=False,
        voxel_size=sst_voxel_size,
        with_cluster_center=True,
        with_voxel_center=True,
        point_cloud_range=point_cloud_range,
        norm_cfg=dict(type='naiveSyncBN1d', eps=1e-3, momentum=0.01),
    ),

    # Use the implementation in SST without modification
    middle_encoder=dict(
        type='SSTInputLayerV2',
        window_shape=(12, 12, 1),
        sparse_shape=(468, 468, 1),
        shuffle_voxels=True,
        debug=False,
        drop_info={
            0:{'max_tokens':30, 'drop_range':(0, 30)},
            1:{'max_tokens':60, 'drop_range':(30, 60)},
            2:{'max_tokens':100, 'drop_range':(60, 100000)},
        },
        pos_temperature=1000,
        normalize_pos=False,
    ),

    # Use the implementation in SST with minor modification (do not convert to dense BEV feature maps)
    backbone=dict(
        type='SSTv2',
        d_model=[128, 128, 128, 128],
        nhead=[8, 8, 8, 8],
        num_blocks=4,
        dim_feedforward=[256, 256, 256, 256],
        conv_in_channel=128,
        conv_out_channel=128,
        to_bev=False, # Do not convert to dense BEV feature maps
    ),

    # Get point feature by concatenating voxel features with offsets between points and corresponding voxel centers
    decode_neck=dict(
        type='Voxel2PointNeck',
        voxel_size=sst_voxel_size,
        point_cloud_range=point_cloud_range,
    ),

    # Apply MLP to each point for classification and voting
    vote_head=dict(
        type='FSDVoteHead',
        in_channel=131,
        hidden_dims=[128, 128],
        num_classes=num_classes,
        loss_decode=dict( type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.8, loss_weight=1.0),
        loss_vote=dict( type='L1Loss', loss_weight=1.0),
    ),

    # Connected Components Labeling (CCL)
    ccl=dict(
        point_cloud_range=point_cloud_range,
        connected_dist=dict(Car=0.6, Cyclist=0.4, Pedestrian=0.15,),  # Connected distance thresholds (meters)
        class_names=class_names,
    ),
)

model = dict(
    type='FullySparseDetector',

    IPG=instance_point_grouping,

    # 3-layer SIR
    backbone=dict(
        type='SIR',
        num_layers=3,
        in_channels=[148,] + [133, 133],
        feat_channels=[[128, 128], [128, 128], [128, 128]],
        norm_cfg=dict(type='LN', eps=1e-3),
        mode='max',
        act='gelu',
    ),

    # A simple head to make sparse prediction via MLP
    bbox_head=dict(
        type='SparsePredictionHead',
        num_classes=num_classes,
        bbox_coder=dict(type='BasePointBBoxCoder'),

        # Focal loss for instance classification
        loss_cls=dict(type='FocalLoss', use_sigmoid=True, gamma=1.0, alpha=0.25, loss_weight=2.0),

        # L1 loss for regression
        loss_bbox=dict(type='L1Loss', loss_weight=1.0),

        in_channel=128 * 3, # concat the group feature from each SIR layer, so here is a x3
        hidden_dim=[128, 128],
        as_rpn=True,
    ),

    # Group Correction and SIR2, modified from the partA2 roi_head
    roi_head=dict(
        type='FullySparseROIHead',
        num_classes=num_classes,
        roi_extractor=dict(
             type='GroupCorrection',
             max_inbox_point=512,
        ),
        bbox_head=dict(
            type='SIR2',
            num_classes=num_classes,
            num_layers=3,
            in_channels=[275+2, 131+13+2, 131+13+2], 
            feat_channels=[[128, 128], [128, 128], [128, 128], ],
            mode='max',
            act='gelu',
            bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
            norm_cfg=dict(type='LN', eps=1e-3),
            loss_bbox=dict(type='L1Loss', reduction='mean', loss_weight=1.0),
            loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=True, reduction='mean', loss_weight=1.0),
        ),
    ),

    train_cfg=dict(
        score_thresh=foreground_score_thresh,
        rpn=dict(
            use_rotate_nms=True,
            nms_pre=-1,
            nms_thr=None,
            score_thr=0.1,
            max_num=500,
        ),
        rcnn=dict(
            assigner=dict(
                    type='MaxIoUAssigner',
                    iou_calculator=dict(type='BboxOverlaps3D', coordinate='lidar'),
                    pos_iou_thr=0.45,
                    neg_iou_thr=0.45,
                    min_pos_iou=0.45,
                    ignore_iof_thr=-1
            ),

            # proposal sampler adopted in PartA2
            sampler=dict(
                type='IoUNegPiecewiseSampler',
                num=256,
                pos_fraction=0.55,
                neg_piece_fractions=[0.8, 0.2],
                neg_iou_piece_thrs=[0.55, 0.1],
                neg_pos_ub=-1,
                add_gt_as_proposals=False,
                return_iou=True
            ),

            # soft IoU target used in SIR2
            cls_pos_thr=0.75,
            cls_neg_thr=0.25,
        )
    ),

    test_cfg=dict(
        score_thresh=foreground_score_thresh,
        rpn=dict(
            use_rotate_nms=True,
            nms_pre=-1,
            nms_thr=None, # do not use NMS after SIR 
            score_thr=0.1, 
            max_num=500,
        ),
        rcnn=dict(
            use_rotate_nms=True,
            nms_pre=-1,
            nms_thr=0.25, # use NMS after SIR2
            score_thr=0.1, 
            max_num=500,
        ),
    ),

)

# Training settings following SST

lr=1e-5
optimizer = dict(
    type='AdamW',
    lr=lr,
    betas=(0.9, 0.999),
    weight_decay=0.05,
    paramwise_cfg=dict(custom_keys={'norm': dict(decay_mult=0.)}),
    )
optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2))
lr_config = dict(
    policy='cyclic', # In fact, it's cosine schedule because we set 'cyclic_times=1'.
    target_ratio=(100, 1e-3),
    cyclic_times=1,
    step_ratio_up=0.1,
)

data = dict(
    samples_per_gpu=1, # batch size == 1
    workers_per_gpu=4,
    train=dict(
        type='RepeatDataset',
        times=1,
        dataset=dict(
            load_interval=1)
    ),
)
runner = dict(type='EpochBasedRunner', max_epochs=6) # training 6 epochs
