#
# This schedule performs 1x1x8 block pruning using L1-norm ranking and AGP for the setting the pruning-rate decay.
# The final Linear layer (FC) is also pruned to 70%.
#
# Best Top1: 76.358 (epoch 72) vs. 76.15 baseline (+0.2%)
#
# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=../agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml --validation-split=0 --num-best-scores=10
#
# Parameters:
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11032 | -0.00044 |    0.06760 |
# |  1 | module.layer1.0.conv1.weight        | (64, 64, 1, 1)     |          4096 |           4096 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06353 | -0.00371 |    0.03587 |
# |  2 | module.layer1.0.conv2.weight        | (64, 64, 3, 3)     |         36864 |          11064 |    0.00000 |    0.00000 |  0.00000 | 28.12500 |  7.81250 |   69.98698 | 0.02277 |  0.00061 |    0.00835 |
# |  3 | module.layer1.0.conv3.weight        | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03142 |  0.00034 |    0.01880 |
# |  4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05117 | -0.00298 |    0.02858 |
# |  5 | module.layer1.1.conv1.weight        | (64, 256, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02722 |  0.00105 |    0.01803 |
# |  6 | module.layer1.1.conv2.weight        | (64, 64, 3, 3)     |         36864 |          11064 |    0.00000 |    0.00000 |  0.00000 | 18.75000 |  1.56250 |   69.98698 | 0.02097 |  0.00016 |    0.00841 |
# |  7 | module.layer1.1.conv3.weight        | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02896 | -0.00002 |    0.01815 |
# |  8 | module.layer1.2.conv1.weight        | (64, 256, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02671 |  0.00015 |    0.01929 |
# |  9 | module.layer1.2.conv2.weight        | (64, 64, 3, 3)     |         36864 |          11064 |    0.00000 |    0.00000 |  0.00000 | 13.47656 |  0.00000 |   69.98698 | 0.02149 | -0.00033 |    0.00930 |
# | 10 | module.layer1.2.conv3.weight        | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02755 | -0.00215 |    0.01658 |
# | 11 | module.layer2.0.conv1.weight        | (128, 256, 1, 1)   |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03152 | -0.00126 |    0.02213 |
# | 12 | module.layer2.0.conv2.weight        | (128, 128, 3, 3)   |        147456 |          44240 |    0.00000 |    0.00000 |  0.00000 | 19.04297 |  0.00000 |   69.99783 | 0.01489 | -0.00011 |    0.00633 |
# | 13 | module.layer2.0.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02486 |  0.00003 |    0.01535 |
# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02046 | -0.00033 |    0.01198 |
# | 15 | module.layer2.1.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01482 | -0.00005 |    0.00895 |
# | 16 | module.layer2.1.conv2.weight        | (128, 128, 3, 3)   |        147456 |          44240 |    0.00000 |    0.00000 |  0.00000 | 22.36328 |  0.78125 |   69.99783 | 0.01512 |  0.00037 |    0.00598 |
# | 17 | module.layer2.1.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01964 | -0.00101 |    0.01122 |
# | 18 | module.layer2.2.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02073 | -0.00067 |    0.01437 |
# | 19 | module.layer2.2.conv2.weight        | (128, 128, 3, 3)   |        147456 |          44240 |    0.00000 |    0.00000 |  0.00000 | 14.64844 |  0.00000 |   69.99783 | 0.01522 |  0.00006 |    0.00622 |
# | 20 | module.layer2.2.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02328 | -0.00032 |    0.01636 |
# | 21 | module.layer2.3.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02161 | -0.00079 |    0.01598 |
# | 22 | module.layer2.3.conv2.weight        | (128, 128, 3, 3)   |        147456 |          44240 |    0.00000 |    0.00000 |  0.00000 | 12.79297 |  0.00000 |   69.99783 | 0.01498 | -0.00022 |    0.00650 |
# | 23 | module.layer2.3.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02159 | -0.00091 |    0.01488 |
# | 24 | module.layer3.0.conv1.weight        | (256, 512, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02732 | -0.00100 |    0.01950 |
# | 25 | module.layer3.0.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 24.70703 |  0.00000 |   69.99919 | 0.01165 | -0.00010 |    0.00486 |
# | 26 | module.layer3.0.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02073 | -0.00034 |    0.01470 |
# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01429 |  0.00003 |    0.00978 |
# | 28 | module.layer3.1.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01360 | -0.00050 |    0.00955 |
# | 29 | module.layer3.1.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 16.29639 |  0.00000 |   69.99919 | 0.01055 |  0.00002 |    0.00442 |
# | 30 | module.layer3.1.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01820 | -0.00090 |    0.01307 |
# | 31 | module.layer3.2.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01407 | -0.00041 |    0.01007 |
# | 32 | module.layer3.2.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 11.88965 |  0.00000 |   69.99919 | 0.01011 | -0.00021 |    0.00433 |
# | 33 | module.layer3.2.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01698 | -0.00063 |    0.01240 |
# | 34 | module.layer3.3.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01550 | -0.00058 |    0.01146 |
# | 35 | module.layer3.3.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 11.92627 |  0.00000 |   69.99919 | 0.00985 | -0.00019 |    0.00429 |
# | 36 | module.layer3.3.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01625 | -0.00094 |    0.01197 |
# | 37 | module.layer3.4.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01617 | -0.00080 |    0.01216 |
# | 38 | module.layer3.4.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 11.99951 |  0.00000 |   69.99919 | 0.00980 | -0.00028 |    0.00428 |
# | 39 | module.layer3.4.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01623 | -0.00131 |    0.01196 |
# | 40 | module.layer3.5.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01757 | -0.00075 |    0.01337 |
# | 41 | module.layer3.5.conv2.weight        | (256, 256, 3, 3)   |        589824 |         176952 |    0.00000 |    0.00000 |  0.00000 | 11.16943 |  0.00000 |   69.99919 | 0.01001 | -0.00032 |    0.00438 |
# | 42 | module.layer3.5.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01725 | -0.00191 |    0.01294 |
# | 43 | module.layer4.0.conv1.weight        | (512, 1024, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02114 | -0.00099 |    0.01632 |
# | 44 | module.layer4.0.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         707792 |    0.00000 |    0.00000 |  0.00000 | 19.15894 |  0.00000 |   69.99986 | 0.00801 | -0.00012 |    0.00358 |
# | 45 | module.layer4.0.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01369 | -0.00055 |    0.01057 |
# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) |       2097152 |        2097152 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00892 | -0.00015 |    0.00679 |
# | 47 | module.layer4.1.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01330 | -0.00056 |    0.01038 |
# | 48 | module.layer4.1.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         707792 |    0.00000 |    0.00000 |  0.00000 | 13.93127 |  0.00000 |   69.99986 | 0.00781 | -0.00028 |    0.00351 |
# | 49 | module.layer4.1.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01347 | -0.00007 |    0.01039 |
# | 50 | module.layer4.2.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01628 | -0.00034 |    0.01277 |
# | 51 | module.layer4.2.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         707792 |    0.00000 |    0.00000 |  0.00000 | 23.70911 |  0.00000 |   69.99986 | 0.00686 | -0.00021 |    0.00310 |
# | 52 | module.layer4.2.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01262 |  0.00002 |    0.00943 |
# | 53 | module.fc.weight                    | (1000, 2048)       |       2048000 |         614400 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   70.00000 | 0.03148 |  0.00299 |    0.01480 |
# | 54 | Total sparsity:                     | -                  |      25502912 |       16147304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   36.68447 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# 2018-12-24 13:52:24,645 - Total sparsity: 36.68
#
# 2018-12-24 13:52:24,645 - --- validate (epoch=72)-----------
# 2018-12-24 13:52:24,645 - 50000 samples (256 per mini-batch)
# 2018-12-24 13:52:44,774 - Epoch: [72][   50/  195]    Loss 0.676330    Top1 82.195312    Top5 96.039062
# 2018-12-24 13:52:52,702 - Epoch: [72][  100/  195]    Loss 0.799058    Top1 79.386719    Top5 94.863281
# 2018-12-24 13:53:00,916 - Epoch: [72][  150/  195]    Loss 0.911178    Top1 77.216146    Top5 93.466146
# 2018-12-24 13:53:08,224 - ==> Top1: 76.358    Top5: 92.972    Loss: 0.952
#
# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.454 on Epoch: 1
# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.446 on Epoch: 0
# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.416 on Epoch: 3
# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.358 on Epoch: 72
# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.344 on Epoch: 2
# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.326 on Epoch: 69
# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.320 on Epoch: 68
# 2018-12-24 13:53:08,308 - ==> Best Top1: 76.318 on Epoch: 70
# 2018-12-24 13:53:08,309 - ==> Best Top1: 76.300 on Epoch: 58
# 2018-12-24 13:53:08,309 - ==> Best Top1: 76.284 on Epoch: 71

version: 1

pruners:
  fc_pruner:
    class: AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.70
    weights: module.fc.weight


  block_pruner:
    class: L1RankedStructureParameterPruner_AGP
    initial_sparsity : 0.05
    final_sparsity: 0.70
    group_type: Blocks
    kwargs:
      block_shape: [1,8,1,1]  # [block_repetition, block_depth, block_height, block_width]
    weights: [module.layer1.0.conv2.weight,
              module.layer1.1.conv2.weight,
              module.layer1.2.conv2.weight,
              module.layer2.0.conv2.weight,
              module.layer2.1.conv2.weight,
              module.layer2.2.conv2.weight,
              module.layer2.3.conv2.weight,
              module.layer3.0.conv2.weight,
              module.layer3.1.conv2.weight,
              module.layer3.2.conv2.weight,
              module.layer3.3.conv2.weight,
              module.layer3.4.conv2.weight,
              module.layer3.5.conv2.weight,
              module.layer4.0.conv2.weight,
              module.layer4.1.conv2.weight,
              module.layer4.2.conv2.weight]


lr_schedulers:
  pruning_lr:
    class: ExponentialLR
    gamma: 0.95


policies:
  - pruner:
     instance_name : block_pruner
    starting_epoch: 0
    ending_epoch: 30
    frequency: 1

  - pruner:
      instance_name : fc_pruner
    starting_epoch: 0
    ending_epoch: 30
    frequency: 3

  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 40
    ending_epoch: 80
    frequency: 1
