#
# This schedule is an example of "Iterative Pruning" for Alexnet/Imagent, as
# described in chapter 3 of Song Han's PhD dissertation: "EFFICIENT METHODS AND
# HARDWARE FOR DEEP LEARNING"
#
# The pruning policy uses multiple pruning phases.  Each pruning phase is
# followed by a retraining phase.
# In this particular policy, pruning is scheduled every 2 epochs.
# After 38/2 pruning phases, pruning ends and the only retraining continues.
#
# time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j=44 --epochs=90 --pretrained --compress=../sensitivity-pruning/alexnet.schedule_sensitivity.yaml
#
# Parameters:
#
# +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                      | Shape            |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | features.module.0.weight  | (64, 3, 11, 11)  |         23232 |          13373 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   42.43716 | 0.14381 | -0.00002 |    0.08794 |
# |  1 | features.module.3.weight  | (192, 64, 5, 5)  |        307200 |         115322 |    0.00000 |    0.00000 |  0.00000 |  2.04264 |  0.00000 |   62.46029 | 0.04702 | -0.00248 |    0.02286 |
# |  2 | features.module.6.weight  | (384, 192, 3, 3) |        663552 |         256454 |    0.00000 |    0.00000 |  0.00000 |  6.13742 |  0.00000 |   61.35133 | 0.03354 | -0.00184 |    0.01803 |
# |  3 | features.module.8.weight  | (256, 384, 3, 3) |        884736 |         315278 |    0.00000 |    0.00000 |  0.00000 |  7.02922 |  0.00000 |   64.36474 | 0.02647 | -0.00168 |    0.01423 |
# |  4 | features.module.10.weight | (256, 256, 3, 3) |        589824 |         186861 |    0.00000 |    0.00000 |  0.00000 | 15.72266 |  0.00000 |   68.31919 | 0.02714 | -0.00245 |    0.01408 |
# |  5 | classifier.1.weight       | (4096, 9216)     |      37748736 |        3395124 |    0.00000 |    0.21973 |  0.00000 |  0.21973 |  0.00000 |   91.00599 | 0.00589 | -0.00020 |    0.00168 |
# |  6 | classifier.4.weight       | (4096, 4096)     |      16777216 |        1783541 |    0.21973 |    3.49121 |  0.00000 |  3.49121 |  0.00000 |   89.36927 | 0.00849 | -0.00066 |    0.00263 |
# |  7 | classifier.6.weight       | (1000, 4096)     |       4096000 |         993134 |    3.39355 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   75.75356 | 0.01718 |  0.00029 |    0.00777 |
# |  8 | Total sparsity:           | -                |      61090496 |        7059087 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   88.44487 | 0.00000 |  0.00000 |    0.00000 |
# +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 88.44
#
# --- validate (epoch=89)-----------
# 128116 samples (256 per mini-batch)
# Epoch: [89][   50/  500]    Loss 2.149753    Top1 51.976562    Top5 74.859375
# Epoch: [89][  100/  500]    Loss 2.154934    Top1 51.941406    Top5 74.550781
# Epoch: [89][  150/  500]    Loss 2.159868    Top1 51.880208    Top5 74.513021
# Epoch: [89][  200/  500]    Loss 2.158245    Top1 51.875000    Top5 74.597656
# Epoch: [89][  250/  500]    Loss 2.150266    Top1 51.920313    Top5 74.667187
# Epoch: [89][  300/  500]    Loss 2.152199    Top1 51.933594    Top5 74.682292
# Epoch: [89][  350/  500]    Loss 2.152126    Top1 51.952009    Top5 74.684152
# Epoch: [89][  400/  500]    Loss 2.153599    Top1 51.949219    Top5 74.648438
# Epoch: [89][  450/  500]    Loss 2.151281    Top1 52.046875    Top5 74.703993
# Epoch: [89][  500/  500]    Loss 2.149620    Top1 52.032031    Top5 74.765625
# ==> Top1: 52.029    Top5: 74.767    Loss: 2.150
#
# Saving checkpoint
# --- test ---------------------
# 50000 samples (256 per mini-batch)
# Test: [   50/  195]    Loss 1.484814    Top1 63.328125    Top5 85.820312
# Test: [  100/  195]    Loss 1.636993    Top1 60.835938    Top5 83.617188
# Test: [  150/  195]    Loss 1.832027    Top1 57.713542    Top5 80.330729
# ==> Top1: 56.762    Top5: 79.340    Loss: 1.892
#
#
# Log file for this run: /data/home/cvds_lab/nzmora/private-distiller/examples/classifier_compression/logs/2018.04.08-154509/2018.04.08-154509.log
#
# real    646m54.061s
# user    14899m29.068s
# sys     1901m19.958s

version: 1
pruners:
  pruner1:
    class: 'SensitivityPruner'
    sensitivities:
      'features.module.0.weight': 0.25
      'features.module.3.weight': 0.35
      'features.module.6.weight': 0.40
      'features.module.8.weight': 0.45
      'features.module.10.weight': 0.55
      'classifier.1.weight': 0.875
      'classifier.4.weight': 0.875
      'classifier.6.weight': 0.625

lr_schedulers:
   pruning_lr:
     class: ExponentialLR
     gamma: 0.9

policies:
  - pruner:
      instance_name : 'pruner1'
    starting_epoch: 0
    ending_epoch: 38
    frequency: 2

  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 24
    ending_epoch: 200
    frequency: 1
