#
# This schedule is an example of "Iterative Pruning" for Alexnet/Imagent, combined
# with 2D structure regularization for the Convolution weights.
#
# time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j 24 --epochs 90 --pretrained --compress=../hybrid/alexnet.schedule_sensitivity_2D-reg.yaml
#
# Parameters:
#
# +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                      | Shape            |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | features.module.0.weight  | (64, 3, 11, 11)  |         23232 |          13380 |    0.00000 |    0.00000 |  0.00000 |  0.52083 |  0.00000 |   42.40702 | 0.15288 |  0.00001 |    0.09403 |
# |  1 | features.module.3.weight  | (192, 64, 5, 5)  |        307200 |         102744 |    0.00000 |    0.00000 |  0.00000 |  9.17969 |  0.00000 |   66.55469 | 0.04458 | -0.00215 |    0.02018 |
# |  2 | features.module.6.weight  | (384, 192, 3, 3) |        663552 |         176986 |    0.00000 |    0.00000 |  0.00000 | 29.33757 |  0.00000 |   73.32734 | 0.02720 | -0.00124 |    0.01197 |
# |  3 | features.module.8.weight  | (256, 384, 3, 3) |        884736 |         199956 |    0.00000 |    0.00000 |  0.00000 | 35.29867 |  0.00000 |   77.39936 | 0.02040 | -0.00092 |    0.00869 |
# |  4 | features.module.10.weight | (256, 256, 3, 3) |        589824 |         131286 |    0.00000 |    0.00000 |  0.00000 | 43.33954 |  0.00000 |   77.74150 | 0.02280 | -0.00154 |    0.00987 |
# |  5 | classifier.1.weight       | (4096, 9216)     |      37748736 |        3643767 |    0.00000 |    0.21973 |  0.00000 |  0.21973 |  0.00000 |   90.34731 | 0.00603 | -0.00019 |    0.00178 |
# |  6 | classifier.4.weight       | (4096, 4096)     |      16777216 |        1892052 |    0.21973 |    3.56445 |  0.00000 |  3.56445 |  0.00000 |   88.72249 | 0.00879 | -0.00067 |    0.00280 |
# |  7 | classifier.6.weight       | (1000, 4096)     |       4096000 |        1022778 |    3.44238 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   75.02983 | 0.01783 |  0.00039 |    0.00816 |
# |  8 | Total sparsity:           | -                |      61090496 |        7182950 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   88.24212 | 0.00000 |  0.00000 |    0.00000 |
# +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 88.24
#
# --- validate (epoch=89)-----------
# 128116 samples (256 per mini-batch)
# Epoch: [89][   50/  500]    Loss 2.179626    Top1 51.585938    Top5 73.976562
# Epoch: [89][  100/  500]    Loss 2.188121    Top1 51.261719    Top5 74.054688
# Epoch: [89][  150/  500]    Loss 2.186677    Top1 51.302083    Top5 74.096354
# Epoch: [89][  200/  500]    Loss 2.188725    Top1 51.195312    Top5 74.007812
# Epoch: [89][  250/  500]    Loss 2.184323    Top1 51.342188    Top5 74.150000
# Epoch: [89][  300/  500]    Loss 2.181935    Top1 51.441406    Top5 74.194010
# Epoch: [89][  350/  500]    Loss 2.180590    Top1 51.477679    Top5 74.223214
# Epoch: [89][  400/  500]    Loss 2.177557    Top1 51.538086    Top5 74.300781
# Epoch: [89][  450/  500]    Loss 2.178948    Top1 51.572049    Top5 74.275174
# Epoch: [89][  500/  500]    Loss 2.178128    Top1 51.576563    Top5 74.308594
# ==> Top1: 51.577    Top5: 74.305    Loss: 2.178
#
# Saving checkpoint
# --- test ---------------------
# 50000 samples (256 per mini-batch)
# Test: [   50/  195]    Loss 1.514649    Top1 62.546875    Top5 85.429688
# Test: [  100/  195]    Loss 1.659908    Top1 60.261719    Top5 83.367188
# Test: [  150/  195]    Loss 1.852519    Top1 57.171875    Top5 80.187500
# ==> Top1: 56.240    Top5: 79.246    Loss: 1.911
#
#
# Log file for this run: /data/home/cvds_lab/nzmora/pytorch_workspace/private-distiller/examples/classifier_compression/logs/2018.04.10-030052/2018.04.10-030052.log
#
# real    1031m47.668s
# user    23604m55.588s
# sys     1772m45.380s

version: 1
pruners:
  my_pruner:
    class: 'SensitivityPruner'
    sensitivities:
      'features.module.0.weight': 0.25
      'features.module.3.weight': 0.35
      'features.module.6.weight': 0.40
      'features.module.8.weight': 0.45
      'features.module.10.weight': 0.55
      'classifier.1.weight': 0.875
      'classifier.4.weight': 0.875
      'classifier.6.weight': 0.625

regularizers:
  2d_groups_regularizer:
    class: GroupLassoRegularizer
    reg_regims:
      'features.module.0.weight': [0.000012, '2D']
      'features.module.3.weight': [0.000012, '2D']
      'features.module.6.weight': [0.000012, '2D']
      'features.module.8.weight': [0.000012, '2D']
      'features.module.10.weight': [0.000012, '2D']
      #'classifier.1.weight': [0.000012, '2D']
      #'classifier.4.weight': [0.000012, '2D']
      #'classifier.6.weight': [0.000012, '2D']


lr_schedulers:
  # Learning rate decay scheduler
   pruning_lr:
     class: ExponentialLR
     gamma: 0.9

policies:
  - pruner:
      instance_name : 'my_pruner'
    starting_epoch: 0
    ending_epoch: 38
    frequency: 2

  - regularizer:
      instance_name: '2d_groups_regularizer'
    starting_epoch: 0
    ending_epoch: 38
    frequency: 1

  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 24
    ending_epoch: 200
    frequency: 1
