# This script performs DropFilter - a regularization method similar to Dropout, which drops entire convolutional
# filters, instead of mere neurons.
# However, unlike the original intent of DropFilter - to act as a regularizer and reduce the generalization error
# of the network, here we employ higher rates of filter-dropping (rates are increased over time by following an AGP
# schedule) in order to make the network more robust to filter-pruning.  We test this robustness using sensitivity
# analysis.
#
# References:
# [1] DropFilter: Dropout for Convolutions
#     Zhengsu Chen Jianwei Niu Qi Tian
#     https://arxiv.org/abs/1810.09849
# [2] DropFilter: A Novel Regularization Method for Learning Convolutional Neural Networks
#     Hengyue Pan, Hui Jiang, Xin Niu, Yong Dou
#     https://arxiv.org/abs/1811.06783
#
#
#
# time python3 compress_classifier.py --arch=plain20_cifar ../../../data.cifar --lr=0.3 --epochs=180 --compress=plain20_cifar_dropfilter_training.yaml -p=50 --gpus=0 --masks-sparsity --vs=0 --epochs=220
#
# --- validate (epoch=219)-----------
# 10000 samples (256 per mini-batch)
# ==> Top1: 89.410    Top5: 99.550    Loss: 0.454
#
# ==> Best [Top1: 89.610   Top5: 99.560   Sparsity:0.00   Params: 268336 on epoch: 139]
# Saving checkpoint to: logs/2019.03.24-133353/checkpoint.pth.tar
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# ==> Top1: 89.410    Top5: 99.550    Loss: 0.422
#
# real    37m16.853s
# user    131m1.775s
# sys     15m12.706s

lr_schedulers:
  training_lr:
    class: StepLR
    step_size: 45
    gamma: 0.20

pruners:
  random_filter_pruner:
    class: BernoulliFilterPruner_AGP
    initial_sparsity : 0.05
    final_sparsity: 0.50
    group_type: Filters
    weights: [module.conv1.weight,
              module.layer1.0.conv1.weight, module.layer1.1.conv1.weight, module.layer1.2.conv1.weight,
              module.layer1.0.conv2.weight, module.layer1.1.conv2.weight, module.layer1.2.conv2.weight,
              module.layer2.0.conv1.weight, module.layer2.1.conv1.weight, module.layer2.2.conv1.weight,
              module.layer2.0.conv2.weight, module.layer2.1.conv2.weight, module.layer2.2.conv2.weight,
              module.layer3.0.conv1.weight, module.layer3.1.conv1.weight, module.layer3.2.conv1.weight,
              module.layer3.0.conv2.weight, module.layer3.1.conv2.weight, module.layer3.2.conv2.weight]

policies:
  - lr_scheduler:
      instance_name: training_lr
    starting_epoch: 30
    ending_epoch: 200
    frequency: 1

  - pruner:
      instance_name: random_filter_pruner
      args:
        mini_batch_pruning_frequency: 16
        discard_masks_at_minibatch_end: True
        use_double_copies: True
        mask_on_forward_only: True
        mask_gradients: True
    starting_epoch: 15
    ending_epoch: 220
    frequency: 1
