#
# This schedule uses the average percentage of zeros (APoZ) in the activations, to rank filters.
# Compare this to examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml - the pruning time is
# much longer due to the callbacks required for collecting the activation statistics (this can be improved by disabling
# of the detailed records collection, for example).
# This provides 53.92% compute compression (x1.85).
#
# Baseline results:
#     Top1: 92.850    Top5: 99.780    Loss: 0.364
#     Total MACs: 125,747,840
#     Total parameters: 851504
# Results:
#     Top1: 92.590    Top5: 99.630    Loss: 1.537
#     Total MACs: 67,797,632
#     Total parameters: 570704 (67.02%)
#
# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../network_trimming/resnet56_cifar_activation_apoz_v2.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --act-stats=valid
#
# Parameters:
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25855 |  0.00995 |    0.13518 |
# |  1 | module.layer1.0.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08280 | -0.00200 |    0.04624 |
# |  2 | module.layer1.0.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08243 | -0.01280 |    0.05715 |
# |  3 | module.layer1.1.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07365 |  0.00145 |    0.04511 |
# |  4 | module.layer1.1.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06361 | -0.01105 |    0.04664 |
# |  5 | module.layer1.2.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07599 | -0.00007 |    0.05344 |
# |  6 | module.layer1.2.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07223 | -0.00321 |    0.05500 |
# |  7 | module.layer1.3.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08744 |  0.00275 |    0.06146 |
# |  8 | module.layer1.3.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08422 | -0.01179 |    0.06336 |
# |  9 | module.layer1.4.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07359 | -0.00328 |    0.05202 |
# | 10 | module.layer1.4.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07655 | -0.00830 |    0.05807 |
# | 11 | module.layer1.5.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10343 | -0.00167 |    0.06877 |
# | 12 | module.layer1.5.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10276 | -0.00419 |    0.07489 |
# | 13 | module.layer1.6.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08513 | -0.00885 |    0.06337 |
# | 14 | module.layer1.6.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08756 | -0.00307 |    0.06605 |
# | 15 | module.layer1.7.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10795 | -0.01498 |    0.07969 |
# | 16 | module.layer1.7.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09896 |  0.00700 |    0.07549 |
# | 17 | module.layer1.8.conv1.weight        | (5, 16, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09739 | -0.01620 |    0.07525 |
# | 18 | module.layer1.8.conv2.weight        | (16, 5, 3, 3)  |           720 |            720 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07876 |  0.00671 |    0.05963 |
# | 19 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09674 | -0.00379 |    0.07281 |
# | 20 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07978 | -0.00164 |    0.05939 |
# | 21 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16693 |  0.00597 |    0.11503 |
# | 22 | module.layer2.1.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06721 | -0.00298 |    0.05084 |
# | 23 | module.layer2.1.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05842 | -0.00357 |    0.04602 |
# | 24 | module.layer2.2.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05689 | -0.00687 |    0.04438 |
# | 25 | module.layer2.2.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05025 | -0.00454 |    0.03910 |
# | 26 | module.layer2.3.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05853 | -0.00510 |    0.04560 |
# | 27 | module.layer2.3.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04758 | -0.00213 |    0.03705 |
# | 28 | module.layer2.4.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05571 | -0.00672 |    0.04273 |
# | 29 | module.layer2.4.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04397 | -0.00459 |    0.03407 |
# | 30 | module.layer2.5.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04363 | -0.00326 |    0.03157 |
# | 31 | module.layer2.5.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03392 | -0.00203 |    0.02392 |
# | 32 | module.layer2.6.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04665 | -0.00607 |    0.03589 |
# | 33 | module.layer2.6.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03699 | -0.00209 |    0.02802 |
# | 34 | module.layer2.7.conv1.weight        | (13, 32, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05658 | -0.00710 |    0.04287 |
# | 35 | module.layer2.7.conv2.weight        | (32, 13, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04273 | -0.00281 |    0.03222 |
# | 36 | module.layer2.8.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03511 | -0.00271 |    0.02330 |
# | 37 | module.layer2.8.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02510 | -0.00009 |    0.01561 |
# | 38 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06056 | -0.00144 |    0.04736 |
# | 39 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05184 | -0.00042 |    0.03792 |
# | 40 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09288 |  0.00286 |    0.06527 |
# | 41 | module.layer3.1.conv1.weight        | (52, 64, 3, 3) |         29952 |          29952 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03912 | -0.00155 |    0.02842 |
# | 42 | module.layer3.1.conv2.weight        | (64, 52, 3, 3) |         29952 |          29952 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03727 | -0.00412 |    0.02804 |
# | 43 | module.layer3.2.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03857 | -0.00132 |    0.02876 |
# | 44 | module.layer3.2.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03551 | -0.00487 |    0.02752 |
# | 45 | module.layer3.3.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04020 | -0.00281 |    0.03119 |
# | 46 | module.layer3.3.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03524 | -0.00470 |    0.02756 |
# | 47 | module.layer3.4.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04392 | -0.00341 |    0.03419 |
# | 48 | module.layer3.4.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03533 | -0.00286 |    0.02693 |
# | 49 | module.layer3.5.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04352 | -0.00462 |    0.03410 |
# | 50 | module.layer3.5.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03269 | -0.00326 |    0.02494 |
# | 51 | module.layer3.6.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03254 | -0.00281 |    0.02521 |
# | 52 | module.layer3.6.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02411 | -0.00045 |    0.01755 |
# | 53 | module.layer3.7.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03012 | -0.00277 |    0.02301 |
# | 54 | module.layer3.7.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02162 | -0.00051 |    0.01584 |
# | 55 | module.layer3.8.conv1.weight        | (39, 64, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03413 | -0.00266 |    0.02653 |
# | 56 | module.layer3.8.conv2.weight        | (64, 39, 3, 3) |         22464 |          22464 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02345 |  0.00058 |    0.01716 |
# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.43253 | -0.00001 |    0.34022 |
# | 58 | Total sparsity:                     | -              |        570704 |         570704 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 0.00
#
# --- validate (epoch=249)-----------
# 5000 samples (256 per mini-batch)
# ==> Top1: 91.880    Top5: 99.520    Loss: 1.543
#
# ==> Best Top1: 92.680   On Epoch: 237
#
# Saving checkpoint to: logs/2018.10.16-115506/checkpoint.pth.tar
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# ==> Top1: 92.590    Top5: 99.630    Loss: 1.537
#
#
# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.16-115506/2018.10.16-115506.log
#
# real    49m10.711s
# user    91m13.215s
# sys     8m34.108s

version: 1
pruners:
  filter_pruner_60:
    class: ActivationAPoZRankedFilterPruner_AGP
    initial_sparsity : 0.10
    final_sparsity: 0.7
    group_type: Filters
    weights: [
      module.layer1.0.conv1.weight,
      module.layer1.1.conv1.weight,
      module.layer1.2.conv1.weight,
      module.layer1.3.conv1.weight,
      module.layer1.4.conv1.weight,
      module.layer1.5.conv1.weight,
      module.layer1.6.conv1.weight,
      module.layer1.7.conv1.weight,
      module.layer1.8.conv1.weight]

  filter_pruner_50:
    class: ActivationAPoZRankedFilterPruner_AGP
    initial_sparsity : 0.10
    final_sparsity: 0.6
    group_type: Filters
    weights: [
      module.layer2.1.conv1.weight,
      module.layer2.2.conv1.weight,
      module.layer2.3.conv1.weight,
      module.layer2.4.conv1.weight,
      module.layer2.6.conv1.weight,
      module.layer2.7.conv1.weight]

  filter_pruner_10:
    class: ActivationAPoZRankedFilterPruner_AGP
    initial_sparsity : 0
    final_sparsity: 0.2
    group_type: Filters
    weights: [module.layer3.1.conv1.weight]

  filter_pruner_30:
    class: ActivationAPoZRankedFilterPruner_AGP
    initial_sparsity : 0.10
    final_sparsity: 0.4
    group_type: Filters
    weights: [
        module.layer3.2.conv1.weight,
        module.layer3.3.conv1.weight,
        module.layer3.5.conv1.weight,
        module.layer3.6.conv1.weight,
        module.layer3.7.conv1.weight,
        module.layer3.8.conv1.weight]


extensions:
  net_thinner:
      class: 'FilterRemover'
      thinning_func_str: remove_filters
      arch: 'resnet56_cifar'
      dataset: 'cifar10'

lr_schedulers:
   exp_finetuning_lr:
     class: ExponentialLR
     gamma: 0.95


policies:
  - pruner:
      instance_name: filter_pruner_60
    starting_epoch: 1
    ending_epoch: 20
    frequency: 2

  - pruner:
      instance_name: filter_pruner_50
    starting_epoch: 1
    ending_epoch: 20
    frequency: 2

  - pruner:
      instance_name: filter_pruner_30
    starting_epoch: 1
    ending_epoch: 20
    frequency: 2

  - pruner:
      instance_name: filter_pruner_10
    starting_epoch: 1
    ending_epoch: 20
    frequency: 2

  - extension:
      instance_name: net_thinner
    epochs: [20]

  - lr_scheduler:
      instance_name: exp_finetuning_lr
    starting_epoch: 10
    ending_epoch: 300
    frequency: 1
