# We used this schedule to train CIFAR10-Plain20 from scratch.
#
# Plain-20 is defined in "Deep Residual Learning for Image Recognition".
# The configuration is based on section 4.2 of "Deep Residual Learning for Image Recognition":
#   "We use a weight decay of 0.0001 and momentum of 0.9, and adopt the weight initialization in [13] and BN [16] but
#   with no dropout. These models are trained with a mini batch size of 128 on two GPUs. We start with a learning
#   rate of 0.1, divide it by 10 at 32k and 48k iterations, and terminate training at 64k iterations, which is
#   determined on a 45k/5k train/val split. We follow the simple data augmentation in [24] for training: 4 pixels are
#   padded on each side, and a 32x32 crop is randomly sampled from the padded image or its horizontal flip. For testing,
#   we only evaluate the single view of the original 32x32 image.
#
# We translate "iterations" to "epochs" because Distiller schedules at the epoch granularity:
#   45K training samples / batch 128 == 351.6 iterations per epoch
#   32K iterations = 91 epochs
#   48K iterations = 152 epochs
#   64K iterations = 182 epochs
#
# Our target test Top1 is 90.5.  This is inferred from Figure 6 of "Deep Residual Learning for Image Recognition", and
# also the accuracy achieved in AMC, Table 2.
#
# References:
#   Yihui Hez , Ji Liny , Zhijian Liuy, Hanrui Wangy, Li-Jia Lil, and Song Han.
#   AMC: AutoML for Model Compression and Acceleration on Mobile Devices.
#   arXiv:1802.03494v3
#
#   Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun.
#   Deep Residual Learning for Image Recognition.
#   arXiv:1512.03385
#
#
# time python3  compress_classifier.py --arch=plain20_cifar ../../../data.cifar -p=50 --lr=0.1 --epochs=180 --batch=128 --compress=../baseline_networks/cifar/plain20_cifar_baseline_training.yaml --gpu=0 -j=1 --deterministic
#
# Results:
#   Top1 = 90.18 - which is 0.3% lower than owr goal.
#   *For better results, with much shorter training, see the explanation after the tables below.
#
# Parameters:
# +----+------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                         | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight          | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.40074 | -0.00071 |    0.29948 |
# |  1 | module.layer1.0.conv1.weight | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17085 | -0.01192 |    0.12854 |
# |  2 | module.layer1.0.conv2.weight | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17880 | -0.01883 |    0.13891 |
# |  3 | module.layer1.1.conv1.weight | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18079 | -0.00512 |    0.13792 |
# |  4 | module.layer1.1.conv2.weight | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17919 | -0.00807 |    0.13943 |
# |  5 | module.layer1.2.conv1.weight | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18720 | -0.01524 |    0.14466 |
# |  6 | module.layer1.2.conv2.weight | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18216 | -0.00676 |    0.14077 |
# |  7 | module.layer2.0.conv1.weight | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14269 | -0.00945 |    0.10973 |
# |  8 | module.layer2.0.conv2.weight | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13417 | -0.00725 |    0.10532 |
# |  9 | module.layer2.1.conv1.weight | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13663 | -0.00779 |    0.10872 |
# | 10 | module.layer2.1.conv2.weight | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13405 | -0.00875 |    0.10667 |
# | 11 | module.layer2.2.conv1.weight | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12706 | -0.01244 |    0.10117 |
# | 12 | module.layer2.2.conv2.weight | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12098 | -0.00570 |    0.09625 |
# | 13 | module.layer3.0.conv1.weight | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09828 | -0.00750 |    0.07821 |
# | 14 | module.layer3.0.conv2.weight | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09798 | -0.00763 |    0.07826 |
# | 15 | module.layer3.1.conv1.weight | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.10033 | -0.00928 |    0.08020 |
# | 16 | module.layer3.1.conv2.weight | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08956 | -0.01220 |    0.07165 |
# | 17 | module.layer3.2.conv1.weight | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07004 | -0.01346 |    0.05663 |
# | 18 | module.layer3.2.conv2.weight | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04815 |  0.00177 |    0.03756 |
# | 19 | module.fc.weight             | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.54818 | -0.00011 |    0.50385 |
# | 20 | Total sparsity:              | -              |        268336 |         268336 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
# +----+------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 0.00
#
# --- validate (epoch=179)-----------
# 5000 samples (128 per mini-batch)
# Epoch: [179][   10/   39]    Loss 0.391990    Top1 89.062500    Top5 99.609375
# Epoch: [179][   20/   39]    Loss 0.373019    Top1 89.960938    Top5 99.453125
# Epoch: [179][   30/   39]    Loss 0.371198    Top1 90.182292    Top5 99.453125
# Epoch: [179][   40/   39]    Loss 0.360783    Top1 90.100000    Top5 99.440000
# ==> Top1: 90.100    Top5: 99.440    Loss: 0.361
#
# ==> Best Top1: 90.540 on Epoch: 163
# Saving checkpoint to: logs/2018.12.11-134350/checkpoint.pth.tar
# --- test ---------------------
# 10000 samples (128 per mini-batch)
# Test: [   10/   78]    Loss 0.410806    Top1 89.609375    Top5 99.531250
# Test: [   20/   78]    Loss 0.438778    Top1 89.218750    Top5 99.296875
# Test: [   30/   78]    Loss 0.419225    Top1 89.791667    Top5 99.427083
# Test: [   40/   78]    Loss 0.421272    Top1 89.921875    Top5 99.472656
# Test: [   50/   78]    Loss 0.409017    Top1 90.046875    Top5 99.562500
# Test: [   60/   78]    Loss 0.401275    Top1 90.169271    Top5 99.583333
# Test: [   70/   78]    Loss 0.400794    Top1 90.111607    Top5 99.609375
# ==> Top1: 90.180    Top5: 99.630    Loss: 0.401
#
#
# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.12.11-134350/2018.12.11-134350.log
#
# real    47m26.710s
# user    124m30.606s
# sys     21m1.999s
#
#
# We can achieve a better Top1 result, with faster training by doubling the batch-size to 256, and increasing the initial
# learning-rate to 0.3.
#
# time python3 compress_classifier.py --arch=plain20_cifar ../../../data.cifar --lr=0.3 --epochs=180 --batch=256 --compress=../automated_deep_compression/plain20_cifar_baseline_training.yaml -j=1 --deterministic
#
# Results:
#   Top1 = 90.55
#
# Parameters:
# +----+------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                         | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight          | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.37799 | -0.00179 |    0.27913 |
# |  1 | module.layer1.0.conv1.weight | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17171 | -0.01391 |    0.12635 |
# |  2 | module.layer1.0.conv2.weight | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17004 | -0.01753 |    0.13081 |
# |  3 | module.layer1.1.conv1.weight | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16450 |  0.00003 |    0.12702 |
# |  4 | module.layer1.1.conv2.weight | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.16691 | -0.01517 |    0.13133 |
# |  5 | module.layer1.2.conv1.weight | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.17618 | -0.00955 |    0.13691 |
# |  6 | module.layer1.2.conv2.weight | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.18617 | -0.00262 |    0.14352 |
# |  7 | module.layer2.0.conv1.weight | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14799 |  0.00621 |    0.11439 |
# |  8 | module.layer2.0.conv2.weight | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13378 | -0.00616 |    0.10422 |
# |  9 | module.layer2.1.conv1.weight | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.13473 | -0.00722 |    0.10616 |
# | 10 | module.layer2.1.conv2.weight | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12899 | -0.01044 |    0.10220 |
# | 11 | module.layer2.2.conv1.weight | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12227 | -0.00908 |    0.09684 |
# | 12 | module.layer2.2.conv2.weight | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.11781 | -0.00567 |    0.09389 |
# | 13 | module.layer3.0.conv1.weight | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09735 | -0.00409 |    0.07706 |
# | 14 | module.layer3.0.conv2.weight | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09819 | -0.00773 |    0.07806 |
# | 15 | module.layer3.1.conv1.weight | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.09940 | -0.00979 |    0.07937 |
# | 16 | module.layer3.1.conv2.weight | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.08860 | -0.01071 |    0.07082 |
# | 17 | module.layer3.2.conv1.weight | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06713 | -0.01385 |    0.05436 |
# | 18 | module.layer3.2.conv2.weight | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04638 |  0.00117 |    0.03604 |
# | 19 | module.fc.weight             | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.53306 | -0.00001 |    0.48409 |
# | 20 | Total sparsity:              | -              |        268336 |         268336 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
# +----+------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 0.00
#
# --- validate (epoch=179)-----------
# 5000 samples (256 per mini-batch)
# Epoch: [179][   10/   19]    Loss 0.369681    Top1 90.000000    Top5 99.687500
# Epoch: [179][   20/   19]    Loss 0.414264    Top1 89.640000    Top5 99.560000
# ==> Top1: 89.640    Top5: 99.560    Loss: 0.414
#
# ==> Best Top1: 90.200 on Epoch: 118
# Saving checkpoint to: logs/2018.12.11-160023/checkpoint.pth.tar
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# Test: [   10/   39]    Loss 0.424408    Top1 90.195312    Top5 99.453125
# Test: [   20/   39]    Loss 0.411984    Top1 90.390625    Top5 99.550781
# Test: [   30/   39]    Loss 0.394514    Top1 90.546875    Top5 99.609375
# Test: [   40/   39]    Loss 0.409653    Top1 90.550000    Top5 99.650000
# ==> Top1: 90.550    Top5: 99.650    Loss: 0.410
#
# real    33m57.110s
# user    69m57.694s
# sys     9m49.416s

lr_schedulers:
  training_lr:
    class: MultiStepLR
    milestones: [60, 120, 160]
    gamma: 0.20

policies:
    - lr_scheduler:
        instance_name: training_lr
      starting_epoch: 0
      ending_epoch: 200
      frequency: 1
