# This schedule follows the methodology proposed by Intel Labs China in the paper:
#   Dynamic Network Surgery for Efficient DNNs, Yiwen Guo, Anbang Yao, Yurong Chen.
#   NIPS 2016, https://arxiv.org/abs/1600.504493.
#
# We have not been able to see great results with our implementation of the paper, as we understood it.  In the paper
# two sets of weights are used: in the forward-pass Guo et. al use masked weights to compute the loss, but in the
# backward-pass they update the unmasked weights (using gradients computed from the masked-weights loss).
# To replicate this behavior set in the SplicingPruner policy:
#   use_double_copies: True
#   mask_on_forward_only: True
#
# We found that using two copies of weights reduces the accuracy results, and so we disable this configuration in the
# example schedule below.
#
# The "mask_on_forward_only" configuration controls what we do after the weights are updated by the backward pass.
# In issue #53 (https://github.com/NervanaSystems/distiller/issues/53) we explain why in some cases masked weights
# will be updated to a non-zero value, even if their gradients are masked (e.g. when using SGD with momentum).
# Therefore, to circumvent this weights-update performed by the backward pass, we usually mask the weights again -
# right after the backward pass.  To disable this masking set:
#   mask_on_forward_only: True
#
#
# Baseline results:
#     Top1: 91.780    Top5: 99.710    Loss: 0.376
#     Total MACs: 40,813,184
#     # of parameters: 270,896
#
# Results:
#     Best Top1: 91.43
#     Total MACs: 40,813,184
#     Total sparsity: 69.1%
#     # of parameters: 83,671
#
# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.01 --epochs=180 --compress=../network_surgery/resnet20.network_surgery.yaml --vs=0 --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --masks-sparsity --num-best-scores=5
#
# Parameters:
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.36322 | -0.00677 |    0.25811 |
# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |            641 |    0.00000 |    0.00000 |  6.25000 | 19.53125 |  0.00000 |   72.17882 | 0.12306 | -0.00732 |    0.05892 |
# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |            629 |    0.00000 |    0.00000 |  0.00000 | 19.14062 |  0.00000 |   72.69965 | 0.12015 | -0.00124 |    0.05719 |
# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |            670 |    0.00000 |    0.00000 |  0.00000 | 16.40625 |  0.00000 |   70.92014 | 0.10390 | -0.00880 |    0.05262 |
# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |            659 |    0.00000 |    0.00000 |  0.00000 | 14.84375 |  0.00000 |   71.39757 | 0.09757 | -0.00261 |    0.04895 |
# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |            714 |    0.00000 |    0.00000 |  0.00000 | 20.31250 |  0.00000 |   69.01042 | 0.13813 | -0.00660 |    0.07117 |
# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |            685 |    0.00000 |    0.00000 |  0.00000 | 19.14062 |  0.00000 |   70.26910 | 0.10906 |  0.00146 |    0.05552 |
# |  7 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           1534 |    0.00000 |    0.00000 |  0.00000 | 14.64844 |  0.00000 |   66.71007 | 0.10999 |  0.00096 |    0.05983 |
# |  8 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           2915 |    0.00000 |    0.00000 |  0.00000 |  8.30078 |  0.00000 |   68.37023 | 0.09245 | -0.00413 |    0.04926 |
# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            137 |    0.00000 |    0.00000 |  0.00000 | 73.24219 |  0.00000 |   73.24219 | 0.20471 | -0.00300 |    0.09534 |
# | 10 | module.layer2.1.conv1.weight        | (32, 32, 3, 3) |          9216 |           2738 |    0.00000 |    0.00000 |  0.00000 |  9.96094 |  0.00000 |   70.29080 | 0.08032 | -0.00450 |    0.04164 |
# | 11 | module.layer2.1.conv2.weight        | (32, 32, 3, 3) |          9216 |           2791 |    0.00000 |    0.00000 |  0.00000 |  8.39844 |  0.00000 |   69.71571 | 0.07054 | -0.00283 |    0.03704 |
# | 12 | module.layer2.2.conv1.weight        | (32, 32, 3, 3) |          9216 |           2770 |    0.00000 |    0.00000 |  0.00000 | 11.03516 |  0.00000 |   69.94358 | 0.08060 | -0.00651 |    0.04184 |
# | 13 | module.layer2.2.conv2.weight        | (32, 32, 3, 3) |          9216 |           2745 |    0.00000 |    0.00000 |  0.00000 | 11.52344 |  0.00000 |   70.21484 | 0.06489 |  0.00081 |    0.03376 |
# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |           5852 |    0.00000 |    0.00000 |  0.00000 | 13.37891 |  0.00000 |   68.25087 | 0.07912 | -0.00226 |    0.04262 |
# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          11772 |    0.00000 |    0.00000 |  0.00000 |  4.83398 |  0.00000 |   68.06641 | 0.07390 | -0.00186 |    0.03979 |
# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |            667 |    0.00000 |    0.00000 |  0.00000 | 67.43164 |  0.00000 |   67.43164 | 0.11225 | -0.00579 |    0.06123 |
# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          10906 |    0.00000 |    0.00000 |  0.00000 |  6.98242 |  0.00000 |   70.41558 | 0.07001 | -0.00362 |    0.03640 |
# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          10850 |    0.00000 |    0.00000 |  0.00000 |  8.20312 |  0.00000 |   70.56749 | 0.06218 | -0.00419 |    0.03227 |
# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11053 |    0.00000 |    0.00000 |  0.00000 | 10.42480 |  0.00000 |   70.01682 | 0.06197 | -0.00483 |    0.03245 |
# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11871 |    0.00000 |    0.00000 |  0.00000 | 23.90137 |  0.00000 |   67.79785 | 0.03901 |  0.00004 |    0.02103 |
# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.56654 | -0.00002 |    0.48838 |
# | 22 | Total sparsity:                     | -              |        270896 |          83671 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   69.11324 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# 2018-11-03 12:20:32,056 - Total sparsity: 69.11
# 2018-11-03 12:20:32,108 - --- validate (epoch=359)-----------
# 2018-11-03 12:20:32,109 - 10000 samples (256 per mini-batch)
# 2018-11-03 12:20:34,825 - ==> Top1: 91.430    Top5: 99.670    Loss: 0.368
#
# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.490 on Epoch: 339   <====== BEST RESULT
# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.480 on Epoch: 222
# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.470 on Epoch: 280
# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.450 on Epoch: 328
# 2018-11-03 12:20:34,827 - ==> Best Top1: 91.440 on Epoch: 298
# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.440 on Epoch: 314
# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.440 on Epoch: 319
# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.430 on Epoch: 283
# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.430 on Epoch: 304
# 2018-11-03 12:20:34,828 - ==> Best Top1: 91.430 on Epoch: 359
# 2018-11-03 12:20:34,828 - Saving checkpoint to: logs/2018.11.03-111011/checkpoint.pth.tar
# 2018-11-03 12:20:34,859 - --- test ---------------------
# 2018-11-03 12:20:34,859 - 10000 samples (256 per mini-batch)
# 2018-11-03 12:20:37,610 - ==> Top1: 91.430    Top5: 99.670    Loss: 0.368


version: 1
pruners:
  pruner1:
    class: SplicingPruner
    low_thresh_mult: 0.9 # 0.6
    hi_thresh_mult: 1.1 # 0.7
    sensitivity_multiplier: 0.015
    sensitivities:
      #module.conv1.weight: 0.50
      module.layer1.0.conv1.weight: 0.050
      module.layer1.0.conv2.weight: 0.050
      module.layer1.1.conv1.weight: 0.050
      module.layer1.1.conv2.weight: 0.050
      module.layer1.2.conv1.weight: 0.010
      module.layer1.2.conv2.weight: 0.050
      module.layer2.0.conv1.weight: 0.010
      module.layer2.0.conv2.weight: 0.030
      module.layer2.0.downsample.0.weight: 0.050
      module.layer2.1.conv1.weight: 0.050
      module.layer2.1.conv2.weight: 0.050
      module.layer2.2.conv1.weight: 0.040
      module.layer2.2.conv2.weight: 0.050
      module.layer3.0.conv1.weight: 0.040
      module.layer3.0.conv2.weight: 0.020
      module.layer3.0.downsample.0.weight: 0.050
      module.layer3.1.conv1.weight: 0.050
      module.layer3.1.conv2.weight: 0.050
      module.layer3.2.conv1.weight: 0.050
      module.layer3.2.conv2.weight: 0.050
      #module.fc.weight

lr_schedulers:
  training_lr:
    class: StepLR
    step_size: 45
    gamma: 0.10

policies:
  - pruner:
      instance_name: pruner1
      args:
        keep_mask: True
        mini_batch_pruning_frequency: 1
        mask_on_forward_only: True
        #use_double_copies: True
    starting_epoch: 0
    ending_epoch: 100
    frequency: 1


  - lr_scheduler:
      instance_name: training_lr
    starting_epoch: 0
    ending_epoch: 400
    frequency: 1
