# This schedule demonstrates high-rate element-wise pruning (80% sparsity) of Resnet 50.
# Top1 is 76.0 vs the published Top1: 76.15 (https://pytorch.org/docs/stable/torchvision/models.html)
# Top5 is on par with the baseline.
#
# The pruning level is uniform across all layers (80%), except for the first convolution.
#
# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=12 --epochs=100 --lr=0.005 --compress=../agp-pruning/resnet50.schedule_agp.yaml --vs=0
#
# Parameters:
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10902 | -0.00039 |    0.06756 |
# |  1 | module.layer1.0.conv1.weight        | (64, 64, 1, 1)     |          4096 |            820 |    0.00000 |    0.00000 |  1.56250 | 79.98047 |  7.81250 |   79.98047 | 0.04406 | -0.00270 |    0.01620 |
# |  2 | module.layer1.0.conv2.weight        | (64, 64, 3, 3)     |         36864 |           7373 |    0.00000 |    0.00000 |  7.81250 | 36.27930 |  6.25000 |   79.99946 | 0.02160 |  0.00050 |    0.00779 |
# |  3 | module.layer1.0.conv3.weight        | (256, 64, 1, 1)    |         16384 |           3277 |    0.00000 |    0.00000 |  6.25000 | 79.99878 | 13.28125 |   79.99878 | 0.02543 |  0.00032 |    0.00974 |
# |  4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1)    |         16384 |           3277 |    0.00000 |    0.00000 |  1.56250 | 79.99878 | 13.67188 |   79.99878 | 0.03585 | -0.00183 |    0.01348 |
# |  5 | module.layer1.1.conv1.weight        | (64, 256, 1, 1)    |         16384 |           3277 |    0.00000 |    0.00000 | 11.71875 | 79.99878 |  6.25000 |   79.99878 | 0.02139 |  0.00075 |    0.00844 |
# |  6 | module.layer1.1.conv2.weight        | (64, 64, 3, 3)     |         36864 |           7373 |    0.00000 |    0.00000 |  6.25000 | 30.76172 |  0.00000 |   79.99946 | 0.02009 |  0.00011 |    0.00763 |
# |  7 | module.layer1.1.conv3.weight        | (256, 64, 1, 1)    |         16384 |           3277 |    0.00000 |    0.00000 |  0.00000 | 79.99878 |  7.03125 |   79.99878 | 0.02291 |  0.00013 |    0.00891 |
# |  8 | module.layer1.2.conv1.weight        | (64, 256, 1, 1)    |         16384 |           3277 |    0.00000 |    0.00000 |  8.20312 | 79.99878 |  0.00000 |   79.99878 | 0.02034 | -0.00007 |    0.00816 |
# |  9 | module.layer1.2.conv2.weight        | (64, 64, 3, 3)     |         36864 |           7373 |    0.00000 |    0.00000 |  0.00000 | 26.29395 |  0.00000 |   79.99946 | 0.02126 | -0.00038 |    0.00860 |
# | 10 | module.layer1.2.conv3.weight        | (256, 64, 1, 1)    |         16384 |           3277 |    0.00000 |    0.00000 |  0.00000 | 79.99878 |  7.03125 |   79.99878 | 0.02220 | -0.00112 |    0.00856 |
# | 11 | module.layer2.0.conv1.weight        | (128, 256, 1, 1)   |         32768 |           6554 |    0.00000 |    0.00000 |  3.51562 | 79.99878 |  0.00000 |   79.99878 | 0.02269 | -0.00074 |    0.00903 |
# | 12 | module.layer2.0.conv2.weight        | (128, 128, 3, 3)   |        147456 |          29492 |    0.00000 |    0.00000 |  0.00000 | 32.34253 |  0.00000 |   79.99946 | 0.01436 | -0.00008 |    0.00567 |
# | 13 | module.layer2.0.conv3.weight        | (512, 128, 1, 1)   |         65536 |          13108 |    0.00000 |    0.00000 |  0.00000 | 79.99878 | 18.75000 |   79.99878 | 0.01925 |  0.00021 |    0.00717 |
# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1)   |        131072 |          26215 |    0.00000 |    0.00000 |  0.00000 | 79.99954 | 12.30469 |   79.99954 | 0.01469 | -0.00023 |    0.00518 |
# | 15 | module.layer2.1.conv1.weight        | (128, 512, 1, 1)   |         65536 |          13108 |    0.00000 |    0.00000 | 12.89062 | 79.99878 |  0.00000 |   79.99878 | 0.01206 |  0.00011 |    0.00439 |
# | 16 | module.layer2.1.conv2.weight        | (128, 128, 3, 3)   |        147456 |          29492 |    0.00000 |    0.00000 |  0.00000 | 36.49902 |  0.00000 |   79.99946 | 0.01451 |  0.00018 |    0.00548 |
# | 17 | module.layer2.1.conv3.weight        | (512, 128, 1, 1)   |         65536 |          13108 |    0.00000 |    0.00000 |  0.00000 | 79.99878 |  3.71094 |   79.99878 | 0.01631 | -0.00087 |    0.00588 |
# | 18 | module.layer2.2.conv1.weight        | (128, 512, 1, 1)   |         65536 |          13108 |    0.00000 |    0.00000 |  1.56250 | 79.99878 |  0.00000 |   79.99878 | 0.01590 | -0.00040 |    0.00605 |
# | 19 | module.layer2.2.conv2.weight        | (128, 128, 3, 3)   |        147456 |          29492 |    0.00000 |    0.00000 |  0.00000 | 28.51562 |  0.00000 |   79.99946 | 0.01464 | -0.00008 |    0.00558 |
# | 20 | module.layer2.2.conv3.weight        | (512, 128, 1, 1)   |         65536 |          13108 |    0.00000 |    0.00000 |  0.00000 | 79.99878 |  2.14844 |   79.99878 | 0.01771 | -0.00020 |    0.00682 |
# | 21 | module.layer2.3.conv1.weight        | (128, 512, 1, 1)   |         65536 |          13108 |    0.00000 |    0.00000 |  0.19531 | 79.99878 |  0.00000 |   79.99878 | 0.01613 | -0.00042 |    0.00634 |
# | 22 | module.layer2.3.conv2.weight        | (128, 128, 3, 3)   |        147456 |          29492 |    0.00000 |    0.00000 |  0.00000 | 24.03564 |  0.00000 |   79.99946 | 0.01476 | -0.00026 |    0.00586 |
# | 23 | module.layer2.3.conv3.weight        | (512, 128, 1, 1)   |         65536 |          13108 |    0.00000 |    0.00000 |  0.00000 | 79.99878 |  4.10156 |   79.99878 | 0.01678 | -0.00034 |    0.00641 |
# | 24 | module.layer3.0.conv1.weight        | (256, 512, 1, 1)   |        131072 |          26215 |    0.00000 |    0.00000 |  0.00000 | 79.99954 |  0.00000 |   79.99954 | 0.01981 | -0.00048 |    0.00781 |
# | 25 | module.layer3.0.conv2.weight        | (256, 256, 3, 3)   |        589824 |         117965 |    0.00000 |    0.00000 |  0.00000 | 38.29956 |  0.00000 |   79.99997 | 0.01108 | -0.00012 |    0.00427 |
# | 26 | module.layer3.0.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          52429 |    0.00000 |    0.00000 |  0.00000 | 79.99992 |  4.39453 |   79.99992 | 0.01559 | -0.00001 |    0.00608 |
# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1)  |        524288 |         104858 |    0.00000 |    0.00000 |  0.00000 | 79.99992 |  4.00391 |   79.99992 | 0.01054 | -0.00000 |    0.00388 |
# | 28 | module.layer3.1.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          52429 |    0.00000 |    0.00000 |  4.58984 | 79.99992 |  0.00000 |   79.99992 | 0.01161 | -0.00015 |    0.00440 |
# | 29 | module.layer3.1.conv2.weight        | (256, 256, 3, 3)   |        589824 |         117965 |    0.00000 |    0.00000 |  0.00000 | 30.37567 |  0.00000 |   79.99997 | 0.01065 | -0.00009 |    0.00409 |
# | 30 | module.layer3.1.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          52429 |    0.00000 |    0.00000 |  0.00000 | 79.99992 |  0.68359 |   79.99992 | 0.01423 | -0.00072 |    0.00548 |
# | 31 | module.layer3.2.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          52429 |    0.00000 |    0.00000 |  0.68359 | 79.99992 |  0.00000 |   79.99992 | 0.01134 | -0.00020 |    0.00424 |
# | 32 | module.layer3.2.conv2.weight        | (256, 256, 3, 3)   |        589824 |         117965 |    0.00000 |    0.00000 |  0.00000 | 23.76862 |  0.00000 |   79.99997 | 0.01032 | -0.00033 |    0.00400 |
# | 33 | module.layer3.2.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          52429 |    0.00000 |    0.00000 |  0.00000 | 79.99992 |  0.19531 |   79.99992 | 0.01298 | -0.00031 |    0.00501 |
# | 34 | module.layer3.3.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          52429 |    0.00000 |    0.00000 |  0.19531 | 79.99992 |  0.00000 |   79.99992 | 0.01234 | -0.00023 |    0.00471 |
# | 35 | module.layer3.3.conv2.weight        | (256, 256, 3, 3)   |        589824 |         117965 |    0.00000 |    0.00000 |  0.00000 | 23.16437 |  0.00000 |   79.99997 | 0.01036 | -0.00030 |    0.00404 |
# | 36 | module.layer3.3.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          52429 |    0.00000 |    0.00000 |  0.00000 | 79.99992 |  0.39062 |   79.99992 | 0.01273 | -0.00055 |    0.00495 |
# | 37 | module.layer3.4.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          52429 |    0.00000 |    0.00000 |  0.09766 | 79.99992 |  0.00000 |   79.99992 | 0.01271 | -0.00035 |    0.00492 |
# | 38 | module.layer3.4.conv2.weight        | (256, 256, 3, 3)   |        589824 |         117965 |    0.00000 |    0.00000 |  0.00000 | 24.42474 |  0.00000 |   79.99997 | 0.01033 | -0.00038 |    0.00405 |
# | 39 | module.layer3.4.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          52429 |    0.00000 |    0.00000 |  0.00000 | 79.99992 |  0.29297 |   79.99992 | 0.01291 | -0.00077 |    0.00505 |
# | 40 | module.layer3.5.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          52429 |    0.00000 |    0.00000 |  0.00000 | 79.99992 |  0.00000 |   79.99992 | 0.01351 | -0.00029 |    0.00532 |
# | 41 | module.layer3.5.conv2.weight        | (256, 256, 3, 3)   |        589824 |         117965 |    0.00000 |    0.00000 |  0.00000 | 26.96075 |  0.00000 |   79.99997 | 0.01055 | -0.00040 |    0.00417 |
# | 42 | module.layer3.5.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          52429 |    0.00000 |    0.00000 |  0.00000 | 79.99992 |  0.68359 |   79.99992 | 0.01390 | -0.00120 |    0.00555 |
# | 43 | module.layer4.0.conv1.weight        | (512, 1024, 1, 1)  |        524288 |         104858 |    0.00000 |    0.00000 |  0.00000 | 79.99992 |  0.00000 |   79.99992 | 0.01559 | -0.00040 |    0.00635 |
# | 44 | module.layer4.0.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         471860 |    0.00000 |    0.00000 |  0.00000 | 38.93700 |  0.00000 |   79.99997 | 0.00838 | -0.00015 |    0.00335 |
# | 45 | module.layer4.0.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |         209716 |    0.00000 |    0.00000 |  0.00000 | 79.99992 |  0.00000 |   79.99992 | 0.01160 | -0.00020 |    0.00466 |
# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) |       2097152 |         419431 |    0.00000 |    0.00000 |  0.00000 | 79.99997 |  0.00000 |   79.99997 | 0.00780 | -0.00013 |    0.00296 |
# | 47 | module.layer4.1.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |         209716 |    0.00000 |    0.00000 |  0.00000 | 79.99992 |  0.00000 |   79.99992 | 0.01202 | -0.00025 |    0.00479 |
# | 48 | module.layer4.1.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         471860 |    0.00000 |    0.00000 |  0.00000 | 33.88023 |  0.00000 |   79.99997 | 0.00884 | -0.00036 |    0.00357 |
# | 49 | module.layer4.1.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |         209716 |    0.00000 |    0.00000 |  0.00000 | 79.99992 |  0.00000 |   79.99992 | 0.01205 |  0.00008 |    0.00487 |
# | 50 | module.layer4.2.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |         209716 |    0.00000 |    0.00000 |  0.00000 | 79.99992 |  0.00000 |   79.99992 | 0.01396 | -0.00011 |    0.00568 |
# | 51 | module.layer4.2.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         471860 |    0.00000 |    0.00000 |  0.00000 | 50.91476 |  0.00000 |   79.99997 | 0.00723 | -0.00022 |    0.00303 |
# | 52 | module.layer4.2.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |         209716 |    0.00000 |    0.00000 |  0.00000 | 79.99992 |  0.00000 |   79.99992 | 0.00957 |  0.00020 |    0.00386 |
# | 53 | module.fc.weight                    | (1000, 2048)       |       2048000 |         409600 |    0.00000 |    0.04883 |  0.00000 |  0.00000 |  0.00000 |   80.00000 | 0.03149 |  0.00414 |    0.01235 |
# | 54 | Total sparsity:                     | -                  |      25502912 |        5108133 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   79.97039 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# 2019-03-20 18:14:17,059 - Total sparsity: 79.97
#
# 2019-03-20 18:14:17,059 - --- validate (epoch=98)-----------
# 2019-03-20 18:14:17,059 - 50000 samples (256 per mini-batch)
# 2019-03-20 18:14:47,289 - Epoch: [98][   50/  195]    Loss 0.958758    Top1 75.703125    Top5 92.843750
# 2019-03-20 18:15:09,204 - Epoch: [98][  100/  195]    Loss 0.961983    Top1 75.789062    Top5 92.804688
# 2019-03-20 18:15:35,028 - Epoch: [98][  150/  195]    Loss 0.956074    Top1 75.776042    Top5 92.848958
# 2019-03-20 18:15:50,982 - ==> Top1: 75.838    Top5: 92.868    Loss: 0.959
#
# 2019-03-20 18:15:50,998 - ==> Best [Top1: 75.990   Top5: 92.872   Sparsity:79.97   Params: 5108133 on epoch: 94]
# 2019-03-20 18:15:50,998 - Saving checkpoint to: logs/2019.03.18-090917/checkpoint.pth.tar
#
# real    3463m11.943s
# user    31959m34.272s
# sys     2745m57.392s

version: 1

pruners:
  fc_pruner:
    class: AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.80
    weights: module.fc.weight

  conv_pruner:
    class: AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.80
    weights: [
    #module.conv1.weight,
    module.layer1.0.conv1.weight,
    module.layer1.0.conv2.weight,
    module.layer1.0.conv3.weight,
    module.layer1.0.downsample.0.weight,
    module.layer1.1.conv1.weight,
    module.layer1.1.conv2.weight,
    module.layer1.1.conv3.weight,
    module.layer1.2.conv1.weight,
    module.layer1.2.conv2.weight,
    module.layer1.2.conv3.weight,
    module.layer2.0.conv1.weight,
    module.layer2.0.conv2.weight,
    module.layer2.0.conv3.weight,
    module.layer2.0.downsample.0.weight,
    module.layer2.1.conv1.weight,
    module.layer2.1.conv2.weight,
    module.layer2.1.conv3.weight,
    module.layer2.2.conv1.weight,
    module.layer2.2.conv2.weight,
    module.layer2.2.conv3.weight,
    module.layer2.3.conv1.weight,
    module.layer2.3.conv2.weight,
    module.layer2.3.conv3.weight,
    module.layer3.0.conv1.weight,
    module.layer3.0.conv2.weight,
    module.layer3.0.conv3.weight,
    module.layer3.0.downsample.0.weight,
    module.layer3.1.conv1.weight,
    module.layer3.1.conv2.weight,
    module.layer3.1.conv3.weight,
    module.layer3.2.conv1.weight,
    module.layer3.2.conv2.weight,
    module.layer3.2.conv3.weight,
    module.layer3.3.conv1.weight,
    module.layer3.3.conv2.weight,
    module.layer3.3.conv3.weight,
    module.layer3.4.conv1.weight,
    module.layer3.4.conv2.weight,
    module.layer3.4.conv3.weight,
    module.layer3.5.conv1.weight,
    module.layer3.5.conv2.weight,
    module.layer3.5.conv3.weight,
    module.layer4.0.conv1.weight,
    module.layer4.0.conv2.weight,
    module.layer4.0.conv3.weight,
    module.layer4.0.downsample.0.weight,
    module.layer4.1.conv1.weight,
    module.layer4.1.conv2.weight,
    module.layer4.1.conv3.weight,
    module.layer4.2.conv1.weight,
    module.layer4.2.conv2.weight,
    module.layer4.2.conv3.weight]

lr_schedulers:
   pruning_lr:
     class: ExponentialLR
     gamma: 0.95


policies:
  - pruner:
      instance_name : conv_pruner
    starting_epoch: 0
    ending_epoch: 35
    frequency: 1

  - pruner:
      instance_name : fc_pruner
    starting_epoch: 1
    ending_epoch: 35
    frequency: 1

  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 40
    ending_epoch: 100
    frequency: 1
