#
# This schedule uses the average percentage of zeros (APoZ) in the activations, to rank filters.
#
# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters.yaml --validation-split=0   --num-best-scores=10 --name="resnet50_filters_v5_APoZ"   --act-stats=valid
#
# Results:
#   Best Top1: 73.926 on Epoch: 88
#   No. of Parameters: 12,335,296 (of 25,502,912) = 43.37% dense (51.63% sparse)
#   Total MACs: 1,822,031,872 (of 4,089,184,256) = 44.56% compute = 2.24x
#
# From our tests, using L1-norm to rank the filters produces better results than ranking filters using APoZ.
#
# Parameters:
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11140 | -0.00038 |    0.06812 |
# |  1 | module.layer1.0.conv1.weight        | (32, 64, 1, 1)     |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07328 | -0.00361 |    0.04403 |
# |  2 | module.layer1.0.conv2.weight        | (32, 32, 3, 3)     |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03226 |  0.00112 |    0.02039 |
# |  3 | module.layer1.0.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03809 | -0.00044 |    0.02411 |
# |  4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05170 | -0.00309 |    0.02873 |
# |  5 | module.layer1.1.conv1.weight        | (32, 256, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03025 |  0.00064 |    0.02077 |
# |  6 | module.layer1.1.conv2.weight        | (32, 32, 3, 3)     |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03562 | -0.00026 |    0.02537 |
# |  7 | module.layer1.1.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03181 |  0.00034 |    0.02038 |
# |  8 | module.layer1.2.conv1.weight        | (32, 256, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02967 |  0.00092 |    0.02179 |
# |  9 | module.layer1.2.conv2.weight        | (32, 32, 3, 3)     |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03353 |  0.00027 |    0.02527 |
# | 10 | module.layer1.2.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02930 | -0.00182 |    0.01762 |
# | 11 | module.layer2.0.conv1.weight        | (64, 256, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03387 | -0.00075 |    0.02381 |
# | 12 | module.layer2.0.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02261 |  0.00020 |    0.01692 |
# | 13 | module.layer2.0.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02730 |  0.00005 |    0.01709 |
# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02064 | -0.00046 |    0.01219 |
# | 15 | module.layer2.1.conv1.weight        | (64, 512, 1, 1)    |         32768 |          32709 |    0.00000 |    0.00000 |  0.00000 |  0.18005 |  0.00000 |    0.18005 | 0.01662 | -0.00007 |    0.00987 |
# | 16 | module.layer2.1.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02609 |  0.00103 |    0.01699 |
# | 17 | module.layer2.1.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02226 | -0.00126 |    0.01299 |
# | 18 | module.layer2.2.conv1.weight        | (64, 512, 1, 1)    |         32768 |          32762 |    0.00000 |    0.00000 |  0.00000 |  0.01831 |  0.00000 |    0.01831 | 0.02207 | -0.00059 |    0.01529 |
# | 19 | module.layer2.2.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02295 |  0.00058 |    0.01617 |
# | 20 | module.layer2.2.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02534 | -0.00018 |    0.01791 |
# | 21 | module.layer2.3.conv1.weight        | (64, 512, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02274 | -0.00091 |    0.01690 |
# | 22 | module.layer2.3.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02277 | -0.00017 |    0.01722 |
# | 23 | module.layer2.3.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02284 | -0.00045 |    0.01592 |
# | 24 | module.layer3.0.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02876 | -0.00082 |    0.02054 |
# | 25 | module.layer3.0.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01704 |  0.00005 |    0.01260 |
# | 26 | module.layer3.0.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02278 | -0.00033 |    0.01635 |
# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01442 | -0.00002 |    0.00989 |
# | 28 | module.layer3.1.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01416 | -0.00035 |    0.01013 |
# | 29 | module.layer3.1.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01573 |  0.00019 |    0.01148 |
# | 30 | module.layer3.1.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01877 | -0.00067 |    0.01360 |
# | 31 | module.layer3.2.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01479 | -0.00027 |    0.01061 |
# | 32 | module.layer3.2.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01525 | -0.00027 |    0.01140 |
# | 33 | module.layer3.2.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01807 | -0.00052 |    0.01331 |
# | 34 | module.layer3.3.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01624 | -0.00040 |    0.01203 |
# | 35 | module.layer3.3.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01500 | -0.00029 |    0.01149 |
# | 36 | module.layer3.3.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01732 | -0.00102 |    0.01288 |
# | 37 | module.layer3.4.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01668 | -0.00060 |    0.01245 |
# | 38 | module.layer3.4.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01506 | -0.00054 |    0.01152 |
# | 39 | module.layer3.4.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01712 | -0.00145 |    0.01274 |
# | 40 | module.layer3.5.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01813 | -0.00044 |    0.01367 |
# | 41 | module.layer3.5.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01540 | -0.00056 |    0.01172 |
# | 42 | module.layer3.5.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01822 | -0.00198 |    0.01375 |
# | 43 | module.layer4.0.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02157 | -0.00070 |    0.01664 |
# | 44 | module.layer4.0.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01212 | -0.00013 |    0.00939 |
# | 45 | module.layer4.0.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01466 | -0.00048 |    0.01138 |
# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) |       2097152 |        2097152 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00904 | -0.00023 |    0.00690 |
# | 47 | module.layer4.1.conv1.weight        | (256, 2048, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01414 | -0.00026 |    0.01102 |
# | 48 | module.layer4.1.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01180 | -0.00054 |    0.00927 |
# | 49 | module.layer4.1.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01469 |  0.00013 |    0.01141 |
# | 50 | module.layer4.2.conv1.weight        | (256, 2048, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01745 | -0.00000 |    0.01368 |
# | 51 | module.layer4.2.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01081 | -0.00037 |    0.00848 |
# | 52 | module.layer4.2.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01328 | -0.00003 |    0.01009 |
# | 53 | module.fc.weight                    | (1000, 2048)       |       2048000 |        2048000 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03337 |  0.00000 |    0.02297 |
# | 54 | Total sparsity:                     | -                  |      12335296 |       12335231 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00053 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# 2018-12-13 16:41:13,433 - Total sparsity: 0.00
#
# 2018-12-13 16:41:13,441 - --- validate (epoch=99)-----------
# 2018-12-13 16:41:13,441 - 50000 samples (256 per mini-batch)
# 2018-12-13 16:41:39,719 - Epoch: [99][   50/  195]    Loss 0.747123    Top1 80.109375    Top5 95.187500
# 2018-12-13 16:41:54,671 - Epoch: [99][  100/  195]    Loss 0.871733    Top1 77.542969    Top5 93.890625
# 2018-12-13 16:42:09,554 - Epoch: [99][  150/  195]    Loss 0.994601    Top1 74.940104    Top5 92.312500
# 2018-12-13 16:42:22,252 - ==> Top1: 73.878    Top5: 91.802    Loss: 1.045
#
# 2018-12-13 16:42:22,343 - Generating logs/resnet50_filters_v5_APoZ___2018.12.11-193402/apoz_channels
# 2018-12-13 16:42:22,778 - Generating logs/resnet50_filters_v5_APoZ___2018.12.11-193402/l1_channels
# 2018-12-13 16:42:23,027 - Generating logs/resnet50_filters_v5_APoZ___2018.12.11-193402/sparsity
# 2018-12-13 16:42:23,036 - ==> Best Top1: 76.546 on Epoch: 0
# 2018-12-13 16:42:23,036 - ==> Best Top1: 75.498 on Epoch: 1
# 2018-12-13 16:42:23,036 - ==> Best Top1: 74.726 on Epoch: 2
# 2018-12-13 16:42:23,036 - ==> Best Top1: 74.070 on Epoch: 3
# 2018-12-13 16:42:23,036 - ==> Best Top1: 73.926 on Epoch: 88  <====
# 2018-12-13 16:42:23,036 - ==> Best Top1: 73.908 on Epoch: 83
# 2018-12-13 16:42:23,036 - ==> Best Top1: 73.902 on Epoch: 92
# 2018-12-13 16:42:23,036 - ==> Best Top1: 73.898 on Epoch: 79
# 2018-12-13 16:42:23,036 - ==> Best Top1: 73.898 on Epoch: 96
# 2018-12-13 16:42:23,036 - ==> Best Top1: 73.896 on Epoch: 86
# 2018-12-13 16:42:23,036 - Saving checkpoint to: logs/resnet50_filters_v5_APoZ___2018.12.11-193402/resnet50_filters_v5_APoZ_checkpoint.pth.tar
# 2018-12-13 16:42:23,211 - --- test ---------------------
# 2018-12-13 16:42:23,211 - 50000 samples (256 per mini-batch)
# 2018-12-13 16:42:41,777 - Test: [   50/  195]    Loss 0.747123    Top1 80.109375    Top5 95.187500
# 2018-12-13 16:42:50,383 - Test: [  100/  195]    Loss 0.871733    Top1 77.542969    Top5 93.890625
# 2018-12-13 16:42:59,887 - Test: [  150/  195]    Loss 0.994601    Top1 74.940104    Top5 92.312500
# 2018-12-13 16:43:06,822 - ==> Top1: 73.878    Top5: 91.802    Loss: 1.045

version: 1

pruners:
  fc_pruner:
    class: AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.87
    weights: module.fc.weight

  filter_pruner:
    class: ActivationAPoZRankedFilterPruner_AGP
    initial_sparsity : 0.05
    final_sparsity: 0.50
    group_type: Filters
    weights: [module.layer1.0.conv1.weight,
              module.layer1.1.conv1.weight,
              module.layer1.2.conv1.weight,
              module.layer2.0.conv1.weight,
              module.layer2.1.conv1.weight,
              module.layer2.2.conv1.weight,
              module.layer2.3.conv1.weight,
              module.layer3.0.conv1.weight,
              module.layer3.1.conv1.weight,
              module.layer3.2.conv1.weight,
              module.layer3.3.conv1.weight,
              module.layer3.4.conv1.weight,
              module.layer3.5.conv1.weight,
              module.layer4.0.conv1.weight,
              module.layer4.1.conv1.weight,
              module.layer4.2.conv1.weight,

              module.layer1.0.conv2.weight,
              module.layer1.1.conv2.weight,
              module.layer1.2.conv2.weight,
              module.layer2.0.conv2.weight,
              module.layer2.1.conv2.weight,
              module.layer2.2.conv2.weight,
              module.layer2.3.conv2.weight,
              module.layer3.0.conv2.weight,
              module.layer3.1.conv2.weight,
              module.layer3.2.conv2.weight,
              module.layer3.3.conv2.weight,
              module.layer3.4.conv2.weight,
              module.layer3.5.conv2.weight,
              module.layer4.0.conv2.weight,
              module.layer4.1.conv2.weight,
              module.layer4.2.conv2.weight]

  fine_pruner:
    class: AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.70
    weights: [
      module.layer4.0.conv2.weight,
      module.layer4.0.conv3.weight,
      module.layer4.0.downsample.0.weight,
      module.layer4.1.conv1.weight,
      module.layer4.1.conv2.weight,
      module.layer4.1.conv3.weight,
      module.layer4.2.conv1.weight,
      module.layer4.2.conv2.weight,
      module.layer4.2.conv3.weight]

extensions:
  net_thinner:
    class: 'FilterRemover'
    thinning_func_str: remove_filters
    arch: 'resnet50'
    dataset: 'imagenet'

lr_schedulers:
  pruning_lr:
    class: ExponentialLR
    gamma: 0.95

policies:
  - pruner:
     instance_name : filter_pruner
    starting_epoch: 1
    ending_epoch: 30
    frequency: 1

# After completeing the pruning, we perform network thinning and continue fine-tuning.
  - extension:
      instance_name: net_thinner
    epochs: [31]


  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 40
    ending_epoch: 80
    frequency: 1
