# This schedule performs element-wise (fine grain) pruning, with 2D group regularization,
# following the Automated Gradual Pruner (Zhu-Gupta) schedule.
#
# time python3 compress_classifier.py -a=alexnet --lr=0.005 -p=50 ../../../data.imagenet -j 22 --epochs 90 --pretrained --compress=../hybrid/alexnet.schedule_agp_2Dreg.yaml
#
# Parameters:
#
# +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                      | Shape            |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | features.module.0.weight  | (64, 3, 11, 11)  |         23232 |          23232 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15382 | -0.00000 |    0.10280 |
# |  1 | features.module.3.weight  | (192, 64, 5, 5)  |        307200 |         116736 |    0.00000 |    0.00000 |  0.00000 |  4.19108 |  0.00000 |   62.00000 | 0.04541 | -0.00230 |    0.02114 |
# |  2 | features.module.6.weight  | (384, 192, 3, 3) |        663552 |         232244 |    0.00000 |    0.00000 |  0.00000 | 14.28358 |  0.00000 |   64.99988 | 0.02982 | -0.00155 |    0.01429 |
# |  3 | features.module.8.weight  | (256, 384, 3, 3) |        884736 |         504300 |    0.00000 |    0.00000 |  0.00000 |  2.13928 |  0.00000 |   42.99995 | 0.02435 | -0.00159 |    0.01458 |
# |  4 | features.module.10.weight | (256, 256, 3, 3) |        589824 |         336200 |    0.00000 |    0.00000 |  0.00000 |  2.74200 |  0.00000 |   42.99995 | 0.02570 | -0.00239 |    0.01572 |
# |  5 | classifier.1.weight       | (4096, 9216)     |      37748736 |        3397387 |    0.00000 |    0.21973 |  0.00000 |  0.21973 |  0.00000 |   91.00000 | 0.00569 | -0.00017 |    0.00152 |
# |  6 | classifier.4.weight       | (4096, 4096)     |      16777216 |        1509950 |    0.21973 |    4.15039 |  0.00000 |  4.15039 |  0.00000 |   91.00000 | 0.00809 | -0.00054 |    0.00221 |
# |  7 | classifier.6.weight       | (1000, 4096)     |       4096000 |        1024000 |    3.10059 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   75.00000 | 0.01719 |  0.00057 |    0.00767 |
# |  8 | Total sparsity:           | -                |      61090496 |        7144049 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   88.30579 | 0.00000 |  0.00000 |    0.00000 |
# +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 88.31
#
# --- validate (epoch=89)-----------
# 128116 samples (256 per mini-batch)
# Epoch: [89][   50/  500]    Loss 2.201010    Top1 51.210938    Top5 74.523438
# Epoch: [89][  100/  500]    Loss 2.177135    Top1 51.625000    Top5 74.785156
# Epoch: [89][  150/  500]    Loss 2.189493    Top1 51.239583    Top5 74.572917
# Epoch: [89][  200/  500]    Loss 2.190784    Top1 51.300781    Top5 74.494141
# Epoch: [89][  250/  500]    Loss 2.192392    Top1 51.306250    Top5 74.446875
# Epoch: [89][  300/  500]    Loss 2.194373    Top1 51.406250    Top5 74.444010
# Epoch: [89][  350/  500]    Loss 2.194265    Top1 51.440848    Top5 74.410714
# Epoch: [89][  400/  500]    Loss 2.192004    Top1 51.493164    Top5 74.432617
# Epoch: [89][  450/  500]    Loss 2.194460    Top1 51.421007    Top5 74.355903
# Epoch: [89][  500/  500]    Loss 2.196544    Top1 51.345313    Top5 74.357812
# ==> Top1: 51.338    Top5: 74.353    Loss: 2.197
#
# Saving checkpoint
# --- test ---------------------
# 50000 samples (256 per mini-batch)
# Test: [   50/  195]    Loss 1.496703    Top1 62.984375    Top5 85.507812
# Test: [  100/  195]    Loss 1.647965    Top1 60.425781    Top5 83.476562
# Test: [  150/  195]    Loss 1.840966    Top1 57.302083    Top5 80.260417
# ==> Top1: 56.406    Top5: 79.318    Loss: 1.901
#
#
# Log file for this run: /data/home/cvds_lab/nzmora/private-distiller/examples/classifier_compression/logs/2018.03.30-022204/2018.03.30-022204.log
#
# real    669m2.348s
# user    13577m24.604s
# sys     1588m34.821s

version: 1
pruners:
  fc_pruner:
    class: 'AutomatedGradualPruner'
    initial_sparsity : 0.05
    final_sparsity: 0.91
    weights: ['classifier.1.weight', 'classifier.4.weight']

  fc3_pruner:
    class: 'AutomatedGradualPruner'
    initial_sparsity : 0.05
    final_sparsity: 0.75
    weights: ['classifier.6.weight']

  # conv1_pruner:
  #   class: 'AutomatedGradualPruner'
  #   initial_sparsity : 0.03
  #   final_sparsity: 0.16

  conv2_pruner:
    class: 'AutomatedGradualPruner'
    initial_sparsity : 0.03
    final_sparsity: 0.62
    weights: ['features.module.3.weight']

  conv3_pruner:
    class: 'AutomatedGradualPruner'
    initial_sparsity : 0.03
    final_sparsity: 0.65
    weights: ['features.module.6.weight']

  conv45_pruner:
    class: 'AutomatedGradualPruner'
    initial_sparsity : 0.03
    final_sparsity: 0.43
    weights: ['features.module.8.weight', 'features.module.10.weight']

lr_schedulers:
  # Learning rate decay scheduler
   pruning_lr:
     class: ExponentialLR
     gamma: 0.9

regularizers:
  2d_groups_regularizer:
    class: GroupLassoRegularizer
    reg_regims:
      'features.module.0.weight': [0.000012, '2D']
      'features.module.3.weight': [0.000012, '2D']
      'features.module.6.weight': [0.000012, '2D']
      'features.module.8.weight': [0.000012, '2D']
      'features.module.10.weight': [0.000012, '2D']
      #'classifier.1.weight': [0.000012, '2D']
      #'classifier.4.weight': [0.000012, '2D']
      #'classifier.6.weight': [0.000012, '2D']

policies:
  - pruner:
      instance_name : 'conv2_pruner'
    # One unfortunate fact is that ImageNet pretrained models, load with
    # the epoch set to 0, not the epoch to which the models were trained.
    starting_epoch: 1
    ending_epoch: 11
    frequency: 2

  - pruner:
      instance_name : 'conv3_pruner'
    starting_epoch: 1
    ending_epoch: 11
    frequency: 2

  - pruner:
      instance_name : 'conv45_pruner'
    starting_epoch: 1
    ending_epoch: 11
    frequency: 2

  - pruner:
      instance_name : 'fc_pruner'
    starting_epoch: 0
    ending_epoch: 18
    frequency: 2

  - pruner:
      instance_name: 'fc3_pruner'
    starting_epoch: 0
    ending_epoch: 18
    frequency: 2

  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 24
    ending_epoch: 200
    frequency: 1

  - regularizer:
      instance_name: '2d_groups_regularizer'
    starting_epoch: 0
    ending_epoch: 24
    frequency: 1
