# This is a hybrid pruning schedule composed of several pruning techniques, all using AGP scheduling:
# 1. Filter pruning (and thinning) to reduce compute and activation sizes of some layers.
# 2. Fine grained pruning to reduce the parameter memory requirements of layers with large weights tensors.
# 3. Row pruning for the last linear (fully-connected) layer.
#
# Baseline results:
#     Top1: 91.780    Top5: 99.710    Loss: 0.376
#     Total MACs: 40,813,184
#     # of parameters: 270,896
#
# Results:
#     Top1: 91.34
#     Total MACs: 30,655,104
#     Total sparsity: 46.3%
#     # of parameters: 120,000  (=55.7% of the baseline parameters)
#
# time python3 compress_classifier.py --arch resnet20_cifar  $CIFAR10_PATH -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml  --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --vs=0 --reset-optimizer --gpu=0
#
#  Parameters:
#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
#  |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
#  |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
#  |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.41726 | -0.00601 |    0.29649 |
#  |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15451 | -0.01086 |    0.10477 |
#  |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15074 | -0.00062 |    0.10780 |
#  |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13604 | -0.01868 |    0.10372 |
#  |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12923 | -0.00463 |    0.09968 |
#  |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17818 | -0.01222 |    0.13184 |
#  |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14475 | -0.00069 |    0.11089 |
#  |  7 | module.layer2.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16991 |  0.00091 |    0.12894 |
#  |  8 | module.layer2.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18059 |  0.00199 |    0.14176 |
#  |  9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) |           256 |            256 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.34145 | -0.03631 |    0.25094 |
#  | 10 | module.layer2.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13213 | -0.00809 |    0.10198 |
#  | 11 | module.layer2.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10230 |  0.00805 |    0.07883 |
#  | 12 | module.layer2.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11462 | -0.00682 |    0.08532 |
#  | 13 | module.layer2.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08012 |  0.00611 |    0.05776 |
#  | 14 | module.layer3.0.conv1.weight        | (64, 16, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13316 | -0.00256 |    0.10497 |
#  | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09755 | -0.00598 |    0.07722 |
#  | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) |          1024 |           1024 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16702 |  0.00251 |    0.12968 |
#  | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 10.25391 |  4.68750 |   69.99783 | 0.07554 | -0.00373 |    0.03568 |
#  | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  4.68750 | 11.49902 |  0.00000 |   69.99783 | 0.06968 | -0.00573 |    0.03275 |
#  | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 15.57617 |  4.68750 |   69.99783 | 0.06895 | -0.00504 |    0.03245 |
#  | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  4.68750 | 32.08008 |  0.00000 |   69.99783 | 0.04180 |  0.00053 |    0.01793 |
#  | 21 | module.fc.weight                    | (10, 64)       |           640 |            320 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.55055 | -0.00001 |    0.31038 |
#  | 22 | Total sparsity:                     | -              |        223536 |         120000 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   46.31737 | 0.00000 |  0.00000 |    0.00000 |
#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
#  Total sparsity: 46.32
#
#  --- validate (epoch=179)-----------
#  10000 samples (256 per mini-batch)
#  ==> Top1: 91.110    Top5: 99.700    Loss: 0.379
#
#  ==> Best [Top1: 91.340   Top5: 99.670   Sparsity:46.32   NNZ-Params: 120000 on epoch: 127]
#  Saving checkpoint to: logs/2019.10.31-204215/checkpoint.pth.tar
#  --- test ---------------------
#  10000 samples (256 per mini-batch)
#  ==> Top1: 91.110    Top5: 99.700    Loss: 0.363
#
#
#  real    31m38.549s
#  user    293m36.410s
#  sys     18m6.041s

version: 1

pruners:
  low_pruner:
    class: L1RankedStructureParameterPruner_AGP
    initial_sparsity : 0.10
    final_sparsity: 0.50
    group_type: Filters
    weights: [module.layer2.0.conv1.weight, module.layer2.0.conv2.weight,
              module.layer2.0.downsample.0.weight,
              module.layer2.1.conv2.weight, module.layer2.2.conv2.weight,
              module.layer2.1.conv1.weight, module.layer2.2.conv1.weight]

  fine_pruner:
    class:  AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.70
    weights: [module.layer3.1.conv1.weight,  module.layer3.1.conv2.weight,
              module.layer3.2.conv1.weight,  module.layer3.2.conv2.weight]

  fc_pruner:
    class: L1RankedStructureParameterPruner_AGP
    initial_sparsity : 0.05
    final_sparsity: 0.50
    group_type: Rows
    weights: [module.fc.weight]


lr_schedulers:
  pruning_lr:
    class: StepLR
    step_size: 50
    gamma: 0.10

extensions:
  net_thinner:
      class: 'FilterRemover'
      thinning_func_str: remove_filters
      arch: 'resnet20_cifar'
      dataset: 'cifar10'


policies:
  - pruner:
      instance_name : low_pruner
    starting_epoch: 0
    ending_epoch: 30
    frequency: 2

# After completing the pruning, we perform network thinning and continue fine-tuning.
# When there is ambiguity in the scheduling order of policies, Distiller follows the
# order of declaration.  Because epoch 30 is the end of one pruner, and the beginning
# of two others, and because we want the thinning to happen at the beginning of
# epoch 30, it is important to declare the thinning policy here and not lower in the
# file.
  - extension:
      instance_name: net_thinner
    epochs: [30]

  - pruner:
      instance_name : fine_pruner
    starting_epoch: 30
    ending_epoch: 50
    frequency: 2

  - pruner:
      instance_name : fc_pruner
    starting_epoch: 30
    ending_epoch: 50
    frequency: 2

  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 0
    ending_epoch: 400
    frequency: 1
