# This schedule performs filter ranking and removal, for the convolution layers in ResNet56-CIFAR, as described in
# Pruning Filters for Efficient Convnets, H. Li, A. Kadav, I. Durdanovic, H. Samet, and H. P. Graf.
# ICLR 2017, arXiv:1608.087
#
# Filters are ranked and pruned accordingly.
# This is followed by network thinning which removes the filters entirely from the model, and changes the convolution
# layers' dimensions accordingly.  Convolution layers that follow have their respective channels removed as well, as do
# Batch normailization layers.
#
# The authors write that: "Since there is no projection mapping for choosing the identity featuremaps, we only
# consider pruning the first layer of the residual block."
#
# Note that to use the command-line below, you will need the baseline ResNet56 model (checkpoint.resnet56_cifar_baseline.pth.tar).
# You may either train this model from scratch, or download it from the link below.
# https://s3-us-west-1.amazonaws.com/nndistiller/pruning_filters_for_efficient_convnets/checkpoint.resnet56_cifar_baseline.pth.tar
#
# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar  --reset-optimizer --vs=0
#
# Results: 62.7% of the original convolution MACs (when calculated using direct convolution)
#
# Baseline results:
#     Top1: 92.850    Top5: 99.780    Loss: 0.464
#     Parameters: 851,504
#     Total MACs: 125,747,840
#
# Results:
#     Top1: 93.160    Top5: 99.770    Loss: 0.359
#     Parameters: 634,640 (=25.5% sparse)
#     Total MACs: 78,856,832  (=1.59x less MACs)
#
# Parameters:
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.29813 | -0.00025 |    0.13038 |
# |  1 | module.layer1.0.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10107 |  0.00154 |    0.04532 |
# |  2 | module.layer1.0.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10037 | -0.01144 |    0.07036 |
# |  3 | module.layer1.1.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07281 |  0.00606 |    0.04960 |
# |  4 | module.layer1.1.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07149 | -0.00825 |    0.05371 |
# |  5 | module.layer1.2.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10620 |  0.00802 |    0.06828 |
# |  6 | module.layer1.2.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10025 | -0.01102 |    0.07251 |
# |  7 | module.layer1.3.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09117 |  0.00476 |    0.06612 |
# |  8 | module.layer1.3.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08281 | -0.01167 |    0.06335 |
# |  9 | module.layer1.4.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12840 | -0.00739 |    0.08541 |
# | 10 | module.layer1.4.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11029 | -0.00373 |    0.07547 |
# | 11 | module.layer1.5.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12273 |  0.00353 |    0.08644 |
# | 12 | module.layer1.5.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11482 | -0.01021 |    0.08623 |
# | 13 | module.layer1.6.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11548 | -0.00751 |    0.08488 |
# | 14 | module.layer1.6.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10436 | -0.00536 |    0.07751 |
# | 15 | module.layer1.7.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12513 | -0.00715 |    0.09247 |
# | 16 | module.layer1.7.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09655 |  0.00126 |    0.07188 |
# | 17 | module.layer1.8.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13188 | -0.00344 |    0.09825 |
# | 18 | module.layer1.8.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09572 |  0.00137 |    0.07336 |
# | 19 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10913 | -0.00127 |    0.08239 |
# | 20 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09105 | -0.00103 |    0.06896 |
# | 21 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.21201 |  0.00925 |    0.14586 |
# | 22 | module.layer2.1.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07899 | -0.00625 |    0.06075 |
# | 23 | module.layer2.1.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06983 | -0.00272 |    0.05518 |
# | 24 | module.layer2.2.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07068 | -0.00318 |    0.05479 |
# | 25 | module.layer2.2.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06181 | -0.00348 |    0.04839 |
# | 26 | module.layer2.3.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07141 | -0.00803 |    0.05594 |
# | 27 | module.layer2.3.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06190 |  0.00061 |    0.04827 |
# | 28 | module.layer2.4.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07364 | -0.00800 |    0.05776 |
# | 29 | module.layer2.4.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06012 |  0.00121 |    0.04699 |
# | 30 | module.layer2.5.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06762 | -0.00774 |    0.05222 |
# | 31 | module.layer2.5.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05282 | -0.00115 |    0.04040 |
# | 32 | module.layer2.6.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07211 | -0.00665 |    0.05668 |
# | 33 | module.layer2.6.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05485 | -0.00150 |    0.04209 |
# | 34 | module.layer2.7.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05631 | -0.00307 |    0.04426 |
# | 35 | module.layer2.7.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04464 | -0.00278 |    0.03306 |
# | 36 | module.layer2.8.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05403 | -0.00609 |    0.03848 |
# | 37 | module.layer2.8.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04022 | -0.00025 |    0.02761 |
# | 38 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07608 | -0.00286 |    0.05980 |
# | 39 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06714 | -0.00102 |    0.05225 |
# | 40 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11481 | -0.00192 |    0.08748 |
# | 41 | module.layer3.1.conv1.weight        | (58, 64, 3, 3) |         33408 |          33408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05406 | -0.00239 |    0.04216 |
# | 42 | module.layer3.1.conv2.weight        | (64, 58, 3, 3) |         33408 |          33408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05063 | -0.00539 |    0.03926 |
# | 43 | module.layer3.2.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05888 | -0.00395 |    0.04621 |
# | 44 | module.layer3.2.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04979 | -0.00531 |    0.03896 |
# | 45 | module.layer3.3.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06070 | -0.00426 |    0.04805 |
# | 46 | module.layer3.3.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04619 | -0.00235 |    0.03520 |
# | 47 | module.layer3.4.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05023 | -0.00388 |    0.03945 |
# | 48 | module.layer3.4.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03666 | -0.00140 |    0.02780 |
# | 49 | module.layer3.5.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03828 | -0.00302 |    0.03027 |
# | 50 | module.layer3.5.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02896 | -0.00195 |    0.02164 |
# | 51 | module.layer3.6.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03483 | -0.00169 |    0.02672 |
# | 52 | module.layer3.6.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02518 |  0.00037 |    0.01825 |
# | 53 | module.layer3.7.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03966 | -0.00293 |    0.03091 |
# | 54 | module.layer3.7.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02753 |  0.00051 |    0.02028 |
# | 55 | module.layer3.8.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02897 | -0.00308 |    0.02240 |
# | 56 | module.layer3.8.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02130 | -0.00061 |    0.01556 |
# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.47902 | -0.00002 |    0.47518 |
# | 58 | Total sparsity:                     | -              |        634640 |         634640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 0.00
#
# --- validate (epoch=249)-----------
# 5000 samples (256 per mini-batch)
# ==> Top1: 92.680    Top5: 99.700    Loss: 0.354
#
# ==> Best Top1: 93.080 on Epoch: 243
# Saving checkpoint to: logs/2018.11.22-110009/checkpoint.pth.tar
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# ==> Top1: 93.160    Top5: 99.770    Loss: 0.359
#
#
# Log file for this run: /home/cvds_lab/nzmora/sandbox_5/distiller/examples/classifier_compression/logs/2018.11.22-110009/2018.11.22-110009.log
#
# real    30m58.723s
# user    70m7.727s
# sys     7m16.697s

version: 1
pruners:
  filter_pruner_60:
    class: 'L1RankedStructureParameterPruner'
    group_type: Filters
    desired_sparsity: 0.6
    weights: [
      module.layer1.0.conv1.weight,
      module.layer1.1.conv1.weight,
      module.layer1.2.conv1.weight,
      module.layer1.3.conv1.weight,
      module.layer1.4.conv1.weight,
      module.layer1.5.conv1.weight,
      module.layer1.6.conv1.weight,
      module.layer1.7.conv1.weight,
      module.layer1.8.conv1.weight]

  filter_pruner_50:
    class: 'L1RankedStructureParameterPruner'
    group_type: Filters
    desired_sparsity: 0.5
    weights: [
      module.layer2.1.conv1.weight,
      module.layer2.2.conv1.weight,
      module.layer2.3.conv1.weight,
      module.layer2.4.conv1.weight,
      module.layer2.6.conv1.weight,
      module.layer2.7.conv1.weight]

  filter_pruner_10:
    class: 'L1RankedStructureParameterPruner'
    group_type: Filters
    desired_sparsity: 0.1
    weights: [module.layer3.1.conv1.weight]

  filter_pruner_30:
    class: 'L1RankedStructureParameterPruner'
    group_type: Filters
    desired_sparsity: 0.3
    weights: [
      module.layer3.2.conv1.weight,
      module.layer3.3.conv1.weight,
      module.layer3.5.conv1.weight,
      module.layer3.6.conv1.weight,
      module.layer3.7.conv1.weight,
      module.layer3.8.conv1.weight]


extensions:
  net_thinner:
      class: 'FilterRemover'
      thinning_func_str: remove_filters
      arch: 'resnet56_cifar'
      dataset: 'cifar10'

lr_schedulers:
   exp_finetuning_lr:
     class: ExponentialLR
     gamma: 0.95


policies:
  - pruner:
      instance_name: filter_pruner_60
    epochs: [0]

  - pruner:
      instance_name: filter_pruner_50
    epochs: [0]

  - pruner:
      instance_name: filter_pruner_30
    epochs: [0]

  - pruner:
      instance_name: filter_pruner_10
    epochs: [0]

  - extension:
      instance_name: net_thinner
    epochs: [0]

  - lr_scheduler:
      instance_name: exp_finetuning_lr
    starting_epoch: 10
    ending_epoch: 300
    frequency: 1
