# Random drop-filter where we randomly choose a percentage of filters to prune (level), then use L1-norm ranking
# to choose which filters to prune.
#
#
# time python3 compress_classifier.py --arch=resnet20_cifar ../../../data.cifar --lr=0.3 --epochs=180 --batch=256 --compress=../drop_filter/resnet20_cifar_randomlevel_training.yaml --vs=0 -p=50 --gpus=0
#
# Parameters:
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.40816 | -0.00610 |    0.26546 |
# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15262 | -0.00699 |    0.10400 |
# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15914 | -0.01044 |    0.11828 |
# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13560 | -0.00450 |    0.09826 |
# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13313 | -0.00876 |    0.10116 |
# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17824 | -0.00447 |    0.13122 |
# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14746 | -0.00306 |    0.11315 |
# |  7 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13769 | -0.01100 |    0.10709 |
# |  8 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11786 | -0.00354 |    0.09118 |
# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.25319 |  0.00040 |    0.19276 |
# | 10 | module.layer2.1.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10268 | -0.00995 |    0.07987 |
# | 11 | module.layer2.1.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09043 | -0.00485 |    0.07108 |
# | 12 | module.layer2.2.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09899 | -0.01216 |    0.07835 |
# | 13 | module.layer2.2.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08196 | -0.00298 |    0.06411 |
# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10177 | -0.00970 |    0.08108 |
# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09456 | -0.00433 |    0.07474 |
# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14403 | -0.01702 |    0.11294 |
# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09133 | -0.00750 |    0.07241 |
# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07808 | -0.00790 |    0.06185 |
# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07204 | -0.00507 |    0.05624 |
# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04531 | -0.00329 |    0.03415 |
# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.52613 | -0.00002 |    0.40702 |
# | 22 | Total sparsity:                     | -              |        270896 |         270896 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 0.00
#
# --- validate (epoch=179)-----------
# 10000 samples (256 per mini-batch)
# ==> Top1: 91.660    Top5: 99.690    Loss: 0.374
#
# ==> Best [Top1: 91.800   Top5: 99.690   Sparsity:0.00   Params: 270896 on epoch: 171]
# Saving checkpoint to: logs/2019.03.24-162956/checkpoint.pth.tar
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# ==> Top1: 91.660    Top5: 99.690    Loss: 0.366
#
#
# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2019.03.24-162956/2019.03.24-162956.log
#
# real    31m57.174s
# user    98m48.700s
# sys     12m8.557s


lr_schedulers:
  training_lr:
    class: StepLR
    step_size: 45
    gamma: 0.10

pruners:
  random_filter_pruner:
    class: RandomLevelStructureParameterPruner
    sparsity_range: [0.1, 0.2]
    group_type: Filters
    weights: [module.layer1.0.conv1.weight, module.layer1.1.conv1.weight, module.layer1.2.conv1.weight,
              module.layer2.0.conv1.weight, module.layer2.1.conv1.weight, module.layer2.2.conv1.weight,
              module.layer3.0.conv1.weight, module.layer3.1.conv1.weight, module.layer3.2.conv1.weight]

policies:
  - lr_scheduler:
      instance_name: training_lr
    starting_epoch: 30
    ending_epoch: 200
    frequency: 1

  - pruner:
      instance_name: random_filter_pruner
      args:
        mini_batch_pruning_frequency: 16
        discard_masks_at_minibatch_end: True
        mask_on_forward_only: True
        use_double_copies: True
        mask_gradients: True
    starting_epoch: 0
    ending_epoch: 300
    frequency: 1
