# This schedule follows the methodology proposed by Intel Labs China in the paper:
#   Dynamic Network Surgery for Efficient DNNs, Yiwen Guo, Anbang Yao, Yurong Chen.
#   NIPS 2016, https://arxiv.org/abs/1600.604493.
#
# Top1 is 75.492 (on Epoch: 93) vs the published Top1: 76.15 (https://pytorch.org/docs/stable/torchvision/models.html)
# Total sparsity: 80.05
#
# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.network_surgery.yaml --validation-split=0  --masks-sparsity --num-best-scores=10
#
# Parameters:
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11209 | -0.00048 |    0.06861 |
# |  1 | module.layer1.0.conv1.weight        | (64, 64, 1, 1)     |          4096 |            930 |    0.00000 |    0.00000 |  3.12500 | 77.29492 |  7.81250 |   77.29492 | 0.06047 | -0.00461 |    0.02365 |
# |  2 | module.layer1.0.conv2.weight        | (64, 64, 3, 3)     |         36864 |           5217 |    0.00000 |    0.00000 |  7.81250 | 50.19531 |  6.25000 |   85.84798 | 0.02355 |  0.00061 |    0.00760 |
# |  3 | module.layer1.0.conv3.weight        | (256, 64, 1, 1)    |         16384 |           2579 |    0.00000 |    0.00000 |  6.25000 | 84.25903 | 13.67188 |   84.25903 | 0.02869 |  0.00024 |    0.01018 |
# |  4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1)    |         16384 |           3192 |    0.00000 |    0.00000 |  1.56250 | 80.51758 | 14.45312 |   80.51758 | 0.04863 | -0.00286 |    0.01787 |
# |  5 | module.layer1.1.conv1.weight        | (64, 256, 1, 1)    |         16384 |           2071 |    0.00000 |    0.00000 | 14.06250 | 87.35962 |  6.25000 |   87.35962 | 0.02289 |  0.00075 |    0.00767 |
# |  6 | module.layer1.1.conv2.weight        | (64, 64, 3, 3)     |         36864 |           4192 |    0.00000 |    0.00000 |  6.25000 | 51.92871 |  0.00000 |   88.62847 | 0.02106 |  0.00024 |    0.00653 |
# |  7 | module.layer1.1.conv3.weight        | (256, 64, 1, 1)    |         16384 |           1994 |    0.00000 |    0.00000 |  0.00000 | 87.82959 |  7.03125 |   87.82959 | 0.02508 |  0.00016 |    0.00811 |
# |  8 | module.layer1.2.conv1.weight        | (64, 256, 1, 1)    |         16384 |           3433 |    0.00000 |    0.00000 |  7.81250 | 79.04663 |  0.00000 |   79.04663 | 0.02381 | -0.00008 |    0.01026 |
# |  9 | module.layer1.2.conv2.weight        | (64, 64, 3, 3)     |         36864 |           4777 |    0.00000 |    0.00000 |  0.00000 | 43.87207 |  0.00000 |   87.04156 | 0.02238 | -0.00031 |    0.00772 |
# | 10 | module.layer1.2.conv3.weight        | (256, 64, 1, 1)    |         16384 |           1970 |    0.00000 |    0.00000 |  0.00000 | 87.97607 | 10.15625 |   87.97607 | 0.02456 | -0.00123 |    0.00793 |
# | 11 | module.layer2.0.conv1.weight        | (128, 256, 1, 1)   |         32768 |           6464 |    0.00000 |    0.00000 |  3.90625 | 80.27344 |  0.00000 |   80.27344 | 0.02816 | -0.00097 |    0.01151 |
# | 12 | module.layer2.0.conv2.weight        | (128, 128, 3, 3)   |        147456 |          26647 |    0.00000 |    0.00000 |  0.00000 | 38.12256 |  0.00000 |   81.92885 | 0.01681 | -0.00017 |    0.00674 |
# | 13 | module.layer2.0.conv3.weight        | (512, 128, 1, 1)   |         65536 |           7191 |    0.00000 |    0.00000 |  0.00000 | 89.02740 | 27.53906 |   89.02740 | 0.02141 |  0.00021 |    0.00645 |
# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1)   |        131072 |          14155 |    0.00000 |    0.00000 |  0.00000 | 89.20059 | 14.25781 |   89.20059 | 0.01808 | -0.00021 |    0.00508 |
# | 15 | module.layer2.1.conv1.weight        | (128, 512, 1, 1)   |         65536 |           6929 |    0.00000 |    0.00000 | 17.77344 | 89.42719 |  0.00000 |   89.42719 | 0.01286 |  0.00023 |    0.00379 |
# | 16 | module.layer2.1.conv2.weight        | (128, 128, 3, 3)   |        147456 |          15005 |    0.00000 |    0.00000 |  0.00000 | 61.30371 |  2.34375 |   89.82408 | 0.01468 |  0.00031 |    0.00421 |
# | 17 | module.layer2.1.conv3.weight        | (512, 128, 1, 1)   |         65536 |           6420 |    0.00000 |    0.00000 |  0.00000 | 90.20386 | 20.11719 |   90.20386 | 0.01755 | -0.00087 |    0.00493 |
# | 18 | module.layer2.2.conv1.weight        | (128, 512, 1, 1)   |         65536 |          10180 |    0.00000 |    0.00000 |  2.53906 | 84.46655 |  0.00000 |   84.46655 | 0.01792 | -0.00034 |    0.00644 |
# | 19 | module.layer2.2.conv2.weight        | (128, 128, 3, 3)   |        147456 |          16802 |    0.00000 |    0.00000 |  0.00000 | 51.42822 |  0.00000 |   88.60541 | 0.01552 | -0.00007 |    0.00489 |
# | 20 | module.layer2.2.conv3.weight        | (512, 128, 1, 1)   |         65536 |           7136 |    0.00000 |    0.00000 |  0.00000 | 89.11133 |  5.85938 |   89.11133 | 0.01884 | -0.00004 |    0.00574 |
# | 21 | module.layer2.3.conv1.weight        | (128, 512, 1, 1)   |         65536 |           9659 |    0.00000 |    0.00000 |  1.56250 | 85.26154 |  0.00000 |   85.26154 | 0.01777 | -0.00035 |    0.00645 |
# | 22 | module.layer2.3.conv2.weight        | (128, 128, 3, 3)   |        147456 |          26960 |    0.00000 |    0.00000 |  0.00000 | 27.93579 |  0.00000 |   81.71658 | 0.01683 | -0.00027 |    0.00683 |
# | 23 | module.layer2.3.conv3.weight        | (512, 128, 1, 1)   |         65536 |           8708 |    0.00000 |    0.00000 |  0.00000 | 86.71265 | 17.38281 |   86.71265 | 0.01841 | -0.00022 |    0.00622 |
# | 24 | module.layer3.0.conv1.weight        | (256, 512, 1, 1)   |        131072 |          21811 |    0.00000 |    0.00000 |  0.00000 | 83.35953 |  0.00000 |   83.35953 | 0.02350 | -0.00039 |    0.00884 |
# | 25 | module.layer3.0.conv2.weight        | (256, 256, 3, 3)   |        589824 |         121299 |    0.00000 |    0.00000 |  0.00000 | 39.09607 |  0.00000 |   79.43471 | 0.01348 | -0.00020 |    0.00567 |
# | 26 | module.layer3.0.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          31275 |    0.00000 |    0.00000 |  0.00000 | 88.06953 |  6.05469 |   88.06953 | 0.01678 |  0.00009 |    0.00543 |
# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1)  |        524288 |          56268 |    0.00000 |    0.00000 |  0.00000 | 89.26773 |  5.27344 |   89.26773 | 0.01169 |  0.00011 |    0.00349 |
# | 28 | module.layer3.1.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          30611 |    0.00000 |    0.00000 |  8.39844 | 88.32283 |  0.00000 |   88.32283 | 0.01105 | -0.00007 |    0.00350 |
# | 29 | module.layer3.1.conv2.weight        | (256, 256, 3, 3)   |        589824 |          70111 |    0.00000 |    0.00000 |  0.00000 | 52.11945 |  0.00000 |   88.11323 | 0.01070 | -0.00001 |    0.00345 |
# | 30 | module.layer3.1.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          29345 |    0.00000 |    0.00000 |  0.00000 | 88.80577 |  2.34375 |   88.80577 | 0.01450 | -0.00058 |    0.00446 |
# | 31 | module.layer3.2.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          28788 |    0.00000 |    0.00000 |  2.14844 | 89.01825 |  0.00000 |   89.01825 | 0.01121 | -0.00007 |    0.00343 |
# | 32 | module.layer3.2.conv2.weight        | (256, 256, 3, 3)   |        589824 |          74201 |    0.00000 |    0.00000 |  0.00000 | 43.12592 |  0.00000 |   87.41981 | 0.01048 | -0.00024 |    0.00353 |
# | 33 | module.layer3.2.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          30558 |    0.00000 |    0.00000 |  0.00000 | 88.34305 |  0.87891 |   88.34305 | 0.01349 | -0.00015 |    0.00431 |
# | 34 | module.layer3.3.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          30595 |    0.00000 |    0.00000 |  0.48828 | 88.32893 |  0.00000 |   88.32893 | 0.01214 | -0.00004 |    0.00388 |
# | 35 | module.layer3.3.conv2.weight        | (256, 256, 3, 3)   |        589824 |          77216 |    0.00000 |    0.00000 |  0.00000 | 40.67383 |  0.00000 |   86.90864 | 0.01034 | -0.00017 |    0.00359 |
# | 36 | module.layer3.3.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          31672 |    0.00000 |    0.00000 |  0.00000 | 87.91809 |  3.80859 |   87.91809 | 0.01290 | -0.00035 |    0.00422 |
# | 37 | module.layer3.4.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          32026 |    0.00000 |    0.00000 |  0.09766 | 87.78305 |  0.00000 |   87.78305 | 0.01259 | -0.00013 |    0.00415 |
# | 38 | module.layer3.4.conv2.weight        | (256, 256, 3, 3)   |        589824 |          78997 |    0.00000 |    0.00000 |  0.00000 | 40.54871 |  0.00000 |   86.60668 | 0.01032 | -0.00024 |    0.00363 |
# | 39 | module.layer3.4.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          31803 |    0.00000 |    0.00000 |  0.00000 | 87.86812 |  1.26953 |   87.86812 | 0.01293 | -0.00056 |    0.00424 |
# | 40 | module.layer3.5.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          33310 |    0.00000 |    0.00000 |  0.00000 | 87.29324 |  0.00000 |   87.29324 | 0.01362 | -0.00006 |    0.00460 |
# | 41 | module.layer3.5.conv2.weight        | (256, 256, 3, 3)   |        589824 |          80320 |    0.00000 |    0.00000 |  0.00000 | 42.06696 |  0.00000 |   86.38238 | 0.01058 | -0.00027 |    0.00375 |
# | 42 | module.layer3.5.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          33272 |    0.00000 |    0.00000 |  0.00000 | 87.30774 |  1.66016 |   87.30774 | 0.01368 | -0.00094 |    0.00463 |
# | 43 | module.layer4.0.conv1.weight        | (512, 1024, 1, 1)  |        524288 |         138518 |    0.00000 |    0.00000 |  0.00000 | 73.57979 |  0.00000 |   73.57979 | 0.01890 | -0.00045 |    0.00915 |
# | 44 | module.layer4.0.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         565539 |    0.00000 |    0.00000 |  0.00000 | 33.63037 |  0.00000 |   76.02933 | 0.00967 | -0.00020 |    0.00454 |
# | 45 | module.layer4.0.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |         230901 |    0.00000 |    0.00000 |  0.00000 | 77.97956 |  0.00000 |   77.97956 | 0.01184 | -0.00016 |    0.00526 |
# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) |       2097152 |         367404 |    0.00000 |    0.00000 |  0.00000 | 82.48081 |  0.00000 |   82.48081 | 0.00742 |  0.00003 |    0.00292 |
# | 47 | module.layer4.1.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |         292890 |    0.00000 |    0.00000 |  0.00000 | 72.06783 |  0.00000 |   72.06783 | 0.01192 | -0.00036 |    0.00595 |
# | 48 | module.layer4.1.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         710178 |    0.00000 |    0.00000 |  0.00000 | 19.51637 |  0.00000 |   69.89873 | 0.00989 | -0.00050 |    0.00518 |
# | 49 | module.layer4.1.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |         229702 |    0.00000 |    0.00000 |  0.00000 | 78.09391 |  0.00000 |   78.09391 | 0.01162 |  0.00020 |    0.00515 |
# | 50 | module.layer4.2.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |         305143 |    0.00000 |    0.00000 |  0.00000 | 70.89930 |  0.00000 |   70.89930 | 0.01464 | -0.00017 |    0.00748 |
# | 51 | module.layer4.2.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         618775 |    0.00000 |    0.00000 |  0.00000 | 43.47458 |  0.00000 |   73.77290 | 0.00855 | -0.00036 |    0.00423 |
# | 52 | module.layer4.2.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |         172278 |    0.00000 |    0.00000 |  0.00000 | 83.57029 |  0.19531 |   83.57029 | 0.01065 |  0.00029 |    0.00403 |
# | 53 | module.fc.weight                    | (1000, 2048)       |       2048000 |         330450 |    0.00000 |    0.14648 |  0.00000 |  0.00000 |  0.00000 |   83.86475 | 0.03022 |  0.00471 |    0.01085 |
# | 54 | Total sparsity:                     | -                  |      25502912 |        5087275 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   80.05218 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# 2018-11-09 01:36:20,957 - Total sparsity: 80.05
#
# 018-11-09 01:36:21,115 - --- validate (epoch=96)-----------
# 2018-11-09 01:36:21,115 - 50000 samples (256 per mini-batch)
# 2018-11-09 01:36:41,267 - Epoch: [96][   50/  195]    Loss 0.698889    Top1 81.335938    Top5 95.867188
# 2018-11-09 01:36:49,363 - Epoch: [96][  100/  195]    Loss 0.816939    Top1 78.742188    Top5 94.550781
# 2018-11-09 01:36:57,669 - Epoch: [96][  150/  195]    Loss 0.930078    Top1 76.424479    Top5 93.192708
# 2018-11-09 01:37:04,449 - ==> Top1: 75.406    Top5: 92.698    Loss: 0.974
#
# 2018-11-09 01:37:04,525 - ==> Best Top1: 75.492 on Epoch: 93
# 2018-11-09 01:37:04,525 - ==> Best Top1: 75.480 on Epoch: 83
# 2018-11-09 01:37:04,525 - ==> Best Top1: 75.480 on Epoch: 87
# 2018-11-09 01:37:04,526 - ==> Best Top1: 75.474 on Epoch: 24
# 2018-11-09 01:37:04,526 - ==> Best Top1: 75.472 on Epoch: 92
# 2018-11-09 01:37:04,526 - ==> Best Top1: 75.456 on Epoch: 27
# 2018-11-09 01:37:04,526 - ==> Best Top1: 75.446 on Epoch: 26
# 2018-11-09 01:37:04,526 - ==> Best Top1: 75.444 on Epoch: 89
# 2018-11-09 01:37:04,526 - ==> Best Top1: 75.444 on Epoch: 94
# 2018-11-09 01:37:04,526 - ==> Best Top1: 75.438 on Epoch: 80
# 2018-11-09 01:37:04,526 - Saving checkpoint to: logs/2018.11.06-213739/checkpoint.pth.tar


version: 1
pruners:
  pruner1:
    class: SplicingPruner
    low_thresh_mult: 0.9 # 0.6
    hi_thresh_mult: 1.1 # 0.7
    sensitivity_multiplier: 0.015
    sensitivities:
      #'module.conv1.weight': 0.60
      module.layer1.0.conv1.weight: 0.10
      module.layer1.0.conv2.weight: 0.40
      module.layer1.0.conv3.weight: 0.40
      module.layer1.0.downsample.0.weight: 0.20
      module.layer1.1.conv1.weight: 0.60
      module.layer1.1.conv2.weight: 0.60
      module.layer1.1.conv3.weight: 0.60
      module.layer1.2.conv1.weight: 0.30
      module.layer1.2.conv2.weight: 0.60
      module.layer1.2.conv3.weight: 0.60

      module.layer2.0.conv1.weight: 0.30
      module.layer2.0.conv2.weight: 0.40
      module.layer2.0.conv3.weight: 0.60
      module.layer2.0.downsample.0.weight: 0.50
      module.layer2.1.conv1.weight: 0.60
      module.layer2.1.conv2.weight: 0.60
      module.layer2.1.conv3.weight: 0.60
      module.layer2.2.conv1.weight: 0.40
      module.layer2.2.conv2.weight: 0.60
      module.layer2.2.conv3.weight: 0.60
      module.layer2.3.conv1.weight: 0.50
      module.layer2.3.conv2.weight: 0.40
      module.layer2.3.conv3.weight: 0.50

      module.layer3.0.conv1.weight: 0.40
      module.layer3.0.conv2.weight: 0.30
      module.layer3.0.conv3.weight: 0.60
      module.layer3.0.downsample.0.weight: 0.60
      module.layer3.1.conv1.weight: 0.60
      module.layer3.1.conv2.weight: 0.60
      module.layer3.1.conv3.weight: 0.60
      module.layer3.2.conv1.weight: 0.60
      module.layer3.2.conv2.weight: 0.60
      module.layer3.2.conv3.weight: 0.60
      module.layer3.3.conv1.weight: 0.60
      module.layer3.3.conv2.weight: 0.60
      module.layer3.3.conv3.weight: 0.60
      module.layer3.4.conv1.weight: 0.60
      module.layer3.4.conv2.weight: 0.60
      module.layer3.4.conv3.weight: 0.60
      module.layer3.5.conv1.weight: 0.60
      module.layer3.5.conv2.weight: 0.60
      module.layer3.5.conv3.weight: 0.60

      module.layer4.0.conv1.weight: 0.20
      module.layer4.0.conv2.weight: 0.30
      module.layer4.0.conv3.weight: 0.30
      module.layer4.0.downsample.0.weight: 0.40
      module.layer4.1.conv1.weight: 0.15
      module.layer4.1.conv2.weight: 0.15
      module.layer4.1.conv3.weight: 0.30
      module.layer4.2.conv1.weight: 0.15
      module.layer4.2.conv2.weight: 0.30
      module.layer4.2.conv3.weight: 0.45
      module.fc.weight: 0.50

lr_schedulers:
  training_lr:
    class: StepLR
    step_size: 45
    gamma: 0.10

policies:
  - pruner:
      instance_name: pruner1
      args:
        keep_mask: True
        #mini_batch_pruning_frequency: 1
        mask_on_forward_only: True
    starting_epoch: 0
    ending_epoch: 47
    frequency: 1


  - lr_scheduler:
      instance_name: training_lr
    starting_epoch: 0
    ending_epoch: 400
    frequency: 1
