#
# This schedule uses the average percentage of zeros (APoZ) in the activations, to rank filters.
# Compare this to examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml - the pruning time is
# much longer due to the callbacks required for collecting the activation statistics (this can be improved by disabling
# of the detailed records collection, for example).
# This provides 62.7% compute compression (x1.6) while increasing the Top1.
#
# Baseline results:
#     Top1: 92.850    Top5: 99.780    Loss: 0.364
#     Total MACs: 125,747,840
#     Total parameters: 851504
# Results:
#     Top1: 93.030    Top5: 99.650    Loss: 1.533
#     Total MACs: 78,856,832
#     Total parameters: 634640 (74.53%)
#
#
# time python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../network_trimming/resnet56_cifar_activation_apoz.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --act-stats=valid
#
# Parameters:
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25444 |  0.01128 |    0.13307 |
# |  1 | module.layer1.0.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07351 |  0.00182 |    0.04119 |
# |  2 | module.layer1.0.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07510 | -0.00968 |    0.05190 |
# |  3 | module.layer1.1.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06982 |  0.00599 |    0.04476 |
# |  4 | module.layer1.1.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05886 | -0.01451 |    0.04284 |
# |  5 | module.layer1.2.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06894 | -0.00031 |    0.04735 |
# |  6 | module.layer1.2.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06561 | -0.00311 |    0.04952 |
# |  7 | module.layer1.3.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07374 | -0.00087 |    0.05137 |
# |  8 | module.layer1.3.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07346 | -0.00474 |    0.05348 |
# |  9 | module.layer1.4.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06855 |  0.00053 |    0.04867 |
# | 10 | module.layer1.4.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07078 | -0.01038 |    0.05366 |
# | 11 | module.layer1.5.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09372 | -0.00430 |    0.06283 |
# | 12 | module.layer1.5.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09056 | -0.00089 |    0.06517 |
# | 13 | module.layer1.6.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08050 | -0.00971 |    0.06157 |
# | 14 | module.layer1.6.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08000 | -0.00081 |    0.06004 |
# | 15 | module.layer1.7.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09966 | -0.01270 |    0.07424 |
# | 16 | module.layer1.7.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09293 |  0.00685 |    0.07128 |
# | 17 | module.layer1.8.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08764 | -0.01361 |    0.06730 |
# | 18 | module.layer1.8.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07053 |  0.00491 |    0.05341 |
# | 19 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09426 | -0.00345 |    0.07094 |
# | 20 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07798 | -0.00154 |    0.05783 |
# | 21 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16513 |  0.00688 |    0.11354 |
# | 24 | module.layer2.2.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05589 | -0.00558 |    0.04355 |
# | 25 | module.layer2.2.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04979 | -0.00491 |    0.03863 |
# | 26 | module.layer2.3.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05622 | -0.00471 |    0.04379 |
# | 27 | module.layer2.3.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04541 | -0.00271 |    0.03535 |
# | 28 | module.layer2.4.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05166 | -0.00597 |    0.03896 |
# | 29 | module.layer2.4.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04098 | -0.00381 |    0.03114 |
# | 30 | module.layer2.5.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04188 | -0.00373 |    0.03040 |
# | 31 | module.layer2.5.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03249 | -0.00190 |    0.02291 |
# | 32 | module.layer2.6.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04584 | -0.00553 |    0.03569 |
# | 33 | module.layer2.6.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03655 | -0.00216 |    0.02758 |
# | 34 | module.layer2.7.conv1.weight        | (16, 32, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05110 | -0.00700 |    0.03909 |
# | 35 | module.layer2.7.conv2.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03884 | -0.00129 |    0.02946 |
# | 36 | module.layer2.8.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03331 | -0.00269 |    0.02211 |
# | 37 | module.layer2.8.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02406 | -0.00014 |    0.01479 |
# | 38 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05957 | -0.00091 |    0.04658 |
# | 39 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05103 | -0.00016 |    0.03729 |
# | 40 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09200 |  0.00203 |    0.06440 |
# | 41 | module.layer3.1.conv1.weight        | (58, 64, 3, 3) |         33408 |          33408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03758 | -0.00117 |    0.02728 |
# | 42 | module.layer3.1.conv2.weight        | (64, 58, 3, 3) |         33408 |          33408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03577 | -0.00397 |    0.02686 |
# | 43 | module.layer3.2.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03704 | -0.00146 |    0.02762 |
# | 44 | module.layer3.2.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03409 | -0.00464 |    0.02638 |
# | 45 | module.layer3.3.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03887 | -0.00274 |    0.03015 |
# | 46 | module.layer3.3.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03390 | -0.00448 |    0.02648 |
# | 47 | module.layer3.4.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04296 | -0.00361 |    0.03345 |
# | 48 | module.layer3.4.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03454 | -0.00255 |    0.02628 |
# | 49 | module.layer3.5.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04208 | -0.00441 |    0.03301 |
# | 50 | module.layer3.5.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03186 | -0.00319 |    0.02431 |
# | 51 | module.layer3.6.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03119 | -0.00262 |    0.02419 |
# | 52 | module.layer3.6.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02298 | -0.00015 |    0.01670 |
# | 53 | module.layer3.7.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02912 | -0.00265 |    0.02235 |
# | 54 | module.layer3.7.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02078 | -0.00010 |    0.01524 |
# | 55 | module.layer3.8.conv1.weight        | (45, 64, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03270 | -0.00269 |    0.02542 |
# | 56 | module.layer3.8.conv2.weight        | (64, 45, 3, 3) |         25920 |          25920 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02244 |  0.00045 |    0.01630 |
# | 57 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.42577 | -0.00001 |    0.33523 |
# | 58 | Total sparsity:                     | -              |        634640 |         634640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 0.00
#
# --- validate (epoch=249)-----------
# 5000 samples (256 per mini-batch)
# ==> Top1: 92.740    Top5: 99.720    Loss: 1.534
#
# ==> Best Top1: 92.760   On Epoch: 237
#
# Saving checkpoint to: logs/2018.10.16-013006/checkpoint.pth.tar
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# ==> Top1: 93.030    Top5: 99.650    Loss: 1.533
#
#
# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.16-013006/2018.10.16-013006.log
#
# real    49m0.623s
# user    90m51.054s
# sys     8m36.745s

version: 1
pruners:
  filter_pruner_60:
    class: ActivationAPoZRankedFilterPruner_AGP
    initial_sparsity : 0.10
    final_sparsity: 0.6
    group_type: Filters
    weights: [
      module.layer1.0.conv1.weight,
      module.layer1.1.conv1.weight,
      module.layer1.2.conv1.weight,
      module.layer1.3.conv1.weight,
      module.layer1.4.conv1.weight,
      module.layer1.5.conv1.weight,
      module.layer1.6.conv1.weight,
      module.layer1.7.conv1.weight,
      module.layer1.8.conv1.weight]

  filter_pruner_50:
    #class: StructuredAutomatedGradualPruner
    class: ActivationAPoZRankedFilterPruner_AGP
    initial_sparsity : 0.10
    final_sparsity: 0.5
    group_type: Filters
    weights: [
      module.layer2.1.conv1.weight,
      module.layer2.2.conv1.weight,
      module.layer2.3.conv1.weight,
      module.layer2.4.conv1.weight,
      module.layer2.6.conv1.weight,
      module.layer2.7.conv1.weight]

  filter_pruner_10:
    class: ActivationAPoZRankedFilterPruner_AGP
    initial_sparsity : 0
    final_sparsity: 0.1
    group_type: Filters
    weights: [module.layer3.1.conv1.weight]

  filter_pruner_30:
    class: ActivationAPoZRankedFilterPruner_AGP
    initial_sparsity : 0.10
    final_sparsity: 0.3
    group_type: Filters
    weights: [
      module.layer3.2.conv1.weight,
      module.layer3.3.conv1.weight,
      module.layer3.5.conv1.weight,
      module.layer3.6.conv1.weight,
      module.layer3.7.conv1.weight,
      module.layer3.8.conv1.weight]


extensions:
  net_thinner:
      class: 'FilterRemover'
      thinning_func_str: remove_filters
      arch: 'resnet56_cifar'
      dataset: 'cifar10'

lr_schedulers:
   exp_finetuning_lr:
     class: ExponentialLR
     gamma: 0.95


policies:
  - pruner:
      instance_name: filter_pruner_60
    starting_epoch: 1
    ending_epoch: 20
    frequency: 2

  - pruner:
      instance_name: filter_pruner_50
    starting_epoch: 1
    ending_epoch: 20
    frequency: 2

  - pruner:
      instance_name: filter_pruner_30
    starting_epoch: 1
    ending_epoch: 20
    frequency: 2

  - pruner:
      instance_name: filter_pruner_10
    starting_epoch: 1
    ending_epoch: 20
    frequency: 2

  - extension:
      instance_name: net_thinner
    epochs: [20]

  - lr_scheduler:
      instance_name: exp_finetuning_lr
    starting_epoch: 10
    ending_epoch: 300
    frequency: 1
