#
# This schedule performs filter-pruning using L1-norm ranking and AGP for the setting the pruning-rate decay.
#
# Best Top1: 74.472 (epoch 89)
# No. of Parameters: 12,335,296 (of 25,502,912) = 43.37% dense (56.63% sparse)
# Total MACs: 1,822,031,872 (of 4,089,184,256) = 44.56% compute = 2.24x
#
# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=../agp-pruning/resnet50.schedule_agp.filters.yaml --validation-split=0   --num-best-scores=10 --name="resnet50_filters_v3.2"
#
# Parameters:
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11126 | -0.00041 |    0.06795 |
# |  1 | module.layer1.0.conv1.weight        | (32, 64, 1, 1)     |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07716 | -0.00620 |    0.04652 |
# |  2 | module.layer1.0.conv2.weight        | (32, 32, 3, 3)     |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04019 |  0.00155 |    0.02598 |
# |  3 | module.layer1.0.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03806 | -0.00040 |    0.02405 |
# |  4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05146 | -0.00309 |    0.02865 |
# |  5 | module.layer1.1.conv1.weight        | (32, 256, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03163 |  0.00100 |    0.02180 |
# |  6 | module.layer1.1.conv2.weight        | (32, 32, 3, 3)     |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03692 |  0.00017 |    0.02594 |
# |  7 | module.layer1.1.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03186 | -0.00044 |    0.02031 |
# |  8 | module.layer1.2.conv1.weight        | (32, 256, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03012 |  0.00020 |    0.02207 |
# |  9 | module.layer1.2.conv2.weight        | (32, 32, 3, 3)     |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03639 | -0.00008 |    0.02752 |
# | 10 | module.layer1.2.conv3.weight        | (256, 32, 1, 1)    |          8192 |           8192 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02805 | -0.00242 |    0.01681 |
# | 11 | module.layer2.0.conv1.weight        | (64, 256, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03600 | -0.00123 |    0.02535 |
# | 12 | module.layer2.0.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02421 |  0.00001 |    0.01792 |
# | 13 | module.layer2.0.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02699 |  0.00001 |    0.01656 |
# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02061 | -0.00045 |    0.01215 |
# | 15 | module.layer2.1.conv1.weight        | (64, 512, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01742 | -0.00019 |    0.01078 |
# | 16 | module.layer2.1.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02549 |  0.00069 |    0.01664 |
# | 17 | module.layer2.1.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02252 | -0.00146 |    0.01323 |
# | 18 | module.layer2.2.conv1.weight        | (64, 512, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02343 | -0.00057 |    0.01626 |
# | 19 | module.layer2.2.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02409 |  0.00013 |    0.01683 |
# | 20 | module.layer2.2.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02583 |  0.00014 |    0.01794 |
# | 21 | module.layer2.3.conv1.weight        | (64, 512, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02346 | -0.00083 |    0.01743 |
# | 22 | module.layer2.3.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02384 | -0.00043 |    0.01808 |
# | 23 | module.layer2.3.conv3.weight        | (512, 64, 1, 1)    |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02310 | -0.00121 |    0.01595 |
# | 24 | module.layer3.0.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02991 | -0.00090 |    0.02168 |
# | 25 | module.layer3.0.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01797 | -0.00014 |    0.01337 |
# | 26 | module.layer3.0.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02274 | -0.00045 |    0.01635 |
# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01442 | -0.00001 |    0.00989 |
# | 28 | module.layer3.1.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01523 | -0.00036 |    0.01075 |
# | 29 | module.layer3.1.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01689 | -0.00003 |    0.01217 |
# | 30 | module.layer3.1.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01965 | -0.00062 |    0.01399 |
# | 31 | module.layer3.2.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01552 | -0.00033 |    0.01106 |
# | 32 | module.layer3.2.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01654 | -0.00056 |    0.01220 |
# | 33 | module.layer3.2.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01840 | -0.00052 |    0.01337 |
# | 34 | module.layer3.3.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01680 | -0.00057 |    0.01253 |
# | 35 | module.layer3.3.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01554 | -0.00054 |    0.01182 |
# | 36 | module.layer3.3.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01744 | -0.00095 |    0.01284 |
# | 37 | module.layer3.4.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01740 | -0.00080 |    0.01313 |
# | 38 | module.layer3.4.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01537 | -0.00063 |    0.01168 |
# | 39 | module.layer3.4.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01709 | -0.00117 |    0.01251 |
# | 40 | module.layer3.5.conv1.weight        | (128, 1024, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01870 | -0.00072 |    0.01435 |
# | 41 | module.layer3.5.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01530 | -0.00071 |    0.01172 |
# | 42 | module.layer3.5.conv3.weight        | (1024, 128, 1, 1)  |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01855 | -0.00211 |    0.01394 |
# | 43 | module.layer4.0.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02221 | -0.00087 |    0.01715 |
# | 44 | module.layer4.0.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01235 | -0.00011 |    0.00962 |
# | 45 | module.layer4.0.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01455 | -0.00057 |    0.01133 |
# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) |       2097152 |        2097152 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00905 | -0.00018 |    0.00689 |
# | 47 | module.layer4.1.conv1.weight        | (256, 2048, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01432 | -0.00032 |    0.01120 |
# | 48 | module.layer4.1.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01232 | -0.00060 |    0.00966 |
# | 49 | module.layer4.1.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01435 |  0.00002 |    0.01111 |
# | 50 | module.layer4.2.conv1.weight        | (256, 2048, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01780 | -0.00008 |    0.01399 |
# | 51 | module.layer4.2.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01082 | -0.00033 |    0.00850 |
# | 52 | module.layer4.2.conv3.weight        | (2048, 256, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01316 |  0.00019 |    0.00993 |
# | 53 | module.fc.weight                    | (1000, 2048)       |       2048000 |        2048000 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03326 |  0.00000 |    0.02289 |
# | 54 | Total sparsity:                     | -                  |      12335296 |       12335296 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 0.00
#
# --- validate (epoch=99)-----------
# 50000 samples (256 per mini-batch)
# Epoch: [99][   50/  195]    Loss 0.736064    Top1 80.679688    Top5 95.328125
# Epoch: [99][  100/  195]    Loss 0.859465    Top1 77.906250    Top5 93.902344
# Epoch: [99][  150/  195]    Loss 0.982479    Top1 75.354167    Top5 92.414062
# ==> Top1: 74.440    Top5: 91.936    Loss: 1.025
#
# ==> Best Top1: 75.880 on Epoch: 0
# ==> Best Top1: 75.380 on Epoch: 1
# ==> Best Top1: 74.714 on Epoch: 2
# ==> Best Top1: 74.472 on Epoch: 89   <===
# ==> Best Top1: 74.442 on Epoch: 94
# ==> Best Top1: 74.440 on Epoch: 99
# ==> Best Top1: 74.436 on Epoch: 92
# ==> Best Top1: 74.434 on Epoch: 78
# ==> Best Top1: 74.432 on Epoch: 81
# ==> Best Top1: 74.422 on Epoch: 90
# Saving checkpoint to: logs/resnet50_filters_v3.2___2018.12.09-135817/resnet50_filters_v3.2_checkpoint.pth.tar
# --- test ---------------------
# 50000 samples (256 per mini-batch)
# Test: [   50/  195]    Loss 0.736064    Top1 80.679688    Top5 95.328125
# Test: [  100/  195]    Loss 0.859465    Top1 77.906250    Top5 93.902344
# Test: [  150/  195]    Loss 0.982479    Top1 75.354167    Top5 92.414062
# ==> Top1: 74.440    Top5: 91.936    Loss: 1.025
#
#
# Log file for this run: /data/home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/resnet50_filters_v3.2___2018.12.09-135817/resnet50_filters_v3.2___2018.12.09-135817.log
#
# real    2674m27.621s
# user    29485m43.111s
# sys     2504m41.802s

version: 1

pruners:
  fc_pruner:
    class: AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.87
    weights: module.fc.weight

  # fc_pruner:
  #   class: StructuredAutomatedGradualPruner
  #   initial_sparsity : 0.05
  #   final_sparsity: 0.50
  #   reg_regims:
  #     module.fc.weight: Rows

  filter_pruner:
    class: L1RankedStructureParameterPruner_AGP
    initial_sparsity : 0.05
    final_sparsity: 0.50
    group_type: Filters
    weights: [module.layer1.0.conv1.weight,
              module.layer1.1.conv1.weight,
              module.layer1.2.conv1.weight,
              module.layer2.0.conv1.weight,
              module.layer2.1.conv1.weight,
              module.layer2.2.conv1.weight,
              module.layer2.3.conv1.weight,
              module.layer3.0.conv1.weight,
              module.layer3.1.conv1.weight,
              module.layer3.2.conv1.weight,
              module.layer3.3.conv1.weight,
              module.layer3.4.conv1.weight,
              module.layer3.5.conv1.weight,
              module.layer4.0.conv1.weight,
              module.layer4.1.conv1.weight,
              module.layer4.2.conv1.weight,


              module.layer1.0.conv2.weight,
              module.layer1.1.conv2.weight,
              module.layer1.2.conv2.weight,
              module.layer2.0.conv2.weight,
              module.layer2.1.conv2.weight,
              module.layer2.2.conv2.weight,
              module.layer2.3.conv2.weight,
              module.layer3.0.conv2.weight,
              module.layer3.1.conv2.weight,
              module.layer3.2.conv2.weight,
              module.layer3.3.conv2.weight,
              module.layer3.4.conv2.weight,
              module.layer3.5.conv2.weight,
              module.layer4.0.conv2.weight,
              module.layer4.1.conv2.weight,
              module.layer4.2.conv2.weight]

  fine_pruner:
    class: AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.70
    weights: [
      module.layer4.0.conv2.weight,
      module.layer4.0.conv3.weight,
      module.layer4.0.downsample.0.weight,
      module.layer4.1.conv1.weight,
      module.layer4.1.conv2.weight,
      module.layer4.1.conv3.weight,
      module.layer4.2.conv1.weight,
      module.layer4.2.conv2.weight,
      module.layer4.2.conv3.weight]

extensions:
  net_thinner:
    class: 'FilterRemover'
    thinning_func_str: remove_filters
    arch: 'resnet50'
    dataset: 'imagenet'

lr_schedulers:
  pruning_lr:
    class: ExponentialLR
    gamma: 0.95

policies:
#  - pruner:
#     instance_name : fine_pruner
#    starting_epoch: 0
#    ending_epoch: 45
#    frequency: 3

  - pruner:
     instance_name : filter_pruner
#     args:
#       mini_batch_pruning_frequency: 1
    starting_epoch: 0
    ending_epoch: 30
    frequency: 1

#  - pruner:
#      instance_name : fc_pruner
#    starting_epoch: 0
#    ending_epoch: 30
#    frequency: 3

# After completeing the pruning, we perform network thinning and continue fine-tuning.
  - extension:
      instance_name: net_thinner
    epochs: [31]


  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 40
    ending_epoch: 80
    frequency: 1
