# SSL: Channel regularization
# We compressed the compute from 313M MACs to 85M MACs (x3.7), and the parameters from 14.7M to 725K (x20) using SSL
# with channel-wise regularization, with a drop of 0.19% Top1 accuracy and w/o muuch effort.
#
# time python3 compress_classifier.py --arch vgg16_cifar  ../../../data.cifar10 -p=50 --lr=0.05 --epochs=180 --compress=../ssl/vgg16_cifar_ssl_channels_training.yaml -j=1 --deterministic
#
# The results below are from the SSL training session, and you can follow-up with some fine-tuning:
# time python3 compress_classifier.py --arch vgg16_cifar  ../../../data.cifar10 --resume-from=checkpoint.vgg16_cifar.pth.tar  --reset-optimizer --lr=0.01 --epochs=20
# ==> Top1: 91.010    Top5: 99.480    Loss: 0.513
#
# Parameters:
# +----------+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |          | Name                      | Shape            |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----------+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0.00000 | features.module.0.weight  | (31, 3, 3, 3)    |           837 |            837 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.32376 | -0.00329 |    0.23517 |
# |  1.00000 | features.module.2.weight  | (47, 31, 3, 3)   |         13113 |          13113 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07184 | -0.00374 |    0.04210 |
# |  2.00000 | features.module.5.weight  | (98, 47, 3, 3)   |         41454 |          41454 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04835 | -0.00384 |    0.03511 |
# |  3.00000 | features.module.7.weight  | (117, 98, 3, 3)  |        103194 |         103194 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03096 | -0.00500 |    0.02305 |
# |  4.00000 | features.module.10.weight | (193, 117, 3, 3) |        203229 |         203229 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02948 | -0.00345 |    0.02259 |
# |  5.00000 | features.module.12.weight | (164, 193, 3, 3) |        284868 |         284868 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01766 | -0.00233 |    0.01313 |
# |  6.00000 | features.module.14.weight | (24, 164, 3, 3)  |         35424 |          35424 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01755 | -0.00082 |    0.01235 |
# |  7.00000 | features.module.17.weight | (15, 24, 3, 3)   |          3240 |           3240 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04479 |  0.00043 |    0.03239 |
# |  8.00000 | features.module.19.weight | (9, 15, 3, 3)    |          1215 |           1215 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08220 |  0.00163 |    0.06065 |
# |  9.00000 | features.module.21.weight | (7, 9, 3, 3)     |           567 |            567 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12944 |  0.00961 |    0.09122 |
# | 10.00000 | features.module.24.weight | (5, 7, 3, 3)     |           315 |            315 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17881 |  0.02638 |    0.12776 |
# | 11.00000 | features.module.26.weight | (7, 5, 3, 3)     |           315 |            315 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18656 |  0.03477 |    0.12432 |
# | 12.00000 | features.module.28.weight | (512, 7, 3, 3)   |         32256 |          32256 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02310 |  0.00009 |    0.01329 |
# | 13.00000 | classifier.weight         | (10, 512)        |          5120 |           5120 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10157 | -0.00002 |    0.07181 |
# | 14.00000 | Total sparsity:           | -                |        725147 |         725147 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
# +----------+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
#
# Total sparsity: 0.00
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# Test: [   10/   39]    Loss 0.454324    Top1 90.664062    Top5 99.609375
# Test: [   20/   39]    Loss 0.450643    Top1 90.722656    Top5 99.511719
# Test: [   30/   39]    Loss 0.441285    Top1 90.807292    Top5 99.557292
# Test: [   40/   39]    Loss 0.458055    Top1 90.740000    Top5 99.580000
# ==> Top1: 90.740    Top5: 99.580    Loss: 0.458

lr_schedulers:
  training_lr:
    class: StepLR
    step_size: 45
    gamma: 0.10

regularizers:
  Channels_groups_regularizer:
    class: GroupLassoRegularizer
    reg_regims:
      features.module.0.weight: [0.0008, Channels]
      features.module.2.weight: [0.0008, Channels]
      features.module.5.weight: [0.0008, Channels]
      features.module.7.weight: [0.0008, Channels]
      features.module.10.weight: [0.0008, Channels]
      features.module.12.weight: [0.0008, Channels]
      features.module.14.weight: [0.0008, Channels]
      features.module.17.weight: [0.0008, Channels]
      features.module.19.weight: [0.0008, Channels]
      features.module.21.weight: [0.0008, Channels]
      features.module.24.weight: [0.0008, Channels]
      features.module.26.weight: [0.0008, Channels]
      features.module.28.weight: [0.0008, Channels]

    threshold_criteria: Mean_Abs

extensions:
  net_thinner:
      class: 'ChannelRemover'
      thinning_func_str: remove_channels
      arch: 'vgg16_cifar'
      dataset: 'cifar10'

policies:
  - lr_scheduler:
      instance_name: training_lr
    starting_epoch: 0
    ending_epoch: 300
    frequency: 1

  - extension:
      instance_name: net_thinner
    epochs: [0]

  - regularizer:
      instance_name: Channels_groups_regularizer
      args:
        keep_mask: True
    starting_epoch: 0
    ending_epoch: 180
    frequency: 1
