# Scheduler for training / re-training a model using quantization aware training, with a linear, range-based quantizer
#
# The setting here is 8-bit weights and activations. For vision models, this is usually applied to the entire model,
# without exceptions. Hence, this scheduler isn't model-specific as-is. It doesn't define any name-based overrides.
#
# At the moment this quantizer will:
#  * Quantize weights and biases for all convolution and FC layers
#  * Quantize all ReLU activations
#
# Here's an example run for fine-tuning the ResNet-18 model from torchvision:
#
# python compress_classifier.py -a resnet18 -p 50 -b 256 ~/datasets/imagenet --epochs 10 --compress=../quantization/quant_aware_train/quant_aware_train_linear_quant.yaml --pretrained -j 22 --lr 0.0001 --vs 0 --gpu 0
#
# After 6 epochs we get:
#
# 2018-11-22 20:41:03,662 - --- validate (epoch=6)-----------
# 2018-11-22 20:41:03,663 - 50000 samples (256 per mini-batch)
# 2018-11-22 20:41:23,507 - Epoch: [6][   50/  195]    Loss 0.896985    Top1 76.320312    Top5 93.460938
# 2018-11-22 20:41:33,633 - Epoch: [6][  100/  195]    Loss 1.026040    Top1 74.007812    Top5 91.984375
# 2018-11-22 20:41:44,142 - Epoch: [6][  150/  195]    Loss 1.168643    Top1 71.197917    Top5 90.041667
# 2018-11-22 20:41:51,505 - ==> Top1: 70.188    Top5: 89.376    Loss: 1.223
#
# This is an improvement compared to the pre-trained torchvision model:
# 2018-11-07 15:45:53,435 - ==> Top1: 69.758    Top5: 89.078    Loss: 1.251
#
# (Note that the command line above is not using --deterministic, so results could vary a little bit)

quantizers:
  linear_quantizer:
    class: QuantAwareTrainRangeLinearQuantizer
    bits_activations: 8
    bits_weights: 8
    mode: 'ASYMMETRIC_UNSIGNED'  # Can try "SYMMETRIC" as well
    ema_decay: 0.999   # Decay value for exponential moving average tracking of activation ranges
    per_channel_wts: True

policies:
    - quantizer:
        instance_name: linear_quantizer
      # For now putting a large range here, which should cover both training from scratch or resuming from some
      # pre-trained checkpoint at some unknown epoch
      starting_epoch: 0
      ending_epoch: 300
      frequency: 1
