# This is an example of one-shot channel pruning.
# It is very similar to the pruning schedule described in
#   Pruning Filters for Efficient Convnets, H. Li, A. Kadav, I. Durdanovic, H. Samet, and H. P. Graf.
#   ICLR 2017, arXiv:1608.087
# However, instead of one-shot filter ranking and pruning, we perform one-shot channel ranking and
# pruning, using L1-magnitude of the structures.
#
# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_channel_rank.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --vs=0
#
# Baseline results:
#     Top1: 92.850    Top5: 99.780    Loss: 0.464
#     Parameters: 851,504
#     Total MACs: 125,747,840
#
# Results:
#     Top1: 92.580    Top5: 99.670    Loss: 0.378
#     Parameters: 566,887 (=33.4% sparse)
#     Total MACs: 66,592,384  (=1.89x less MACs)
#
# Parameters:
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (5, 3, 3, 3)   |           135 |            135 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.51071 |  0.02961 |    0.34620 |
# |  1 | module.layer1.0.conv1.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10577 | -0.00027 |    0.06610 |
# |  2 | module.layer1.0.conv2.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09014 | -0.00534 |    0.05336 |
# |  3 | module.layer1.1.conv1.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09359 | -0.00072 |    0.05725 |
# |  4 | module.layer1.1.conv2.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06693 | -0.00967 |    0.04221 |
# |  5 | module.layer1.2.conv1.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11141 | -0.00108 |    0.07825 |
# |  6 | module.layer1.2.conv2.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08755 |  0.00565 |    0.06267 |
# |  7 | module.layer1.3.conv1.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12567 | -0.00626 |    0.09563 |
# |  8 | module.layer1.3.conv2.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10765 | -0.00790 |    0.08238 |
# |  9 | module.layer1.4.conv1.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11128 | -0.00219 |    0.07703 |
# | 10 | module.layer1.4.conv2.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09995 | -0.00546 |    0.06941 |
# | 11 | module.layer1.5.conv1.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13976 | -0.00127 |    0.09434 |
# | 12 | module.layer1.5.conv2.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11994 |  0.01183 |    0.08470 |
# | 13 | module.layer1.6.conv1.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14581 | -0.00892 |    0.10762 |
# | 14 | module.layer1.6.conv2.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11972 |  0.00398 |    0.08729 |
# | 15 | module.layer1.7.conv1.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13305 |  0.00098 |    0.09495 |
# | 16 | module.layer1.7.conv2.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10282 | -0.00396 |    0.07579 |
# | 17 | module.layer1.8.conv1.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14427 |  0.00264 |    0.10199 |
# | 18 | module.layer1.8.conv2.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10895 |  0.00739 |    0.07599 |
# | 19 | module.layer2.0.conv1.weight        | (32, 5, 3, 3)  |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.19751 | -0.00411 |    0.14773 |
# | 20 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10607 | -0.00486 |    0.07701 |
# | 21 | module.layer2.0.downsample.0.weight | (32, 5, 1, 1)  |           160 |            160 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25461 |  0.01664 |    0.19264 |
# | 22 | module.layer2.1.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09205 | -0.00378 |    0.06871 |
# | 23 | module.layer2.1.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08184 | -0.00368 |    0.06422 |
# | 24 | module.layer2.2.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08441 | -0.00480 |    0.06477 |
# | 25 | module.layer2.2.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07339 | -0.00748 |    0.05726 |
# | 26 | module.layer2.3.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08118 | -0.00391 |    0.06368 |
# | 27 | module.layer2.3.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06810 | -0.00177 |    0.05296 |
# | 28 | module.layer2.4.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07782 | -0.00768 |    0.06072 |
# | 29 | module.layer2.4.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06117 | -0.00520 |    0.04731 |
# | 30 | module.layer2.5.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05832 | -0.00430 |    0.04224 |
# | 31 | module.layer2.5.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04601 | -0.00230 |    0.03286 |
# | 32 | module.layer2.6.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06935 | -0.00572 |    0.05344 |
# | 33 | module.layer2.6.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05382 | -0.00365 |    0.04143 |
# | 34 | module.layer2.7.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07991 | -0.00900 |    0.06264 |
# | 35 | module.layer2.7.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06059 | -0.00253 |    0.04624 |
# | 36 | module.layer2.8.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04541 | -0.00436 |    0.02956 |
# | 37 | module.layer2.8.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03298 | -0.00051 |    0.02021 |
# | 38 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07679 | -0.00169 |    0.05996 |
# | 39 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06658 | -0.00063 |    0.04878 |
# | 40 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11350 |  0.00252 |    0.07997 |
# | 41 | module.layer3.1.conv1.weight        | (52, 64, 3, 3) |         29952 |          29952 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05254 | -0.00156 |    0.03844 |
# | 42 | module.layer3.1.conv2.weight        | (64, 52, 3, 3) |         29952 |          29952 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05045 | -0.00537 |    0.03825 |
# | 43 | module.layer3.2.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05230 | -0.00151 |    0.03934 |
# | 44 | module.layer3.2.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04915 | -0.00644 |    0.03828 |
# | 45 | module.layer3.3.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05474 | -0.00361 |    0.04263 |
# | 46 | module.layer3.3.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04832 | -0.00569 |    0.03775 |
# | 47 | module.layer3.4.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05756 | -0.00462 |    0.04486 |
# | 48 | module.layer3.4.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04677 | -0.00323 |    0.03594 |
# | 49 | module.layer3.5.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06053 | -0.00528 |    0.04773 |
# | 50 | module.layer3.5.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04602 | -0.00363 |    0.03534 |
# | 51 | module.layer3.6.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04386 | -0.00286 |    0.03391 |
# | 52 | module.layer3.6.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03310 | -0.00052 |    0.02422 |
# | 53 | module.layer3.7.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03997 | -0.00282 |    0.03058 |
# | 54 | module.layer3.7.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02942 | -0.00015 |    0.02128 |
# | 55 | module.layer3.8.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05219 | -0.00294 |    0.04075 |
# | 56 | module.layer3.8.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03520 |  0.00144 |    0.02545 |
# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.49462 | -0.00002 |    0.39447 |
# | 58 | Total sparsity:                     | -              |        566887 |         566887 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 0.00
#
# --- validate (epoch=249)-----------
# 10000 samples (256 per mini-batch)
# ==> Top1: 92.460    Top5: 99.670    Loss: 0.381
#
# ==> Best Top1: 92.580 on Epoch: 248
# Saving checkpoint to: logs/2018.11.29-145258/checkpoint.pth.tar
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# ==> Top1: 92.460    Top5: 99.670    Loss: 0.381




version: 1
pruners:
  filter_pruner_70:
    class: 'L1RankedStructureParameterPruner'
    group_type: Channels
    desired_sparsity: 0.7
    group_dependency: Leader
    weights: [
      module.layer1.1.conv1.weight,
      module.layer1.0.conv1.weight,
      module.layer1.2.conv1.weight,
      module.layer1.3.conv1.weight,
      module.layer1.4.conv1.weight,
      module.layer1.5.conv1.weight,
      module.layer1.6.conv1.weight,
      module.layer1.7.conv1.weight,
      module.layer1.8.conv1.weight,
      module.layer2.0.conv1.weight,
      module.layer2.0.downsample.0.weight
      ]

  filter_pruner_60:
    class: 'L1RankedStructureParameterPruner'
    group_type: Channels
    desired_sparsity: 0.6
    weights: [
      module.layer2.1.conv2.weight,
      module.layer2.2.conv2.weight,
      module.layer2.3.conv2.weight,
      module.layer2.4.conv2.weight,
      module.layer2.6.conv2.weight,
      module.layer2.7.conv2.weight]

  filter_pruner_20:
    class: 'L1RankedStructureParameterPruner'
    group_type: Channels
    desired_sparsity: 0.2
    weights: [module.layer3.1.conv2.weight]

  filter_pruner_40:
    class: 'L1RankedStructureParameterPruner'
    group_type: Channels
    desired_sparsity: 0.4
    weights: [
      module.layer3.2.conv2.weight,
      module.layer3.3.conv2.weight,
      module.layer3.5.conv2.weight,
      module.layer3.6.conv2.weight,
      module.layer3.7.conv2.weight,
      module.layer3.8.conv2.weight]


extensions:
  net_thinner:
      class: StructureRemover
      thinning_func_str: remove_channels
      arch: resnet56_cifar
      dataset: cifar10

lr_schedulers:
   exp_finetuning_lr:
     class: ExponentialLR
     gamma: 0.95


policies:
  - pruner:
      instance_name: filter_pruner_70
    epochs: [0]

  - pruner:
      instance_name: filter_pruner_60
    epochs: [0]

  - pruner:
      instance_name: filter_pruner_40
    epochs: [0]

  - pruner:
      instance_name: filter_pruner_20
    epochs: [0]

  - extension:
      instance_name: net_thinner
    epochs: [0]

  - lr_scheduler:
      instance_name: exp_finetuning_lr
    starting_epoch: 10
    ending_epoch: 300
    frequency: 1
