# This schedule demonstrates low-rate pruning (26% sparsity) acting as a regularizer to reduce the generalization error
# of ResNet50 using the ImageNet dataset.
# Top1 is 76.538 (=23.462 error rate) vs the published Top1: 76.15 (https://pytorch.org/docs/stable/torchvision/models.html)
#
# I ran this for 80 epochs, but it can probably run for a much shorter time and produce the same results (50 epochs?)
#
# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=80 --lr=0.001 --compress=resnet50.schedule_agp.yaml
#
# Parameters:
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11423 | -0.00048 |    0.07023 |
# |  1 | module.layer1.0.conv1.weight        | (64, 64, 1, 1)     |          4096 |            984 |    0.00000 |    0.00000 |  3.12500 | 75.97656 |  7.81250 |   75.97656 | 0.06234 | -0.00488 |    0.02488 |
# |  2 | module.layer1.0.conv2.weight        | (64, 64, 3, 3)     |         36864 |           8848 |    0.00000 |    0.00000 |  7.81250 | 33.88672 |  6.25000 |   75.99826 | 0.02540 |  0.00064 |    0.01024 |
# |  3 | module.layer1.0.conv3.weight        | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03259 |  0.00035 |    0.01952 |
# |  4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05311 | -0.00314 |    0.02976 |
# |  5 | module.layer1.1.conv1.weight        | (64, 256, 1, 1)    |         16384 |           5407 |    0.00000 |    0.00000 | 11.71875 | 66.99829 |  6.25000 |   66.99829 | 0.02694 |  0.00116 |    0.01374 |
# |  6 | module.layer1.1.conv2.weight        | (64, 64, 3, 3)     |         36864 |          12166 |    0.00000 |    0.00000 |  6.25000 | 16.67480 |  0.00000 |   66.99761 | 0.02510 |  0.00015 |    0.01256 |
# |  7 | module.layer1.1.conv3.weight        | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03004 | -0.00007 |    0.01880 |
# |  8 | module.layer1.2.conv1.weight        | (64, 256, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02775 |  0.00012 |    0.02005 |
# |  9 | module.layer1.2.conv2.weight        | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02927 | -0.00069 |    0.02190 |
# | 10 | module.layer1.2.conv3.weight        | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02861 | -0.00222 |    0.01712 |
# | 11 | module.layer2.0.conv1.weight        | (128, 256, 1, 1)   |         32768 |          10814 |    0.00000 |    0.00000 |  0.00000 | 66.99829 |  0.00000 |   66.99829 | 0.03077 | -0.00121 |    0.01567 |
# | 12 | module.layer2.0.conv2.weight        | (128, 128, 3, 3)   |        147456 |          58983 |    0.00000 |    0.00000 |  0.00000 |  7.04956 |  0.00000 |   59.99959 | 0.01942 | -0.00032 |    0.01106 |
# | 13 | module.layer2.0.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02581 | -0.00001 |    0.01597 |
# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1)   |        131072 |          43254 |    0.00000 |    0.00000 |  0.00000 | 66.99982 | 12.30469 |   66.99982 | 0.02055 | -0.00029 |    0.00925 |
# | 15 | module.layer2.1.conv1.weight        | (128, 512, 1, 1)   |         65536 |          15729 |    0.00000 |    0.00000 | 13.28125 | 75.99945 |  0.00000 |   75.99945 | 0.01449 |  0.00011 |    0.00605 |
# | 16 | module.layer2.1.conv2.weight        | (128, 128, 3, 3)   |        147456 |          35390 |    0.00000 |    0.00000 |  0.00000 | 31.81763 |  0.00000 |   75.99962 | 0.01666 |  0.00021 |    0.00694 |
# | 17 | module.layer2.1.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02037 | -0.00107 |    0.01159 |
# | 18 | module.layer2.2.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02152 | -0.00070 |    0.01494 |
# | 19 | module.layer2.2.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01991 | -0.00026 |    0.01415 |
# | 20 | module.layer2.2.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02417 | -0.00039 |    0.01701 |
# | 21 | module.layer2.3.conv1.weight        | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02241 | -0.00083 |    0.01660 |
# | 22 | module.layer2.3.conv2.weight        | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02064 | -0.00059 |    0.01555 |
# | 23 | module.layer2.3.conv3.weight        | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02242 | -0.00098 |    0.01548 |
# | 24 | module.layer3.0.conv1.weight        | (256, 512, 1, 1)   |        131072 |          31458 |    0.00000 |    0.00000 |  0.00000 | 75.99945 |  0.00000 |   75.99945 | 0.02543 | -0.00054 |    0.01128 |
# | 25 | module.layer3.0.conv2.weight        | (256, 256, 3, 3)   |        589824 |         194642 |    0.00000 |    0.00000 |  0.00000 | 16.35742 |  0.00000 |   66.99999 | 0.01480 | -0.00026 |    0.00767 |
# | 26 | module.layer3.0.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02153 | -0.00034 |    0.01529 |
# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01485 |  0.00006 |    0.01016 |
# | 28 | module.layer3.1.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         104858 |    0.00000 |    0.00000 |  4.58984 | 59.99985 |  0.00000 |   59.99985 | 0.01352 | -0.00038 |    0.00743 |
# | 29 | module.layer3.1.conv2.weight        | (256, 256, 3, 3)   |        589824 |         235930 |    0.00000 |    0.00000 |  0.00000 |  6.40717 |  0.00000 |   59.99993 | 0.01325 | -0.00017 |    0.00739 |
# | 30 | module.layer3.1.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01890 | -0.00097 |    0.01357 |
# | 31 | module.layer3.2.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01459 | -0.00046 |    0.01045 |
# | 32 | module.layer3.2.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01385 | -0.00061 |    0.01041 |
# | 33 | module.layer3.2.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01762 | -0.00069 |    0.01289 |
# | 34 | module.layer3.3.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01607 | -0.00066 |    0.01190 |
# | 35 | module.layer3.3.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01379 | -0.00066 |    0.01055 |
# | 36 | module.layer3.3.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01686 | -0.00102 |    0.01244 |
# | 37 | module.layer3.4.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01678 | -0.00087 |    0.01263 |
# | 38 | module.layer3.4.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01375 | -0.00081 |    0.01055 |
# | 39 | module.layer3.4.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01685 | -0.00141 |    0.01242 |
# | 40 | module.layer3.5.conv1.weight        | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01826 | -0.00079 |    0.01390 |
# | 41 | module.layer3.5.conv2.weight        | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01409 | -0.00080 |    0.01082 |
# | 42 | module.layer3.5.conv3.weight        | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01791 | -0.00203 |    0.01343 |
# | 43 | module.layer4.0.conv1.weight        | (512, 1024, 1, 1)  |        524288 |         209716 |    0.00000 |    0.00000 |  0.00000 | 59.99985 |  0.00000 |   59.99985 | 0.02063 | -0.00079 |    0.01202 |
# | 44 | module.layer4.0.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         943719 |    0.00000 |    0.00000 |  0.00000 | 10.43282 |  0.00000 |   59.99997 | 0.01083 | -0.00032 |    0.00638 |
# | 45 | module.layer4.0.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01424 | -0.00054 |    0.01098 |
# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) |       2097152 |         838861 |    0.00000 |    0.00000 |  0.00000 | 59.99999 |  0.00000 |   59.99999 | 0.00870 | -0.00005 |    0.00497 |
# | 47 | module.layer4.1.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |         419431 |    0.00000 |    0.00000 |  0.00000 | 59.99994 |  0.00000 |   59.99994 | 0.01288 | -0.00056 |    0.00753 |
# | 48 | module.layer4.1.conv2.weight        | (512, 512, 3, 3)   |       2359296 |         778568 |    0.00000 |    0.00000 |  0.00000 | 15.62958 |  0.00000 |   66.99999 | 0.01029 | -0.00052 |    0.00561 |
# | 49 | module.layer4.1.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01400 | -0.00008 |    0.01080 |
# | 50 | module.layer4.2.conv1.weight        | (512, 2048, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01694 | -0.00039 |    0.01327 |
# | 51 | module.layer4.2.conv2.weight        | (512, 512, 3, 3)   |       2359296 |        2359296 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01016 | -0.00059 |    0.00804 |
# | 52 | module.layer4.2.conv3.weight        | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01308 | -0.00000 |    0.00980 |
# | 53 | module.fc.weight                    | (1000, 2048)       |       2048000 |        2048000 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03288 |  0.00000 |    0.02269 |
# | 54 | Total sparsity:                     | -                  |      25502912 |       18871702 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   26.00178 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# 2018-09-20 11:14:10,977 - Total sparsity: 26.00
#
# 2018-09-20 11:14:10,977 - --- validate (epoch=80)-----------
# 2018-09-20 11:14:10,977 - 128116 samples (256 per mini-batch)
# 2018-09-20 11:14:27,909 - Epoch: [80][   50/  500]    Loss 0.958656    Top1 76.281250    Top5 91.539062
# 2018-09-20 11:14:35,973 - Epoch: [80][  100/  500]    Loss 0.971032    Top1 76.289062    Top5 91.375000
# 2018-09-20 11:14:43,769 - Epoch: [80][  150/  500]    Loss 0.965900    Top1 76.359375    Top5 91.505208
# 2018-09-20 11:14:52,185 - Epoch: [80][  200/  500]    Loss 0.963459    Top1 76.472656    Top5 91.494141
# 2018-09-20 11:15:00,467 - Epoch: [80][  250/  500]    Loss 0.961311    Top1 76.487500    Top5 91.554688
# 2018-09-20 11:15:08,730 - Epoch: [80][  300/  500]    Loss 0.952356    Top1 76.649740    Top5 91.640625
# 2018-09-20 11:15:17,016 - Epoch: [80][  350/  500]    Loss 0.955011    Top1 76.588170    Top5 91.614955
# 2018-09-20 11:15:25,533 - Epoch: [80][  400/  500]    Loss 0.952346    Top1 76.601562    Top5 91.615234
# 2018-09-20 11:15:34,597 - Epoch: [80][  450/  500]    Loss 0.950455    Top1 76.662326    Top5 91.646701
# 2018-09-20 11:15:42,484 - Epoch: [80][  500/  500]    Loss 0.952648    Top1 76.621094    Top5 91.630469
# 2018-09-20 11:15:42,554 - ==> Top1: 76.618    Top5: 91.629    Loss: 0.953
#
# 2018-09-20 11:15:42,643 - ==> Best Top1: 77.734   On Epoch: 1
# --- test ---------------------
# 50000 samples (256 per mini-batch)
# Test: [   50/  195]    Loss 0.666113    Top1 82.640625    Top5 96.125000
# Test: [  100/  195]    Loss 0.788863    Top1 79.734375    Top5 95.066406
# Test: [  150/  195]    Loss 0.900865    Top1 77.450521    Top5 93.656250
# ==> Top1: 76.538    Top5: 93.184    Loss: 0.943

version: 1
pruners:
  low_pruner:
    class: AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.60
    weights: [module.layer2.0.conv2.weight,
              module.layer3.1.conv1.weight,  module.layer3.1.conv2.weight,
              module.layer4.0.conv1.weight,  module.layer4.0.conv2.weight, module.layer4.0.downsample.0.weight,
              module.layer4.1.conv1.weight]

  mid_pruner:
    class:  AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.67
    weights: [module.layer1.1.conv1.weight,  module.layer1.1.conv2.weight,
              module.layer2.0.conv1.weight,  module.layer2.0.downsample.0.weight,
              module.layer3.0.conv2.weight,  module.layer4.1.conv2.weight]

  high_pruner:
    class:  AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.76
    weights: [module.layer1.0.conv1.weight,  module.layer1.0.conv2.weight,
              module.layer2.1.conv1.weight,  module.layer2.1.conv2.weight,
              module.layer3.0.conv1.weight]

lr_schedulers:
   pruning_lr:
     class: ExponentialLR
     gamma: 0.9


policies:
  - pruner:
      instance_name : low_pruner
    starting_epoch: 0
    ending_epoch: 16
    frequency: 2

  - pruner:
      instance_name : mid_pruner
    starting_epoch: 4
    ending_epoch: 16
    frequency: 2

  - pruner:
      instance_name : high_pruner
    starting_epoch: 4
    ending_epoch: 16
    frequency: 2

  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 13
    ending_epoch: 100
    frequency: 1
