# This schedule performs element-wise (fine grain) pruning, following the Automated Gradual Pruner (Zhu-Gupta) schedule.
#
# time python3 compress_classifier.py -a=alexnet --lr=0.005 -p=50 ../../../data.imagenet -j 22 --epochs 90 --pretrained --compress=../agp-pruning/alexnet.schedule_agp.yaml
#
# Parameters:
#
# +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                      | Shape            |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | features.module.0.weight  | (64, 3, 11, 11)  |         23232 |          23232 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15003 | -0.00002 |    0.10000 |
# |  1 | features.module.3.weight  | (192, 64, 5, 5)  |        307200 |         116736 |    0.00000 |    0.00000 |  0.00000 |  2.19727 |  0.00000 |   62.00000 | 0.04665 | -0.00245 |    0.02222 |
# |  2 | features.module.6.weight  | (384, 192, 3, 3) |        663552 |         232244 |    0.00000 |    0.00000 |  0.00000 |  8.85824 |  0.00000 |   64.99988 | 0.03270 | -0.00179 |    0.01627 |
# |  3 | features.module.8.weight  | (256, 384, 3, 3) |        884736 |         504300 |    0.00000 |    0.00000 |  0.00000 |  0.44861 |  0.00000 |   42.99995 | 0.02758 | -0.00193 |    0.01720 |
# |  4 | features.module.10.weight | (256, 256, 3, 3) |        589824 |         336200 |    0.00000 |    0.00000 |  0.00000 |  0.68512 |  0.00000 |   42.99995 | 0.02836 | -0.00284 |    0.01795 |
# |  5 | classifier.1.weight       | (4096, 9216)     |      37748736 |        3397387 |    0.00000 |    0.21973 |  0.00000 |  0.21973 |  0.00000 |   91.00000 | 0.00567 | -0.00017 |    0.00151 |
# |  6 | classifier.4.weight       | (4096, 4096)     |      16777216 |        1509950 |    0.21973 |    3.93066 |  0.00000 |  3.93066 |  0.00000 |   91.00000 | 0.00798 | -0.00054 |    0.00217 |
# |  7 | classifier.6.weight       | (1000, 4096)     |       4096000 |        1024000 |    2.97852 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   75.00000 | 0.01688 |  0.00052 |    0.00753 |
# |  8 | Total sparsity:           | -                |      61090496 |        7144049 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   88.30579 | 0.00000 |  0.00000 |    0.00000 |
# +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 88.31
#
# --- validate (epoch=89)-----------
# 128116 samples (256 per mini-batch)
# Epoch: [89][   50/  500]    Loss 2.180059    Top1 51.671875    Top5 74.250000
# Epoch: [89][  100/  500]    Loss 2.174012    Top1 51.851562    Top5 74.386719
# Epoch: [89][  150/  500]    Loss 2.187923    Top1 51.520833    Top5 74.184896
# Epoch: [89][  200/  500]    Loss 2.192574    Top1 51.433594    Top5 74.156250
# Epoch: [89][  250/  500]    Loss 2.184967    Top1 51.459375    Top5 74.279688
# Epoch: [89][  300/  500]    Loss 2.184608    Top1 51.440104    Top5 74.239583
# Epoch: [89][  350/  500]    Loss 2.179280    Top1 51.537946    Top5 74.357143
# Epoch: [89][  400/  500]    Loss 2.182293    Top1 51.491211    Top5 74.354492
# Epoch: [89][  450/  500]    Loss 2.182311    Top1 51.440104    Top5 74.333333
# Epoch: [89][  500/  500]    Loss 2.183890    Top1 51.459375    Top5 74.328125
# ==> Top1: 51.456    Top5: 74.326    Loss: 2.185
#
# Saving checkpoint
# --- test ---------------------
# 50000 samples (256 per mini-batch)
# Test: [   50/  195]    Loss 1.493478    Top1 63.070312    Top5 85.898438
# Test: [  100/  195]    Loss 1.643213    Top1 60.539063    Top5 83.589844
# Test: [  150/  195]    Loss 1.836551    Top1 57.466146    Top5 80.273438
# ==> Top1: 56.528    Top5: 79.352    Loss: 1.897
#
#
# Log file for this run: /data/home/cvds_lab/nzmora/private-distiller/examples/classifier_compression/logs/2018.03.30-142316/2018.03.30-142316.log
#
# real    664m25.914s
# user    13391m25.914s
# sys     1569m37.119s

version: 1
pruners:
  fc_pruner:
    class: 'AutomatedGradualPruner'
    initial_sparsity : 0.05
    final_sparsity: 0.91
    weights: ['classifier.1.weight', 'classifier.4.weight']

  fc3_pruner:
    class: 'AutomatedGradualPruner'
    initial_sparsity : 0.05
    final_sparsity: 0.75
    weights: ['classifier.6.weight']

  # conv1_pruner:
  #   class: 'AutomatedGradualPruner'
  #   initial_sparsity : 0.03
  #   final_sparsity: 0.16

  conv2_pruner:
    class: 'AutomatedGradualPruner'
    initial_sparsity : 0.03
    final_sparsity: 0.62
    weights: ['features.module.3.weight']

  conv3_pruner:
    class: 'AutomatedGradualPruner'
    initial_sparsity : 0.03
    final_sparsity: 0.65
    weights: ['features.module.6.weight']

  conv45_pruner:
    class: 'AutomatedGradualPruner'
    initial_sparsity : 0.03
    final_sparsity: 0.43
    weights: ['features.module.8.weight', 'features.module.10.weight']

lr_schedulers:
  # Learning rate decay scheduler
   pruning_lr:
     class: ExponentialLR
     gamma: 0.9


policies:
  - pruner:
      instance_name : 'conv2_pruner'
    starting_epoch: 1
    ending_epoch: 11
    frequency: 2

  - pruner:
      instance_name : 'conv3_pruner'
    starting_epoch: 1
    ending_epoch: 11
    frequency: 2

  - pruner:
      instance_name : 'conv45_pruner'
    starting_epoch: 1
    ending_epoch: 11
    frequency: 2

  - pruner:
      instance_name : 'fc_pruner'
    starting_epoch: 0
    ending_epoch: 18
    frequency: 2

  - pruner:
      instance_name: 'fc3_pruner'
    starting_epoch: 0
    ending_epoch: 18
    frequency: 2

  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 24
    ending_epoch: 200
    frequency: 1
