# This is a hybrid pruning schedule composed of several pruning techniques, all using AGP scheduling:
# 1. Filter pruning (and thinning) to reduce compute and activation sizes of some layers.
# 2. Fine grained pruning to reduce the parameter memory requirements of layers with large weights tensors.
# 3. Row pruning for the last linear (fully-connected) layer.
#
# Baseline results:
#     Top1: 91.780    Top5: 99.710    Loss: 0.376
#     Total MACs: 40,813,184
#     # of parameters: 270,896
#
# Results:
#     Top1: 90.89
#     Total MACs: 24,723,776 (=1.65x MACs)
#     Total sparsity: 39.66
#     # of parameters: 78,776  (=29.1% of the baseline parameters)
#
# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --vs=0 --gpu=0
#
#  Parameters:
#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
#  |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
#  |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
#  |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.43641 | -0.01077 |    0.30848 |
#  |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16754 | -0.01429 |    0.11853 |
#  |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16230 |  0.00232 |    0.12090 |
#  |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14305 | -0.01696 |    0.10940 |
#  |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13429 | -0.00717 |    0.10360 |
#  |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.19055 | -0.01061 |    0.14304 |
#  |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15439 |  0.00421 |    0.11755 |
#  |  7 | module.layer2.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18164 | -0.00175 |    0.14038 |
#  |  8 | module.layer2.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.20045 | -0.01430 |    0.15603 |
#  |  9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) |           256 |            256 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.33919 | -0.02323 |    0.25281 |
#  | 10 | module.layer2.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13992 | -0.00223 |    0.10972 |
#  | 11 | module.layer2.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11157 |  0.00652 |    0.08719 |
#  | 12 | module.layer2.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14406 | -0.00417 |    0.10841 |
#  | 13 | module.layer2.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10744 |  0.00565 |    0.08125 |
#  | 14 | module.layer3.0.conv1.weight        | (64, 16, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13871 | -0.00321 |    0.10984 |
#  | 15 | module.layer3.0.conv2.weight        | (32, 64, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12541 | -0.00566 |    0.09928 |
#  | 16 | module.layer3.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.19916 | -0.00610 |    0.15371 |
#  | 17 | module.layer3.1.conv1.weight        | (64, 32, 3, 3) |         18432 |           5530 |    0.00000 |    0.00000 |  0.00000 |  7.61719 |  1.56250 |   69.99783 | 0.10042 | -0.00397 |    0.04862 |
#  | 18 | module.layer3.1.conv2.weight        | (32, 64, 3, 3) |         18432 |           5530 |    0.00000 |    0.00000 |  1.56250 |  9.71680 |  0.00000 |   69.99783 | 0.09041 | -0.00711 |    0.04355 |
#  | 19 | module.layer3.2.conv1.weight        | (64, 32, 3, 3) |         18432 |           5530 |    0.00000 |    0.00000 |  0.00000 | 12.30469 |  3.12500 |   69.99783 | 0.09180 | -0.00314 |    0.04391 |
#  | 20 | module.layer3.2.conv2.weight        | (32, 64, 3, 3) |         18432 |           5530 |    0.00000 |    0.00000 |  3.12500 | 29.29688 |  0.00000 |   69.99783 | 0.06059 |  0.00070 |    0.02763 |
#  | 21 | module.fc.weight                    | (10, 32)       |           320 |            160 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.73275 | -0.00001 |    0.41788 |
#  | 22 | Total sparsity:                     | -              |        130544 |          78776 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   39.65560 | 0.00000 |  0.00000 |    0.00000 |
#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
#  Total sparsity: 39.66
#
#  --- validate (epoch=179)-----------
#  10000 samples (256 per mini-batch)
#  ==> Top1: 90.660    Top5: 99.700    Loss: 0.343
#
#  ==> Best [Top1: 90.890   Top5: 99.750   Sparsity:39.66   NNZ-Params: 78776 on epoch: 107]
#  Saving checkpoint to: logs/2019.11.01-013919/checkpoint.pth.tar
#  --- test ---------------------
#  10000 samples (256 per mini-batch)
#  ==> Top1: 90.660    Top5: 99.700    Loss: 0.358
#
#
#  real    30m44.680s
#  user    290m49.048s
#  sys     17m17.928s

version: 1

pruners:
  low_pruner:
    class: L1RankedStructureParameterPruner_AGP
    initial_sparsity : 0.10
    final_sparsity: 0.50
    group_type: Filters
    weights: [module.layer2.0.conv1.weight, module.layer2.0.conv2.weight,
              module.layer2.0.downsample.0.weight,
              module.layer2.1.conv2.weight, module.layer2.2.conv2.weight,
              module.layer2.1.conv1.weight, module.layer2.2.conv1.weight]

  low_pruner_2:
    class: L1RankedStructureParameterPruner_AGP
    initial_sparsity : 0.10
    final_sparsity: 0.50
    group_type: Filters
    group_dependency: Leader
    weights: [module.layer3.0.conv2.weight, module.layer3.0.downsample.0.weight,
              module.layer3.1.conv2.weight,  module.layer3.2.conv2.weight]

  fine_pruner:
    class:  AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.70
    weights: [module.layer3.1.conv1.weight,  module.layer3.1.conv2.weight,
              module.layer3.2.conv1.weight,  module.layer3.2.conv2.weight]

  fc_pruner:
    class: L1RankedStructureParameterPruner_AGP
    initial_sparsity : 0.05
    final_sparsity: 0.50
    group_type: Rows
    weights: [module.fc.weight]


lr_schedulers:
  pruning_lr:
    class: StepLR
    step_size: 50
    gamma: 0.10

extensions:
  net_thinner:
      class: 'FilterRemover'
      thinning_func_str: remove_filters
      arch: 'resnet20_cifar'
      dataset: 'cifar10'


policies:
  - pruner:
      instance_name : low_pruner
    starting_epoch: 0
    ending_epoch: 30
    frequency: 2

  - pruner:
      instance_name : low_pruner_2
    starting_epoch: 0
    ending_epoch: 30
    frequency: 2

  - pruner:
      instance_name : fine_pruner
    starting_epoch: 30
    ending_epoch: 50
    frequency: 2

  - pruner:
      instance_name : fc_pruner
    starting_epoch: 30
    ending_epoch: 50
    frequency: 2

  # Currently the thinner is disabled until the the structure pruner is done, because it interacts
  # with the sparsity goals of the L1RankedStructureParameterPruner_AGP.
  # This can be fixed rather easily.
  # - extension:
  #     instance_name: net_thinner
  #   starting_epoch: 0
  #   ending_epoch: 20
  #   frequency: 2

# After completing the pruning, we perform network thinning and continue fine-tuning.
  - extension:
      instance_name: net_thinner
    epochs: [32]

  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 0
    ending_epoch: 400
    frequency: 1
