# SSL: Channel regularization
# We found the regularization to be "unstable": when we raise the regularization strength on one parameter's Channels,
# other parameters become denser.  It is very hard to direct the level and location of sparsity.
# We also found channel regualarization to be hard to accomplish in general, and we think this is due to the large size
# of the channel structure.
#
# We save the output (i.e. checkpoint.pth.tar) in ../ssl/checkpoints/checkpoint_trained_channel_regularized_resnet20.pth.tar
#
# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../ssl/ssl_channels-removal_training.yaml -j=1 --deterministic
#
# Parameters:
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.43917 | -0.00063 |    0.30806 |
# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15973 | -0.00975 |    0.08835 |
# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |           1584 |    0.00000 |    0.00000 | 31.25000 | 31.25000 |  0.00000 |   31.25000 | 0.04412 |  0.00146 |    0.02351 |
# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12326 | -0.01273 |    0.06584 |
# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |           1152 |    0.00000 |    0.00000 | 50.00000 | 50.00000 |  0.00000 |   50.00000 | 0.02595 |  0.00153 |    0.01196 |
# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13277 | -0.00825 |    0.07012 |
# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           1296 |    0.00000 |    0.00000 | 43.75000 | 43.75000 |  0.00000 |   43.75000 | 0.02046 | -0.00019 |    0.00930 |
# |  7 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12530 | -0.00930 |    0.07898 |
# |  8 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           6336 |    0.00000 |    0.00000 | 31.25000 | 31.25000 |  0.00000 |   31.25000 | 0.02347 | -0.00049 |    0.01420 |
# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.34174 | -0.03220 |    0.24618 |
# | 10 | module.layer2.1.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07267 | -0.00824 |    0.04494 |
# | 11 | module.layer2.1.conv2.weight        | (32, 32, 3, 3) |          9216 |           5760 |    0.00000 |    0.00000 | 37.50000 | 37.50000 |  0.00000 |   37.50000 | 0.00942 | -0.00002 |    0.00535 |
# | 12 | module.layer2.2.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13728 | -0.01268 |    0.10712 |
# | 13 | module.layer2.2.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12001 | -0.00505 |    0.09541 |
# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09361 | -0.01169 |    0.06698 |
# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          29952 |    0.00000 |    0.00000 | 18.75000 | 18.75000 |  0.00000 |   18.75000 | 0.02503 | -0.00022 |    0.01723 |
# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.19990 | -0.02006 |    0.15585 |
# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08028 | -0.00894 |    0.05929 |
# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          32256 |    0.00000 |    0.00000 | 12.50000 | 12.50000 |  0.00000 |   12.50000 | 0.01544 | -0.00052 |    0.01097 |
# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09532 | -0.00895 |    0.07573 |
# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06537 | -0.00463 |    0.05136 |
# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.63531 | -0.00003 |    0.52549 |
# | 22 | Total sparsity:                     | -              |        270896 |         250160 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    7.65460 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# 2018-06-10 17:20:18,750 - Total sparsity: 7.65
#
# 2018-06-10 17:20:18,751 - --- validate (epoch=179)-----------
# 2018-06-10 17:20:18,751 - 5000 samples (256 per mini-batch)
# 2018-06-10 17:20:20,342 - ==> Top1: 90.700    Top5: 99.680    Loss: 0.337
#
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# ==> Top1: 91.140    Top5: 99.690    Loss: 0.346
#
#
# real    36m12.734s
# user    78m32.813s
# sys     10m38.612s



lr_schedulers:
  training_lr:
    class: StepLR
    step_size: 45
    gamma: 0.10

regularizers:
  Channels_groups_regularizer:
    class: GroupLassoRegularizer
    reg_regims:
      #module.layer1.0.conv1.weight: [0.0008, Channels]
      module.layer1.0.conv2.weight: [0.0008, Channels]
      #module.layer1.1.conv1.weight: [0.0008, Channels]
      module.layer1.1.conv2.weight: [0.0008, Channels]
      #module.layer1.2.conv1.weight: [0.0008, Channels]
      module.layer1.2.conv2.weight: [0.0008, Channels]
      #module.layer2.0.conv1.weight: [0.0008, Channels]
      module.layer2.0.conv2.weight: [0.0008, Channels]
      #module.layer2.1.conv1.weight: [0.0008, Channels]
      module.layer2.1.conv2.weight: [0.0008, Channels]
      #module.layer2.1.conv1.weight: [0.0008, Channels]
      module.layer2.1.conv2.weight: [0.0008, Channels]
      #module.layer3.0.conv1.weight: [0.0008, Channels]
      module.layer3.0.conv2.weight: [0.0008, Channels]
      #module.layer3.1.conv1.weight: [0.0008, Channels]
      module.layer3.1.conv2.weight: [0.0008, Channels]
      #module.layer3.2.conv1.weight: [0.0008, Channels]
      module.layer3.1.conv2.weight: [0.0008, Channels]
    threshold_criteria: Mean_Abs

extensions:
  net_thinner:
      class: 'ChannelRemover'
      thinning_func_str: remove_channels
      arch: 'resnet20_cifar'
      dataset: 'cifar10'

policies:
  - lr_scheduler:
      instance_name: training_lr
    starting_epoch: 45
    ending_epoch: 300
    frequency: 1

# After completeing the regularization, we perform network thinning and exit.
  - extension:
      instance_name: net_thinner
    epochs: [179]

  - regularizer:
      instance_name: Channels_groups_regularizer
      args:
        keep_mask: True
    starting_epoch: 0
    ending_epoch: 180
    frequency: 1
