# Scheduler for training pre-activation ResNet on CIFAR-10, quantized using the DoReFa scheme
# See:
#  https://nervanasystems.github.io/distiller/algo_quantization/index.html#dorefa
#  https://arxiv.org/abs/1606.06160
# 
# Applicable to ResNet 20 / 32 / 44 / 56 / 110
# 
# Command line for training (running from the compress_classifier.py directory):
# python compress_classifier.py -a preact_resnet20_cifar --lr 0.1 -p 50 -b 128 <path_to_cifar10_dataset> -j 1 --epochs 200 --compress=../quantization/quant_aware_train/preact_resnet_cifar_dorefa.yaml --wd=0.0002 --vs=0 --gpus 0
#
# Notes:
#  * In '-a preact_resnet20_cifar', replace '20' with the required depth
#  * '--wd=0.0002': Weight decay of 0.0002 is used
#  * '--vs=0': We train on the entire training dataset, and validate using the test set
#
# Knowledge Distillation:
# -----------------------
# To train these models with knowledge distillation, add the following arguments to the command line:
# --kd-teacher preact_resnet44_cifar --kd-resume <path_to_teacher_model_checkpoint> --kd-temp 5.0 --kd-dw 0.7 --kd-sw 0.3
#
# Notes:
#  * Replace 'preact_resnet44_cifar' with the required teacher model
#  * To train baseline FP32 that can be used as teacher models, see preact_resnet_cifar_base_fp32.yaml
#  * In this example we're using a distillation temperature of 5.0, and we give a weight of 0.7 to the distillation loss
#    (that is - the loss of the student predictions vs. the teacher's soft targets).
#  * Note we don't change any of the other training hyper-parameters
#  * More details on knowledge distillation at: 
#    https://nervanasystems.github.io/distiller/schedule/index.html#knowledge-distillation
#
# See some experimental results with the hyper-parameters shown above after the YAML schedule

quantizers:
  dorefa_quantizer:
    class: DorefaQuantizer
    bits_activations: 8
    bits_weights: 3
    overrides:
    # Don't quantize first and last layer
      conv1:
        bits_weights: null
        bits_activations: null
      layer1.0.pre_relu:
        bits_weights: null
        bits_activations: null
      final_relu:
        bits_weights: null
        bits_activations: null
      fc:
        bits_weights: null
        bits_activations: null

lr_schedulers:
  training_lr:
    class: MultiStepMultiGammaLR
    milestones: [80, 120, 160]
    gammas: [0.1, 0.1, 0.2]

policies:
    - quantizer:
        instance_name: dorefa_quantizer
      starting_epoch: 0
      ending_epoch: 200
      frequency: 1

    - lr_scheduler:
        instance_name: training_lr
      starting_epoch: 0
      ending_epoch: 161
      frequency: 1

# The results listed here are based on 4 runs in each configuration. All results are Top-1:
#
# +-------+--------------+-------------------------+-------------------------+
# |       |              |           FP32          |       DoReFa w3-a8      |
# +-------+--------------+-------------------------+-------------------------+
# | Depth | Distillation | Best  | Worst | Average | Best  | Worst | Average |
# |       | Teacher      |       |       |         |       |       |         |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 20    | None         | 92.4  | 91.91 | 92.2225 | 91.87 | 91.34 | 91.605  |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 20    | 32           | 92.85 | 92.68 | 92.7375 | 92.16 | 91.96 | 92.0725 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 20    | 44           | 93.09 | 92.64 | 92.795  | 92.54 | 91.9  | 92.2225 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 20    | 56           | 92.77 | 92.52 | 92.6475 | 92.53 | 91.92 | 92.15   |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 20    | 110          | 92.87 | 92.66 | 92.7725 | 92.12 | 92.01 | 92.0825 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# |       |              |       |       |         |       |       |         |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 32    | None         | 93.31 | 92.93 | 93.13   | 92.66 | 92.33 | 92.485  |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 32    | 44           | 93.54 | 93.35 | 93.48   | 93.41 | 93.2  | 93.2875 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 32    | 56           | 93.58 | 93.47 | 93.5125 | 93.18 | 92.76 | 92.93   |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 32    | 110          | 93.6  | 93.29 | 93.4575 | 93.36 | 92.99 | 93.175  |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# |       |              |       |       |         |       |       |         |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 44    | None         | 94.07 | 93.5  | 93.7425 | 93.08 | 92.66 | 92.8125 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 44    | 56           | 94.08 | 93.58 | 93.875  | 93.46 | 93.28 | 93.3875 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 44    | 110          | 94.13 | 93.75 | 93.95   | 93.45 | 93.24 | 93.3825 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# |       |              |       |       |         |       |       |         |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 56    | None         | 94.2  | 93.52 | 93.8    | 93.44 | 92.91 | 93.0975 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 56    | 110          | 94.47 | 94.0  | 94.16   | 93.83 | 93.56 | 93.7225 |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# |       |              |       |       |         |       |       |         |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
# | 110   | None         | 94.66 | 94.42 | 94.54   | 93.53 | 93.24 | 93.395  |
# +-------+--------------+-------+-------+---------+-------+-------+---------+
