#
# This schedule performs filter ranking and removal, for the convolution layers in ResNet56-CIFAR, as described in
# Pruning Filters for Efficient Convnets, H. Li, A. Kadav, I. Durdanovic, H. Samet, and H. P. Graf.
# ICLR 2017, arXiv:1608.087
#
# Filters are ranked and pruned accordingly.
# This is followed by network thinning which removes the filters entirely from the model, and changes the convolution
# layers' dimensions accordingly.  Convolution layers that follow have their respective channels removed as well, as do
# Batch normailization layers.
#
# The authors write that: "Since there is no projection mapping for choosing the identity featuremaps, we only
# consider pruning the first layer of the residual block."
#
# Note that to use the command-line below, you will need the baseline ResNet56 model (checkpoint.resnet56_cifar_baseline.pth.tar).
# You may either train this model from scratch, or download it from the link below.
# https://s3-us-west-1.amazonaws.com/nndistiller/pruning_filters_for_efficient_convnets/checkpoint.resnet56_cifar_baseline.pth.tar
#
# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank_v2.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --vs=0
#
# Results: 53.9% (1.85x) of the original convolution MACs (when calculated using direct convolution)
#
# Baseline results:
#     Top1: 92.850    Top5: 99.780    Loss: 0.464
#     Parameters: 851,504
#     Total MACs: 125,747,840
#
# Results:
#     Top1: 92.910    Top5: 99.690    Loss: 0.376
#     Parameters: 570,704  (=33% sparse)
#     Total MACs: 67,797,632 (=1.85x less MACs)
#
# Parameters:
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25775 |  0.01021 |    0.13389 |
# |  1 | module.layer1.0.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08400 |  0.00016 |    0.04778 |
# |  2 | module.layer1.0.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08528 | -0.00987 |    0.05921 |
# |  3 | module.layer1.1.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08004 |  0.00482 |    0.05321 |
# |  4 | module.layer1.1.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06757 | -0.01286 |    0.04833 |
# |  5 | module.layer1.2.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07856 |  0.00001 |    0.05491 |
# |  6 | module.layer1.2.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06927 | -0.00455 |    0.05305 |
# |  7 | module.layer1.3.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08582 |  0.00157 |    0.06020 |
# |  8 | module.layer1.3.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08446 | -0.00188 |    0.06118 |
# |  9 | module.layer1.4.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08725 |  0.00491 |    0.06379 |
# | 10 | module.layer1.4.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07957 | -0.01561 |    0.06278 |
# | 11 | module.layer1.5.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10966 | -0.00952 |    0.07636 |
# | 12 | module.layer1.5.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10307 |  0.00342 |    0.07403 |
# | 13 | module.layer1.6.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10200 | -0.00991 |    0.07828 |
# | 14 | module.layer1.6.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09630 |  0.00492 |    0.07303 |
# | 15 | module.layer1.7.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11086 | -0.01292 |    0.08203 |
# | 16 | module.layer1.7.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10022 |  0.00847 |    0.07685 |
# | 17 | module.layer1.8.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10323 | -0.01191 |    0.07876 |
# | 18 | module.layer1.8.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07648 |  0.00615 |    0.05865 |
# | 19 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09559 | -0.00432 |    0.07201 |
# | 20 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07934 | -0.00183 |    0.05900 |
# | 21 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16645 |  0.00714 |    0.11508 |
# | 22 | module.layer2.1.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06975 | -0.00591 |    0.05336 |
# | 23 | module.layer2.1.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05907 | -0.00069 |    0.04646 |
# | 24 | module.layer2.2.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06476 | -0.00591 |    0.05015 |
# | 25 | module.layer2.2.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05564 | -0.00607 |    0.04301 |
# | 26 | module.layer2.3.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06246 | -0.00269 |    0.04888 |
# | 27 | module.layer2.3.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05009 | -0.00171 |    0.03892 |
# | 28 | module.layer2.4.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06047 | -0.00494 |    0.04774 |
# | 29 | module.layer2.4.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04690 | -0.00493 |    0.03661 |
# | 30 | module.layer2.5.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04318 | -0.00403 |    0.03144 |
# | 31 | module.layer2.5.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03371 | -0.00219 |    0.02386 |
# | 32 | module.layer2.6.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05296 | -0.00465 |    0.04163 |
# | 33 | module.layer2.6.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04041 | -0.00099 |    0.03073 |
# | 34 | module.layer2.7.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06190 | -0.00745 |    0.04871 |
# | 35 | module.layer2.7.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04535 |  0.00052 |    0.03475 |
# | 36 | module.layer2.8.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03378 | -0.00309 |    0.02261 |
# | 37 | module.layer2.8.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02457 | -0.00035 |    0.01523 |
# | 38 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06005 | -0.00122 |    0.04697 |
# | 39 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05147 | -0.00011 |    0.03767 |
# | 40 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09252 |  0.00159 |    0.06504 |
# | 41 | module.layer3.1.conv1.weight        | (52, 64, 3, 3) |         29952 |          29952 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03913 | -0.00138 |    0.02846 |
# | 42 | module.layer3.1.conv2.weight        | (64, 52, 3, 3) |         29952 |          29952 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03736 | -0.00428 |    0.02826 |
# | 43 | module.layer3.2.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03982 | -0.00118 |    0.02987 |
# | 44 | module.layer3.2.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03651 | -0.00484 |    0.02836 |
# | 45 | module.layer3.3.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04171 | -0.00306 |    0.03253 |
# | 46 | module.layer3.3.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03619 | -0.00400 |    0.02820 |
# | 47 | module.layer3.4.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04342 | -0.00387 |    0.03380 |
# | 48 | module.layer3.4.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03494 | -0.00264 |    0.02668 |
# | 49 | module.layer3.5.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04597 | -0.00467 |    0.03630 |
# | 50 | module.layer3.5.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03378 | -0.00285 |    0.02578 |
# | 51 | module.layer3.6.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03428 | -0.00242 |    0.02700 |
# | 52 | module.layer3.6.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02501 | -0.00001 |    0.01828 |
# | 53 | module.layer3.7.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03189 | -0.00319 |    0.02489 |
# | 54 | module.layer3.7.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02260 | -0.00034 |    0.01673 |
# | 55 | module.layer3.8.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03693 | -0.00290 |    0.02890 |
# | 56 | module.layer3.8.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02474 |  0.00102 |    0.01800 |
# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.42898 | -0.00001 |    0.33739 |
# | 58 | Total sparsity:                     | -              |        570704 |         570704 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 0.00
#
# --- validate (epoch=249)-----------
# 5000 samples (256 per mini-batch)
# ==> Top1: 92.720    Top5: 99.820    Loss: 0.358
#
# ==> Best Top1: 93.160 on Epoch: 238
# Saving checkpoint to: logs/2018.11.22-121053/checkpoint.pth.tar
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# ==> Top1: 92.910    Top5: 99.690    Loss: 0.376
#
#
# Log file for this run: /home/cvds_lab/nzmora/sandbox_5/distiller/examples/classifier_compression/logs/2018.11.22-121053/2018.11.22-121053.log
#
# real    31m6.775s
# user    70m33.551s
# sys     7m19.538s


version: 1
pruners:
  filter_pruner_70:
    class: 'L1RankedStructureParameterPruner'
    group_type: Filters
    desired_sparsity: 0.7
    weights: [
      module.layer1.0.conv1.weight,
      module.layer1.1.conv1.weight,
      module.layer1.2.conv1.weight,
      module.layer1.3.conv1.weight,
      module.layer1.4.conv1.weight,
      module.layer1.5.conv1.weight,
      module.layer1.6.conv1.weight,
      module.layer1.7.conv1.weight,
      module.layer1.8.conv1.weight]

  filter_pruner_60:
    class: 'L1RankedStructureParameterPruner'
    group_type: Filters
    desired_sparsity: 0.6
    weights: [
      module.layer2.1.conv1.weight,
      module.layer2.2.conv1.weight,
      module.layer2.3.conv1.weight,
      module.layer2.4.conv1.weight,
      module.layer2.6.conv1.weight,
      module.layer2.7.conv1.weight]

  filter_pruner_20:
    class: 'L1RankedStructureParameterPruner'
    group_type: Filters
    desired_sparsity: 0.2
    weights: [module.layer3.1.conv1.weight]

  filter_pruner_40:
    class: 'L1RankedStructureParameterPruner'
    group_type: Filters
    desired_sparsity: 0.4
    weights: [
      module.layer3.2.conv1.weight,
      module.layer3.3.conv1.weight,
      module.layer3.5.conv1.weight,
      module.layer3.6.conv1.weight,
      module.layer3.7.conv1.weight,
      module.layer3.8.conv1.weight]


extensions:
  net_thinner:
      class: 'FilterRemover'
      thinning_func_str: remove_filters
      arch: 'resnet56_cifar'
      dataset: 'cifar10'

lr_schedulers:
   exp_finetuning_lr:
     class: ExponentialLR
     gamma: 0.95


policies:
  - pruner:
      instance_name: filter_pruner_70
    epochs: [0]

  - pruner:
      instance_name: filter_pruner_60
    epochs: [0]

  - pruner:
      instance_name: filter_pruner_40
    epochs: [0]

  - pruner:
      instance_name: filter_pruner_20
    epochs: [0]

  - extension:
      instance_name: net_thinner
    epochs: [0]

  - lr_scheduler:
      instance_name: exp_finetuning_lr
    starting_epoch: 10
    ending_epoch: 300
    frequency: 1
