# Fine-tuning after channel regularization training (SSL) of ResNet20-CIFAR10.
# In this experiment we increase the regulization_strength of some of the channel regularization terms.
# We want to increase the compute compression, while allowing some reduction in the accuracy performance.
# We get 55.3% compute density (x1.81 less compute) with Test Top1: 90.59 after training, and 91.02
# after fine-tuning.
# Our baseline benchmark is Test Top1: 91.78 (@ Total MACs: 40,813,184)
#
# Total MACs: 22,583,936 == 55.3% compute density
#
# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.2 --epochs=98 --compress=../ssl/ssl_channels-removal_finetuning_x1.8.yaml --reset-optimizer --resume-from=../ssl/checkpoints/checkpoint_trained_channel_regularized_resnet20.pth.tar
#
# Parameters:
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.45998 | -0.00238 |    0.32889 |
# |  1 | module.layer1.0.conv1.weight        | (8, 16, 3, 3)  |          1152 |           1152 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.22484 | -0.02487 |    0.15544 |
# |  2 | module.layer1.0.conv2.weight        | (16, 8, 3, 3)  |          1152 |           1152 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18552 |  0.01769 |    0.12981 |
# |  3 | module.layer1.1.conv1.weight        | (6, 16, 3, 3)  |           864 |            864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17903 | -0.01568 |    0.12988 |
# |  4 | module.layer1.1.conv2.weight        | (16, 6, 3, 3)  |           864 |            864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13164 | -0.01038 |    0.09039 |
# |  5 | module.layer1.2.conv1.weight        | (7, 16, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.22172 | -0.00509 |    0.14915 |
# |  6 | module.layer1.2.conv2.weight        | (16, 7, 3, 3)  |          1008 |           1008 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16140 | -0.00446 |    0.11207 |
# |  7 | module.layer2.0.conv1.weight        | (26, 16, 3, 3) |          3744 |           3744 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16133 | -0.00875 |    0.12191 |
# |  8 | module.layer2.0.conv2.weight        | (32, 26, 3, 3) |          7488 |           7488 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12494 | -0.00761 |    0.09602 |
# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.29094 | -0.02631 |    0.19351 |
# | 10 | module.layer2.1.conv1.weight        | (14, 32, 3, 3) |          4032 |           4032 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13983 | -0.01593 |    0.10753 |
# | 11 | module.layer2.1.conv2.weight        | (32, 14, 3, 3) |          4032 |           4032 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09729 |  0.00072 |    0.07334 |
# | 12 | module.layer2.2.conv1.weight        | (6, 32, 3, 3)  |          1728 |           1728 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13990 | -0.01228 |    0.10862 |
# | 13 | module.layer2.2.conv2.weight        | (32, 6, 3, 3)  |          1728 |           1728 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08845 | -0.00248 |    0.06464 |
# | 14 | module.layer3.0.conv1.weight        | (53, 32, 3, 3) |         15264 |          15264 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11833 | -0.01082 |    0.09355 |
# | 15 | module.layer3.0.conv2.weight        | (64, 53, 3, 3) |         30528 |          30528 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10433 | -0.00275 |    0.08174 |
# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17084 | -0.01491 |    0.13004 |
# | 17 | module.layer3.1.conv1.weight        | (30, 64, 3, 3) |         17280 |          17280 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09878 | -0.00955 |    0.07732 |
# | 18 | module.layer3.1.conv2.weight        | (64, 30, 3, 3) |         17280 |          17280 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07864 | -0.00297 |    0.05934 |
# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10865 | -0.01132 |    0.08681 |
# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06811 | -0.00286 |    0.05318 |
# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.65814 | -0.00001 |    0.54548 |
# | 22 | Total sparsity:                     | -              |        186512 |         186512 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 0.00
#
# --- validate (epoch=277)-----------
# 5000 samples (256 per mini-batch)
# ==> Top1: 90.340    Top5: 99.680    Loss: 0.332
#
# ==> Best Top1: 91.040   On Epoch: 270
#
# Saving checkpoint to: logs/2018.09.20-170103/checkpoint.pth.tar
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# ==> Top1: 91.020    Top5: 99.720    Loss: 0.358
#
#
# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.09.20-170103/2018.09.20-170103.log
#
# real    17m47.293s
# user    40m0.278s
# sys     5m20.199s

lr_schedulers:
  training_lr:
    class: StepLR
    step_size: 45
    gamma: 0.10

policies:
  - lr_scheduler:
      instance_name: training_lr
    starting_epoch: 0
    ending_epoch: 300
    frequency: 1
