# This schedule demonstrates high-rate element-wise pruning (84.6% sparsity) of Resnet 50.
# Top1 is 75.66 vs the published Top1: 76.15 (https://pytorch.org/docs/stable/torchvision/models.html) i.e. a drop of -0.5%.
#
# The pruning level is uniform across all layers (85%), except for the first convolution.  The last Linear layer is
# pruned to 80% sparsity.
#
# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=12 --epochs=120 --lr=0.005 --compress=../agp-pruning/resnet50.schedule_agp.yaml --vs=0
#
# Parameters:
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10833 | -0.00039 |    0.06693 |
# |  1 | module.layer1.0.conv1.weight        | (64, 64, 1, 1)     |          4096 |            615 |    0.00000 |    0.00000 |  1.56250 | 84.98535 |  7.81250 |   84.98535 | 0.04060 | -0.00275 |    0.01329 |
# |  2 | module.layer1.0.conv2.weight        | (64, 64, 3, 3)     |         36864 |           5530 |    0.00000 |    0.00000 |  7.81250 | 46.70410 |  6.25000 |   84.99891 | 0.02069 |  0.00051 |    0.00656 |
# |  3 | module.layer1.0.conv3.weight        | (256, 64, 1, 1)    |         16384 |           2458 |    0.00000 |    0.00000 |  6.25000 | 84.99756 | 14.06250 |   84.99756 | 0.02402 |  0.00024 |    0.00802 |
# |  4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1)    |         16384 |           2458 |    0.00000 |    0.00000 |  1.56250 | 84.99756 | 14.45312 |   84.99756 | 0.03334 | -0.00160 |    0.01112 |
# |  5 | module.layer1.1.conv1.weight        | (64, 256, 1, 1)    |         16384 |           2458 |    0.00000 |    0.00000 | 12.89062 | 84.99756 |  6.25000 |   84.99756 | 0.02021 |  0.00071 |    0.00695 |
# |  6 | module.layer1.1.conv2.weight        | (64, 64, 3, 3)     |         36864 |           5530 |    0.00000 |    0.00000 |  6.25000 | 40.82031 |  0.00000 |   84.99891 | 0.01920 |  0.00011 |    0.00638 |
# |  7 | module.layer1.1.conv3.weight        | (256, 64, 1, 1)    |         16384 |           2458 |    0.00000 |    0.00000 |  0.00000 | 84.99756 | 10.54688 |   84.99756 | 0.02166 |  0.00015 |    0.00746 |
# |  8 | module.layer1.2.conv1.weight        | (64, 256, 1, 1)    |         16384 |           2458 |    0.00000 |    0.00000 |  9.76562 | 84.99756 |  0.00000 |   84.99756 | 0.01916 | -0.00012 |    0.00670 |
# |  9 | module.layer1.2.conv2.weight        | (64, 64, 3, 3)     |         36864 |           5530 |    0.00000 |    0.00000 |  0.00000 | 36.81641 |  0.00000 |   84.99891 | 0.02004 | -0.00037 |    0.00708 |
# | 10 | module.layer1.2.conv3.weight        | (256, 64, 1, 1)    |         16384 |           2458 |    0.00000 |    0.00000 |  0.00000 | 84.99756 | 12.10938 |   84.99756 | 0.02094 | -0.00089 |    0.00710 |
# | 11 | module.layer2.0.conv1.weight        | (128, 256, 1, 1)   |         32768 |           4916 |    0.00000 |    0.00000 |  7.42188 | 84.99756 |  0.00000 |   84.99756 | 0.02132 | -0.00057 |    0.00739 |
# | 12 | module.layer2.0.conv2.weight        | (128, 128, 3, 3)   |        147456 |          22119 |    0.00000 |    0.00000 |  0.00000 | 44.40308 |  0.00000 |   84.99959 | 0.01355 | -0.00008 |    0.00468 |
# | 13 | module.layer2.0.conv3.weight        | (512, 128, 1, 1)   |         65536 |           9831 |    0.00000 |    0.00000 |  0.00000 | 84.99908 | 26.36719 |   84.99908 | 0.01814 |  0.00024 |    0.00600 |
# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1)   |        131072 |          19661 |    0.00000 |    0.00000 |  0.00000 | 84.99985 | 12.50000 |   84.99985 | 0.01384 | -0.00018 |    0.00429 |
# | 15 | module.layer2.1.conv1.weight        | (128, 512, 1, 1)   |         65536 |           9831 |    0.00000 |    0.00000 | 13.47656 | 84.99908 |  0.00000 |   84.99908 | 0.01146 |  0.00018 |    0.00366 |
# | 16 | module.layer2.1.conv2.weight        | (128, 128, 3, 3)   |        147456 |          22119 |    0.00000 |    0.00000 |  0.00000 | 47.57080 |  0.00000 |   84.99959 | 0.01386 |  0.00020 |    0.00461 |
# | 17 | module.layer2.1.conv3.weight        | (512, 128, 1, 1)   |         65536 |           9831 |    0.00000 |    0.00000 |  0.00000 | 84.99908 | 16.99219 |   84.99908 | 0.01543 | -0.00076 |    0.00495 |
# | 18 | module.layer2.2.conv1.weight        | (128, 512, 1, 1)   |         65536 |           9831 |    0.00000 |    0.00000 |  1.56250 | 84.99908 |  0.00000 |   84.99908 | 0.01505 | -0.00032 |    0.00503 |
# | 19 | module.layer2.2.conv2.weight        | (128, 128, 3, 3)   |        147456 |          22119 |    0.00000 |    0.00000 |  0.00000 | 39.97192 |  0.00000 |   84.99959 | 0.01382 | -0.00008 |    0.00462 |
# | 20 | module.layer2.2.conv3.weight        | (512, 128, 1, 1)   |         65536 |           9831 |    0.00000 |    0.00000 |  0.00000 | 84.99908 |  2.73438 |   84.99908 | 0.01663 | -0.00013 |    0.00562 |
# | 21 | module.layer2.3.conv1.weight        | (128, 512, 1, 1)   |         65536 |           9831 |    0.00000 |    0.00000 |  0.78125 | 84.99908 |  0.00000 |   84.99908 | 0.01522 | -0.00029 |    0.00521 |
# | 22 | module.layer2.3.conv2.weight        | (128, 128, 3, 3)   |        147456 |          22119 |    0.00000 |    0.00000 |  0.00000 | 35.00366 |  0.00000 |   84.99959 | 0.01394 | -0.00023 |    0.00483 |
# | 23 | module.layer2.3.conv3.weight        | (512, 128, 1, 1)   |         65536 |           9831 |    0.00000 |    0.00000 |  0.00000 | 84.99908 | 11.52344 |   84.99908 | 0.01592 | -0.00024 |    0.00537 |
# | 24 | module.layer3.0.conv1.weight        | (256, 512, 1, 1)   |        131072 |          19661 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  0.00000 |   84.99985 | 0.01860 | -0.00033 |    0.00644 |
# | 25 | module.layer3.0.conv2.weight        | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 50.09918 |  0.00000 |   84.99993 | 0.01041 | -0.00010 |    0.00351 |
# | 26 | module.layer3.0.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  4.49219 |   84.99985 | 0.01460 |  0.00004 |    0.00500 |
# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1)  |        524288 |          78644 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  4.00391 |   84.99985 | 0.00992 |  0.00001 |    0.00320 |
# | 28 | module.layer3.1.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  4.88281 | 84.99985 |  0.00000 |   84.99985 | 0.01106 | -0.00008 |    0.00367 |
# | 29 | module.layer3.1.conv2.weight        | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 41.85333 |  0.00000 |   84.99993 | 0.01006 | -0.00006 |    0.00338 |
# | 30 | module.layer3.1.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  0.97656 |   84.99985 | 0.01332 | -0.00063 |    0.00451 |
# | 31 | module.layer3.2.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.78125 | 84.99985 |  0.00000 |   84.99985 | 0.01074 | -0.00015 |    0.00351 |
# | 32 | module.layer3.2.conv2.weight        | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 34.92432 |  0.00000 |   84.99993 | 0.00978 | -0.00026 |    0.00331 |
# | 33 | module.layer3.2.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  0.48828 |   84.99985 | 0.01219 | -0.00025 |    0.00413 |
# | 34 | module.layer3.3.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.19531 | 84.99985 |  0.00000 |   84.99985 | 0.01165 | -0.00015 |    0.00389 |
# | 35 | module.layer3.3.conv2.weight        | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 34.15985 |  0.00000 |   84.99993 | 0.00979 | -0.00025 |    0.00333 |
# | 36 | module.layer3.3.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  1.17188 |   84.99985 | 0.01197 | -0.00044 |    0.00409 |
# | 37 | module.layer3.4.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.09766 | 84.99985 |  0.00000 |   84.99985 | 0.01195 | -0.00023 |    0.00405 |
# | 38 | module.layer3.4.conv2.weight        | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 35.31799 |  0.00000 |   84.99993 | 0.00976 | -0.00031 |    0.00334 |
# | 39 | module.layer3.4.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  1.26953 |   84.99985 | 0.01214 | -0.00063 |    0.00416 |
# | 40 | module.layer3.5.conv1.weight        | (256, 1024, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  0.00000 |   84.99985 | 0.01269 | -0.00017 |    0.00437 |
# | 41 | module.layer3.5.conv2.weight        | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 37.64801 |  0.00000 |   84.99993 | 0.00997 | -0.00035 |    0.00344 |
# | 42 | module.layer3.5.conv3.weight        | (1024, 256, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  1.36719 |   84.99985 | 0.01306 | -0.00101 |    0.00456 |
# | 43 | module.layer4.0.conv1.weight        | (512, 1024, 1, 1)  |        524288 |          78644 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  0.00000 |   84.99985 | 0.01451 | -0.00024 |    0.00516 |
# | 44 | module.layer4.0.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         353895 |    0.00000 |    0.00000 |  0.00000 | 49.59564 |  0.00000 |   84.99997 | 0.00783 | -0.00011 |    0.00274 |
# | 45 | module.layer4.0.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |         157287 |    0.00000 |    0.00000 |  0.00000 | 84.99994 |  0.00000 |   84.99994 | 0.01082 | -0.00011 |    0.00380 |
# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) |       2097152 |         314573 |    0.00000 |    0.00000 |  0.00000 | 84.99999 |  0.00000 |   84.99999 | 0.00731 | -0.00010 |    0.00242 |
# | 47 | module.layer4.1.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |         157287 |    0.00000 |    0.00000 |  0.00000 | 84.99994 |  0.00000 |   84.99994 | 0.01125 | -0.00016 |    0.00392 |
# | 48 | module.layer4.1.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         353895 |    0.00000 |    0.00000 |  0.00000 | 44.37675 |  0.00000 |   84.99997 | 0.00827 | -0.00030 |    0.00292 |
# | 49 | module.layer4.1.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |         157287 |    0.00000 |    0.00000 |  0.00000 | 84.99994 |  0.00000 |   84.99994 | 0.01120 |  0.00017 |    0.00395 |
# | 50 | module.layer4.2.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |         157287 |    0.00000 |    0.00000 |  0.00000 | 84.99994 |  0.00000 |   84.99994 | 0.01296 | -0.00004 |    0.00460 |
# | 51 | module.layer4.2.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         353895 |    0.00000 |    0.00000 |  0.00000 | 59.08966 |  0.00000 |   84.99997 | 0.00678 | -0.00017 |    0.00248 |
# | 52 | module.layer4.2.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |         157287 |    0.00000 |    0.00000 |  0.00000 | 84.99994 |  0.04883 |   84.99994 | 0.00908 |  0.00024 |    0.00320 |
# | 53 | module.fc.weight                    | (1000, 2048)       |       2048000 |         409600 |    0.00000 |    0.09766 |  0.00000 |  0.00000 |  0.00000 |   80.00000 | 0.03174 |  0.00410 |    0.01240 |
# | 54 | Total sparsity:                     | -                  |      25502912 |        3935859 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   84.56702 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# 2019-03-23 23:01:26,400 - Total sparsity: 84.57
# 
# 2019-03-23 23:01:26,400 - --- validate (epoch=119)-----------
# 2019-03-23 23:01:26,400 - 50000 samples (256 per mini-batch)
# 2019-03-23 23:01:56,810 - Epoch: [119][   50/  195]    Loss 0.971985    Top1 75.554688    Top5 92.703125
# 2019-03-23 23:02:18,645 - Epoch: [119][  100/  195]    Loss 0.977343    Top1 75.527344    Top5 92.597656
# 2019-03-23 23:02:40,429 - Epoch: [119][  150/  195]    Loss 0.975216    Top1 75.455729    Top5 92.664062
# 2019-03-23 23:02:56,350 - ==> Top1: 75.544    Top5: 92.760    Loss: 0.969
# 
# 2019-03-23 23:02:56,366 - ==> Best [Top1: 75.662   Top5: 92.726   Sparsity:84.57   Params: 3935859 on epoch: 94]
# 2019-03-23 23:02:56,366 - Saving checkpoint to: logs/2019.03.21-003631/checkpoint.pth.tar
# 2019-03-23 23:02:56,799 - --- test ---------------------
# 2019-03-23 23:02:56,800 - 50000 samples (256 per mini-batch)
# 2019-03-23 23:03:19,606 - Test: [   50/  195]    Loss 0.988068    Top1 75.031250    Top5 92.539062
# 2019-03-23 23:03:35,571 - Test: [  100/  195]    Loss 0.978709    Top1 75.328125    Top5 92.664062
# 2019-03-23 23:03:51,860 - Test: [  150/  195]    Loss 0.968993    Top1 75.442708    Top5 92.755208
# 2019-03-23 23:04:08,382 - ==> Top1: 75.544    Top5: 92.760    Loss: 0.968

version: 1

pruners:
  fc_pruner:
    class: AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.80
    weights: module.fc.weight

  conv_pruner:
    class: AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.85
    weights: [
    #module.conv1.weight,
    module.layer1.0.conv1.weight,
    module.layer1.0.conv2.weight,
    module.layer1.0.conv3.weight,
    module.layer1.0.downsample.0.weight,
    module.layer1.1.conv1.weight,
    module.layer1.1.conv2.weight,
    module.layer1.1.conv3.weight,
    module.layer1.2.conv1.weight,
    module.layer1.2.conv2.weight,
    module.layer1.2.conv3.weight,
    module.layer2.0.conv1.weight,
    module.layer2.0.conv2.weight,
    module.layer2.0.conv3.weight,
    module.layer2.0.downsample.0.weight,
    module.layer2.1.conv1.weight,
    module.layer2.1.conv2.weight,
    module.layer2.1.conv3.weight,
    module.layer2.2.conv1.weight,
    module.layer2.2.conv2.weight,
    module.layer2.2.conv3.weight,
    module.layer2.3.conv1.weight,
    module.layer2.3.conv2.weight,
    module.layer2.3.conv3.weight,
    module.layer3.0.conv1.weight,
    module.layer3.0.conv2.weight,
    module.layer3.0.conv3.weight,
    module.layer3.0.downsample.0.weight,
    module.layer3.1.conv1.weight,
    module.layer3.1.conv2.weight,
    module.layer3.1.conv3.weight,
    module.layer3.2.conv1.weight,
    module.layer3.2.conv2.weight,
    module.layer3.2.conv3.weight,
    module.layer3.3.conv1.weight,
    module.layer3.3.conv2.weight,
    module.layer3.3.conv3.weight,
    module.layer3.4.conv1.weight,
    module.layer3.4.conv2.weight,
    module.layer3.4.conv3.weight,
    module.layer3.5.conv1.weight,
    module.layer3.5.conv2.weight,
    module.layer3.5.conv3.weight,
    module.layer4.0.conv1.weight,
    module.layer4.0.conv2.weight,
    module.layer4.0.conv3.weight,
    module.layer4.0.downsample.0.weight,
    module.layer4.1.conv1.weight,
    module.layer4.1.conv2.weight,
    module.layer4.1.conv3.weight,
    module.layer4.2.conv1.weight,
    module.layer4.2.conv2.weight,
    module.layer4.2.conv3.weight]


lr_schedulers:
   pruning_lr:
     class: ExponentialLR
     gamma: 0.95


policies:

  - pruner:
      instance_name : conv_pruner
    starting_epoch: 0
    ending_epoch: 35
    frequency: 1

  - pruner:
      instance_name : fc_pruner
    starting_epoch: 1
    ending_epoch: 35
    frequency: 1

  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 40
    ending_epoch: 100
    frequency: 1
