# This is a hybrid pruning schedule composed of several pruning techniques, all using AGP scheduling:
# 1. Filter pruning (and thinning) to reduce compute and activation sizes of some layers.
# 2. Fine grained pruning to reduce the parameter memory requirements of layers with large weights tensors.
# 3. Row pruning for the last linear (fully-connected) layer.
#
# Baseline results:
#     Top1: 91.780    Top5: 99.710    Loss: 0.376
#     Total MACs: 40,813,184
#     # of parameters: 270,896
#
# Results:
#     Top1: 91.29
#     Total MACs: 30,433,920 (74.6% of the original compute)
#     Total sparsity: 56.41%
#     # of parameters: 95922  (=35.4% of the baseline parameters ==> 64.6% sparsity)
#
# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_3.yaml  --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --vs=0
#
#  Parameters:
#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
#  |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
#  |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
#  |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.41066 | -0.01044 |    0.28684 |
#  |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15311 | -0.01381 |    0.10292 |
#  |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15199 |  0.00052 |    0.10768 |
#  |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13703 | -0.01426 |    0.10459 |
#  |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12956 | -0.00363 |    0.10109 |
#  |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18044 | -0.01347 |    0.13435 |
#  |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14809 | -0.00015 |    0.11282 |
#  |  7 | module.layer2.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16746 | -0.00541 |    0.13009 |
#  |  8 | module.layer2.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18162 | -0.00449 |    0.14069 |
#  |  9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) |           256 |            256 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.34670 |  0.00145 |    0.23974 |
#  | 10 | module.layer2.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11932 | -0.00470 |    0.08793 |
#  | 11 | module.layer2.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09387 |  0.00694 |    0.06755 |
#  | 12 | module.layer2.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11968 | -0.00680 |    0.08725 |
#  | 13 | module.layer2.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08844 |  0.00324 |    0.06366 |
#  | 14 | module.layer3.0.conv1.weight        | (64, 16, 3, 3) |          9216 |           4608 |    0.00000 |    0.00000 |  0.00000 |  1.26953 |  0.00000 |   50.00000 | 0.11985 | -0.00543 |    0.07131 |
#  | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.48828 |  0.00000 |   50.00000 | 0.08819 | -0.00512 |    0.05183 |
#  | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) |          1024 |           1024 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15923 | -0.01150 |    0.12339 |
#  | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.03125 |  0.00000 |   69.99783 | 0.07690 | -0.00368 |    0.03622 |
#  | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  8.78906 |  0.00000 |   69.99783 | 0.07113 | -0.00574 |    0.03346 |
#  | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 18.65234 |  9.37500 |   69.99783 | 0.06893 | -0.00443 |    0.03236 |
#  | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  9.37500 | 33.39844 |  0.00000 |   69.99783 | 0.04186 |  0.00078 |    0.01802 |
#  | 21 | module.fc.weight                    | (10, 64)       |           640 |            320 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.54733 | -0.00001 |    0.30963 |
#  | 22 | Total sparsity:                     | -              |        223536 |          96960 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   56.62444 | 0.00000 |  0.00000 |    0.00000 |
#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
#  Total sparsity: 56.62
#
#  --- validate (epoch=179)-----------
#  10000 samples (256 per mini-batch)
#  ==> Top1: 91.290    Top5: 99.760    Loss: 0.344
#
#  ==> Best [Top1: 91.290   Top5: 99.760   Sparsity:56.62   NNZ-Params: 96960 on epoch: 179]
#  Saving checkpoint to: logs/2019.11.01-005544/checkpoint.pth.tar
#  --- test ---------------------
#  10000 samples (256 per mini-batch)
#  ==> Top1: 91.290    Top5: 99.760    Loss: 0.344
#
#
#  real    31m53.439s
#  user    264m58.198s
#  sys     16m57.435s

version: 1

pruners:
  low_pruner:
    class: L1RankedStructureParameterPruner_AGP
    initial_sparsity : 0.10
    final_sparsity: 0.50
    group_type: Filters
    weights: [module.layer2.0.conv1.weight,
              module.layer2.0.conv2.weight,
              module.layer2.0.downsample.0.weight,
              module.layer2.1.conv2.weight,
              module.layer2.2.conv2.weight,  # to balance the BN
              module.layer2.1.conv1.weight,
              module.layer2.2.conv1.weight]

  fine_pruner1:
    class:  AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.50
    weights: [module.layer3.0.conv1.weight,  module.layer3.0.conv2.weight]

  fine_pruner2:
    class:  AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.70
    weights: [module.layer3.1.conv1.weight,  module.layer3.1.conv2.weight,
              module.layer3.2.conv1.weight,  module.layer3.2.conv2.weight]

  fc_pruner:
    class: L1RankedStructureParameterPruner_AGP
    initial_sparsity : 0.05
    final_sparsity: 0.50
    group_type: Rows
    weights: [module.fc.weight]

lr_schedulers:
  pruning_lr:
    class: StepLR
    step_size: 50
    gamma: 0.10

extensions:
  net_thinner:
      class: 'FilterRemover'
      thinning_func_str: remove_filters
      arch: 'resnet20_cifar'
      dataset: 'cifar10'


policies:
  - pruner:
      instance_name : low_pruner
    starting_epoch: 0
    ending_epoch: 30
    frequency: 2

  - pruner:
      instance_name : fine_pruner1
    starting_epoch: 30
    ending_epoch: 50
    frequency: 2

  - pruner:
      instance_name : fine_pruner2
    starting_epoch: 30
    ending_epoch: 50
    frequency: 2

  - pruner:
      instance_name : fc_pruner
    starting_epoch: 30
    ending_epoch: 50
    frequency: 2

  # Currently the thinner is disabled until the the structure pruner is done, because it interacts
  # with the sparsity goals of the L1RankedStructureParameterPruner_AGP.
  # This can be fixed rather easily.
  # - extension:
  #     instance_name: net_thinner
  #   starting_epoch: 0
  #   ending_epoch: 20
  #   frequency: 2

# After completing the pruning, we perform network thinning and continue fine-tuning.
  - extension:
      instance_name: net_thinner
    epochs: [32]

  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 0
    ending_epoch: 400
    frequency: 1
