# SSL: Channel regularization
# In this experiment we increase the regulization_strength of some of the channel regularization terms.
# We want to increase the compute compression, while allowing some reduction in the accuracy performance.
# We get 55.3% compute density (x1.81 less compute) with Test Top1: 90.59 after training, and 91.02
# after fine-tuning.
# Our baseline benchmark is Test Top1: 91.78 (@ Total MACs: 40,813,184)
#
# Total MACs: 22,583,936 == 55.3% compute density
#
# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../ssl/ssl_channels-removal_training_x1.8.yaml -j=1 --deterministic
#
# To fine-tune:
# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning.yaml -j=1 --deterministic --resume=...
#
# Parameters:
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.47378 | -0.00232 |    0.34033 |
# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16038 | -0.01380 |    0.07922 |
# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           1152 |    0.00000 |    0.00000 | 50.00000 | 50.00000 |  0.00000 |   50.00000 | 0.02244 |  0.00172 |    0.01011 |
# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10908 | -0.00650 |    0.04960 |
# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |            864 |    0.00000 |    0.00000 | 62.50000 | 62.50000 |  0.00000 |   62.50000 | 0.00781 | -0.00051 |    0.00293 |
# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14862 | -0.00256 |    0.06741 |
# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           1008 |    0.00000 |    0.00000 | 56.25000 | 56.25000 |  0.00000 |   56.25000 | 0.01798 | -0.00027 |    0.00779 |
# |  7 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14366 | -0.00874 |    0.09717 |
# |  8 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           7488 |    0.00000 |    0.00000 | 18.75000 | 18.75000 |  0.00000 |   18.75000 | 0.03211 | -0.00116 |    0.02158 |
# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.30729 | -0.03178 |    0.20665 |
# | 10 | module.layer2.1.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08802 | -0.00641 |    0.04504 |
# | 11 | module.layer2.1.conv2.weight        | (32, 32, 3, 3) |          9216 |           4032 |    0.00000 |    0.00000 | 56.25000 | 56.25000 |  0.00000 |   56.25000 | 0.00761 | -0.00009 |    0.00363 |
# | 12 | module.layer2.2.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05628 | -0.00200 |    0.01968 |
# | 13 | module.layer2.2.conv2.weight        | (32, 32, 3, 3) |          9216 |           1728 |    0.00000 |    0.00000 | 81.25000 | 81.25000 |  0.00000 |   81.25000 | 0.00369 | -0.00006 |    0.00104 |
# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10371 | -0.01132 |    0.07409 |
# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          30528 |    0.00000 |    0.00000 | 17.18750 | 17.18750 |  0.00000 |   17.18750 | 0.03098 | -0.00022 |    0.02149 |
# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18813 | -0.02071 |    0.14541 |
# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05896 | -0.00535 |    0.03235 |
# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          17280 |    0.00000 |    0.00000 | 53.12500 | 53.12500 |  0.00000 |   53.12500 | 0.00718 |  0.00004 |    0.00369 |
# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10843 | -0.01200 |    0.08641 |
# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07188 | -0.00414 |    0.05606 |
# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.67091 | -0.00003 |    0.54928 |
# | 22 | Total sparsity:                     | -              |        270896 |         226688 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   16.31918 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# 2018-09-20 16:34:17,976 - Total sparsity: 16.32
#
# 2018-09-20 16:34:17,976 - --- validate (epoch=179)-----------
# 2018-09-20 16:34:17,976 - 5000 samples (256 per mini-batch)
# 2018-09-20 16:34:19,474 - ==> Top1: 90.240    Top5: 99.820    Loss: 0.317
#
# 2018-09-20 16:34:19,475 - ==> using cifar10 dataset
# 2018-09-20 16:34:19,475 - => creating resnet56_cifar model for CIFAR10
# 2018-09-20 16:34:19,586 - Invoking create_thinning_recipe_channels
# 2018-09-20 16:34:19,588 - In tensor module.layer1.0.conv2.weight found 8/16 zero channels
# 2018-09-20 16:34:19,590 - In tensor module.layer1.1.conv2.weight found 10/16 zero channels
# 2018-09-20 16:34:19,592 - In tensor module.layer1.2.conv2.weight found 9/16 zero channels
# 2018-09-20 16:34:19,594 - In tensor module.layer2.0.conv2.weight found 6/32 zero channels
# 2018-09-20 16:34:19,597 - In tensor module.layer2.1.conv2.weight found 18/32 zero channels
# 2018-09-20 16:34:19,599 - In tensor module.layer2.2.conv2.weight found 26/32 zero channels
# 2018-09-20 16:34:19,601 - In tensor module.layer3.0.conv2.weight found 11/64 zero channels
# 2018-09-20 16:34:19,604 - In tensor module.layer3.1.conv2.weight found 34/64 zero channels
# 2018-09-20 16:34:19,615 - Created, applied and saved a thinning recipe
# 2018-09-20 16:34:19,616 - ==> Best Top1: 90.820   On Epoch: 144
#
# 2018-09-20 16:34:19,616 - Saving checkpoint to: logs/removal_training_x1.8___2018.09.20-160134/removal_training_x1.8_checkpoint.pth.tar
# 2018-09-20 16:34:19,628 - --- test ---------------------
# 2018-09-20 16:34:19,629 - 10000 samples (256 per mini-batch)
# 2018-09-20 16:34:21,490 - ==> Top1: 90.590    Top5: 99.660    Loss: 0.333



lr_schedulers:
  training_lr:
    class: StepLR
    step_size: 45
    gamma: 0.10

regularizers:
  Channels_groups_regularizer:
    class: GroupLassoRegularizer
    reg_regims:
      module.layer1.0.conv2.weight: [0.0014, Channels]
      module.layer1.1.conv2.weight: [0.0014, Channels]
      module.layer1.2.conv2.weight: [0.0012, Channels]
      module.layer2.0.conv2.weight: [0.0008, Channels]   # sensitive
      module.layer2.1.conv2.weight: [0.0014, Channels]
      module.layer2.2.conv2.weight: [0.0014, Channels]
      module.layer3.0.conv2.weight: [0.0008, Channels]  # sensitive
      module.layer3.1.conv2.weight: [0.0014, Channels]
      #module.layer3.2.conv2.weight: [0.0006, Channels] # very sensitive
    threshold_criteria: Mean_Abs

extensions:
  net_thinner:
      class: 'ChannelRemover'
      thinning_func_str: remove_channels
      arch: 'resnet20_cifar'
      dataset: 'cifar10'

policies:
  - lr_scheduler:
      instance_name: training_lr
    starting_epoch: 45
    ending_epoch: 300
    frequency: 1

# After completeing the regularization, we perform network thinning and exit.
  - extension:
      instance_name: net_thinner
    epochs: [179]

  - regularizer:
      instance_name: Channels_groups_regularizer
      args:
        keep_mask: True
    starting_epoch: 0
    ending_epoch: 180
    frequency: 1
