# This is the same as examples/agp-pruning/resnet20_filters.schedule_agp.yaml, but with BN-folding for ranking
#
# Baseline results:
#     Top1: 91.780    Top5: 99.710    Loss: 0.376
#     Total MACs: 40,813,184
#     # of parameters: 270,896
#
# Results:
#     Top1: 91.34
#     Total MACs: 30,655,104
#     Total sparsity: 46.3
#     # of parameters: 120,000  (=55.7% of the baseline parameters)
#
# time python3 compress_classifier.py --arch resnet20_cifar  $CIFAR10_PATH -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.bn_fold.schedule_agp.yaml  --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --vs=0 --reset-optimizer --gpu=0
#
#  Parameters:
#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
#  |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
#  |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
#  |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.41630 | -0.00526 |    0.29222 |
#  |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15170 | -0.01458 |    0.10325 |
#  |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14991 |  0.00199 |    0.10401 |
#  |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13939 | -0.01859 |    0.10473 |
#  |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13140 | -0.00861 |    0.10033 |
#  |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17184 | -0.00677 |    0.11853 |
#  |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13944 | -0.00140 |    0.09709 |
#  |  7 | module.layer2.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17399 | -0.00476 |    0.13297 |
#  |  8 | module.layer2.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18404 |  0.00603 |    0.14192 |
#  |  9 | module.layer2.0.downsample.0.weight | (16, 16, 1, 1) |           256 |            256 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.34235 | -0.01751 |    0.25170 |
#  | 10 | module.layer2.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09837 | -0.00903 |    0.07080 |
#  | 11 | module.layer2.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08172 | -0.00530 |    0.05841 |
#  | 12 | module.layer2.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12200 | -0.00766 |    0.09124 |
#  | 13 | module.layer2.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08811 |  0.00425 |    0.06511 |
#  | 14 | module.layer3.0.conv1.weight        | (64, 16, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13276 | -0.00406 |    0.10314 |
#  | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09774 | -0.00467 |    0.07597 |
#  | 16 | module.layer3.0.downsample.0.weight | (64, 16, 1, 1) |          1024 |           1024 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16717 | -0.00613 |    0.13268 |
#  | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 |  7.76367 |  1.56250 |   69.99783 | 0.07646 | -0.00394 |    0.03609 |
#  | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  1.56250 | 10.59570 |  0.00000 |   69.99783 | 0.07015 | -0.00542 |    0.03294 |
#  | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  0.00000 | 13.37891 |  3.12500 |   69.99783 | 0.07038 | -0.00402 |    0.03317 |
#  | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          11060 |    0.00000 |    0.00000 |  3.12500 | 30.32227 |  0.00000 |   69.99783 | 0.04269 |  0.00014 |    0.01844 |
#  | 21 | module.fc.weight                    | (10, 64)       |           640 |            320 |    0.00000 |   50.00000 |  0.00000 |  0.00000 |  0.00000 |   50.00000 | 0.55338 | -0.00001 |    0.31654 |
#  | 22 | Total sparsity:                     | -              |        223536 |         120000 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   46.31737 | 0.00000 |  0.00000 |    0.00000 |
#  +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
#  Total sparsity: 46.32
#
#  --- validate (epoch=179)-----------
#  10000 samples (256 per mini-batch)
#  ==> Top1: 91.120    Top5: 99.660    Loss: 0.374
#
#  ==> Best [Top1: 91.340   Top5: 99.650   Sparsity:46.32   NNZ-Params: 120000 on epoch: 109]
#  Saving checkpoint to: logs/2019.10.31-200150/checkpoint.pth.tar
#  --- test ---------------------
#  10000 samples (256 per mini-batch)
#  ==> Top1: 91.120    Top5: 99.660    Loss: 0.364
#
#
#  real    33m39.771s
#  user    289m4.681s
#  sys     17m43.188s

version: 1

pruners:
  low_pruner:
    class: L1RankedStructureParameterPruner_AGP
    initial_sparsity : 0.10
    final_sparsity: 0.50
    group_type: Filters
    weights: [module.layer2.0.conv1.weight, module.layer2.0.conv2.weight,
              module.layer2.0.downsample.0.weight,
              module.layer2.1.conv2.weight, module.layer2.2.conv2.weight,
              module.layer2.1.conv1.weight, module.layer2.2.conv1.weight]

  fine_pruner:
    class:  AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.70
    weights: [module.layer3.1.conv1.weight,  module.layer3.1.conv2.weight,
              module.layer3.2.conv1.weight,  module.layer3.2.conv2.weight]

  fc_pruner:
    class: L1RankedStructureParameterPruner_AGP
    initial_sparsity : 0.05
    final_sparsity: 0.50
    group_type: Rows
    weights: [module.fc.weight]


lr_schedulers:
  pruning_lr:
    class: StepLR
    step_size: 50
    gamma: 0.10

extensions:
  net_thinner:
      class: 'FilterRemover'
      thinning_func_str: remove_filters
      arch: 'resnet20_cifar'
      dataset: 'cifar10'


policies:
  - pruner:
      instance_name : low_pruner
      args:
        fold_batchnorm: True
    starting_epoch: 0
    ending_epoch: 30
    frequency: 2

# After completing the pruning, we perform network thinning and continue fine-tuning.
# When there is ambiguity in the scheduling order of policies, Distiller follows the
# order of declaration.  Because epoch 30 is the end of one pruner, and the beginning
# of two others, and because we want the thinning to happen at the beginning of
# epoch 30, it is important to declare the thinning policy here and not lower in the
# file.
  - extension:
      instance_name: net_thinner
    epochs: [30]

  - pruner:
      instance_name : fine_pruner
      args:
        fold_batchnorm: True
    starting_epoch: 30
    ending_epoch: 50
    frequency: 2

  - pruner:
      instance_name : fc_pruner
    starting_epoch: 30
    ending_epoch: 50
    frequency: 2

  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 0
    ending_epoch: 400
    frequency: 1
