# This script performs DropFilter - a regularization method similar to Dropout, which drops entire convolutional
# filters, instead of mere neurons.
#
# The sample below increase the test Top1 of Plain20 by 0.3% (from 90.55 to 90.85) when averaged across 3 trials.
#
# References:
# [1] DropFilter: Dropout for Convolutions
#     Zhengsu Chen Jianwei Niu Qi Tian
#     https://arxiv.org/abs/1810.09849
# [2] DropFilter: A Novel Regularization Method for Learning Convolutional Neural Networks
#     Hengyue Pan, Hui Jiang, Xin Niu, Yong Dou
#     https://arxiv.org/abs/1811.06783
#
#
#
# time python3 compress_classifier.py --arch=plain20_cifar ../../../data.cifar --lr=0.3 --epochs=180 --compress=../drop_filter/plain20_cifar_dropfilter_training_regularization.yaml -p=50 --gpus=0 --masks-sparsity --vs=0
#
# --- validate (epoch=179)-----------
# 10000 samples (256 per mini-batch)
# ==> Top1: 90.760    Top5: 99.650    Loss: 0.362
#
# ==> Best [Top1: 90.880   Top5: 99.650   Sparsity:0.00   Params: 268336 on epoch: 170]
# Saving checkpoint to: logs/2019.03.23-185744/checkpoint.pth.tar
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# ==> Top1: 90.760    Top5: 99.650    Loss: 0.359
#
#
# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2019.03.23-185744/2019.03.23-185744.log
#
# real    30m59.049s
# user    109m16.449s
# sys     12m9.995s

lr_schedulers:
  training_lr:
    class: StepLR
    step_size: 45
    gamma: 0.20

pruners:
  random_filter_pruner:
    class: BernoulliFilterPruner
    desired_sparsity: 0.1
    group_type: Filters
    weights: [module.conv1.weight,
              module.layer1.0.conv1.weight, module.layer1.1.conv1.weight, module.layer1.2.conv1.weight,
              module.layer1.0.conv2.weight, module.layer1.1.conv2.weight, module.layer1.2.conv2.weight,
              module.layer2.0.conv1.weight, module.layer2.1.conv1.weight, module.layer2.2.conv1.weight,
              module.layer2.0.conv2.weight, module.layer2.1.conv2.weight, module.layer2.2.conv2.weight,
              module.layer3.0.conv1.weight, module.layer3.1.conv1.weight, module.layer3.2.conv1.weight,
              module.layer3.0.conv2.weight, module.layer3.1.conv2.weight, module.layer3.2.conv2.weight]

policies:
  - lr_scheduler:
      instance_name: training_lr
    starting_epoch: 30
    ending_epoch: 200
    frequency: 1

  - pruner:
      instance_name: random_filter_pruner
      args:
        mini_batch_pruning_frequency: 16
        discard_masks_at_minibatch_end: True
        # use_double_copies: True
        mask_on_forward_only: True
        mask_gradients: True
    starting_epoch: 15
    ending_epoch: 180
    frequency: 1
