# SSL: Filter regularization
#
# This is a not-so-successful attempt at filter regularization:
# Total MACs: 27,800,192 = 68% compute density.
# Test Top1 after training: 90.39
# Test Top1 after fine-tuning: 90.93
#
# To train:
# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../ssl/ssl_filter-removal_training.yaml -j=1 --deterministic --name="filters"
#
# To fine-tune:
# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml --reset-optimizer --resume-from=...
#
# Parameters:
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.42230 | -0.00107 |    0.29227 |
# |  1 | module.layer1.0.conv1.weight        | (13, 16, 3, 3) |          1872 |           1872 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04139 | -0.00293 |    0.02218 |
# |  2 | module.layer1.0.conv2.weight        | (16, 13, 3, 3) |          1872 |           1872 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16320 | -0.00121 |    0.10359 |
# |  3 | module.layer1.1.conv1.weight        | (9, 16, 3, 3)  |          1296 |           1296 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02499 | -0.00091 |    0.01594 |
# |  4 | module.layer1.1.conv2.weight        | (16, 9, 3, 3)  |          1296 |           1296 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13183 | -0.01035 |    0.09682 |
# |  5 | module.layer1.2.conv1.weight        | (10, 16, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07616 | -0.00278 |    0.05246 |
# |  6 | module.layer1.2.conv2.weight        | (16, 10, 3, 3) |          1440 |           1440 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18164 | -0.01895 |    0.13244 |
# |  7 | module.layer2.0.conv1.weight        | (25, 16, 3, 3) |          3600 |           3600 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02687 |  0.00002 |    0.01887 |
# |  8 | module.layer2.0.conv2.weight        | (32, 25, 3, 3) |          7200 |           7200 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11768 | -0.01350 |    0.09049 |
# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.30952 | -0.03258 |    0.21696 |
# | 10 | module.layer2.1.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09844 | -0.00413 |    0.07454 |
# | 11 | module.layer2.1.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11731 | -0.00819 |    0.09292 |
# | 12 | module.layer2.2.conv1.weight        | (4, 32, 3, 3)  |          1152 |           1152 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02086 | -0.00164 |    0.01553 |
# | 13 | module.layer2.2.conv2.weight        | (32, 4, 3, 3)  |          1152 |           1152 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08841 | -0.00491 |    0.06650 |
# | 14 | module.layer3.0.conv1.weight        | (48, 32, 3, 3) |         13824 |          13824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01713 | -0.00044 |    0.01255 |
# | 15 | module.layer3.0.conv2.weight        | (64, 48, 3, 3) |         27648 |          27648 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09969 | -0.00692 |    0.07733 |
# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.19055 | -0.01650 |    0.14967 |
# | 17 | module.layer3.1.conv1.weight        | (30, 64, 3, 3) |         17280 |          17280 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01936 |  0.00019 |    0.01468 |
# | 18 | module.layer3.1.conv2.weight        | (64, 30, 3, 3) |         17280 |          17280 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08263 | -0.01507 |    0.06434 |
# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07410 | -0.00536 |    0.05833 |
# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06848 | -0.00032 |    0.05342 |
# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.66850 | -0.00003 |    0.54848 |
# | 22 | Total sparsity:                     | -              |        194144 |         194144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# 2018-09-22 15:44:14,320 - Total sparsity: 0.00
#
# 2018-09-22 15:44:14,321 - --- validate (epoch=179)-----------
# 2018-09-22 15:44:14,321 - 5000 samples (256 per mini-batch)
# 2018-09-22 15:44:15,800 - ==> Top1: 90.460    Top5: 99.720    Loss: 0.332
#
# 2018-09-22 15:44:15,802 - ==> Best Top1: 90.900   On Epoch: 148
#
# 2018-09-22 15:44:15,802 - Saving checkpoint to: logs/filters___2018.09.22-151047/filters_checkpoint.pth.tar
# 2018-09-22 15:44:15,818 - --- test ---------------------
# 2018-09-22 15:44:15,818 - 10000 samples (256 per mini-batch)
# 2018-09-22 15:44:17,459 - ==> Top1: 90.390    Top5: 99.750    Loss: 0.349

lr_schedulers:
  training_lr:
    class: StepLR
    step_size: 45
    gamma: 0.10

regularizers:
  Filters_groups_regularizer:
    class: GroupLassoRegularizer
    reg_regims:
      module.layer1.0.conv1.weight: [0.0008, Filters]
      module.layer1.1.conv1.weight: [0.0008, Filters]
      module.layer1.2.conv1.weight: [0.0006, Filters]
      module.layer2.0.conv1.weight: [0.0008, Filters]
      module.layer2.1.conv1.weight: [0.0002, Filters]
      module.layer2.2.conv1.weight: [0.0008, Filters]
      module.layer3.0.conv1.weight: [0.0012, Filters]
      module.layer3.1.conv1.weight: [0.0010, Filters]
      module.layer3.2.conv1.weight: [0.0002, Filters]
    threshold_criteria: Mean_Abs

extensions:
  net_thinner:
      class: 'FilterRemover'
      thinning_func_str: remove_filters
      arch: 'resnet20_cifar'
      dataset: 'cifar10'

policies:
  - lr_scheduler:
      instance_name: training_lr
    starting_epoch: 45
    ending_epoch: 300
    frequency: 1

# After completeing the regularization, we perform network thinning and exit.
  - extension:
      instance_name: net_thinner
    epochs: [179]

  - regularizer:
      instance_name: Filters_groups_regularizer
      args:
        keep_mask: True
    starting_epoch: 0
    ending_epoch: 180
    frequency: 1
