# This is a hybrid pruning schedule composed of several pruning techniques, all using AGP scheduling:
# 1. Filter pruning (and thinning) to reduce compute and activation sizes of some layers.
# 2. Fine grained pruning to reduce the parameter memory requirements of layers with large weights tensors.
# 3. Row pruning for the last linear (fully-connected) layer.
#
# Baseline results:
#     Top1: 91.780    Top5: 99.710    Loss: 0.376
#     Total MACs: 40,813,184
#     # of parameters: 270,896
#
# Results:
#     Top1: 91.630   Top5: 99.670
#     Total MACs: 30,638,720
#     Total sparsity: 41.84
#     # of parameters: 143,488 (=53% of the baseline parameters)
#
# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.1 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_2.yaml -j=1 --deterministic --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --gpus=0 --vs=0
#
#  Parameters:
#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
#  |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
#  |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
#  |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.38465 | -0.00533 |    0.27349 |
#  |  1 | module.layer1.0.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17334 | -0.01720 |    0.12535 |
#  |  2 | module.layer1.0.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17280 |  0.00148 |    0.12660 |
#  |  3 | module.layer1.1.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14518 | -0.02108 |    0.11044 |
#  |  4 | module.layer1.1.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13157 | -0.00240 |    0.09998 |
#  |  5 | module.layer1.2.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18724 | -0.00470 |    0.13594 |
#  |  6 | module.layer1.2.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15303 | -0.00564 |    0.11591 |
#  |  7 | module.layer2.0.conv1.weight        | (20, 16, 3, 3) |          2880 |           2880 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15062 | -0.00379 |    0.11690 |
#  |  8 | module.layer2.0.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12943 | -0.00739 |    0.10150 |
#  |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25227 | -0.01490 |    0.17715 |
#  | 10 | module.layer2.1.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11074 | -0.00783 |    0.08721 |
#  | 11 | module.layer2.1.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09753 | -0.00582 |    0.07681 |
#  | 12 | module.layer2.2.conv1.weight        | (20, 32, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11501 | -0.01363 |    0.09121 |
#  | 13 | module.layer2.2.conv2.weight        | (32, 20, 3, 3) |          5760 |           5760 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09145 |  0.00280 |    0.07167 |
#  | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09772 | -0.00674 |    0.07769 |
#  | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09316 | -0.00339 |    0.07396 |
#  | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12438 | -0.00958 |    0.09868 |
#  | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  6.71387 |  0.00000 |   69.99783 | 0.07404 | -0.00405 |    0.03694 |
#  | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.25098 |  0.00000 |   69.99783 | 0.06739 | -0.00494 |    0.03356 |
#  | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 10.37598 |  0.00000 |   69.99783 | 0.06739 | -0.00414 |    0.03368 |
#  | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 28.49121 |  0.00000 |   69.99783 | 0.03788 |  0.00048 |    0.01900 |
#  | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.54585 | -0.00002 |    0.46076 |
#  | 22 | Total sparsity:                     | -              |        246704 |         143488 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   41.83799 | 0.00000 |  0.00000 |    0.00000 |
#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
#  Total sparsity: 41.84
#
#  --- validate (epoch=179)-----------
#  10000 samples (256 per mini-batch)
#  ==> Top1: 91.430    Top5: 99.640    Loss: 0.365
#
#  ==> Best [Top1: 91.630   Top5: 99.670   Sparsity:41.84   NNZ-Params: 143488 on epoch: 74]
#  Saving checkpoint to: logs/2019.10.31-235045/checkpoint.pth.tar
#  --- test ---------------------
#  10000 samples (256 per mini-batch)
#  ==> Top1: 91.430    Top5: 99.640    Loss: 0.379
#
#
#  Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller_remote/examples/classifier_compression/logs/2019.10.31-235045/2019.10.31-235045.log
#
#  real    52m57.688s
#  user    304m20.353s
#  sys     10m56.498s

version: 1
pruners:
  low_pruner:
    class: L1RankedStructureParameterPruner_AGP
    initial_sparsity : 0.10
    final_sparsity: 0.40
    group_type: Filters
    weights: [module.layer1.0.conv1.weight,
              module.layer1.1.conv1.weight,
              module.layer1.2.conv1.weight,
              module.layer2.0.conv1.weight,
              module.layer2.1.conv1.weight,
              module.layer2.2.conv1.weight]

  fine_pruner:
    class:  AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.70
    weights: [module.layer3.1.conv1.weight,  module.layer3.1.conv2.weight,
              module.layer3.2.conv1.weight,  module.layer3.2.conv2.weight]

lr_schedulers:
  pruning_lr:
    class: StepLR
    step_size: 50
    gamma: 0.10


extensions:
  net_thinner:
      class: 'FilterRemover'
      thinning_func_str: remove_filters
      arch: 'resnet20_cifar'
      dataset: 'cifar10'

policies:
  - pruner:
      instance_name : low_pruner
    starting_epoch: 0
    ending_epoch: 20
    frequency: 2

  - pruner:
      instance_name : fine_pruner
    starting_epoch: 20
    ending_epoch: 40
    frequency: 2

# After completing the pruning, we perform network thinning and continue fine-tuning.
  - extension:
      instance_name: net_thinner
    epochs: [22]

  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 0
    ending_epoch: 400
    frequency: 1
