# Scheduler for training AlexNet with Batch-Norm, quantized using the DoReFa scheme
# Activations: 8-bits, Weights: 4-bits
# See:
#  https://nervanasystems.github.io/distiller/algo_quantization/index.html#dorefa
#  https://arxiv.org/abs/1606.06160
#
# IMPORTANT NOTES:
# ----------------
#  1. Pay attention that this is not the original AlexNet, but AlexNet w. batch normalization layers.
#     See model implementation in <distiller-root>/distiller/models/imagenet/alexnet_batchnorm.py
#  2. The best results are achieved with the Adam optimizer. As is, the optimizer used in the image classification
#     sample is SGD and this is not configurable. So we need to edit the code:
#     * Open <distiller-root>/distiller/apputils/image_classifier.py
#     * In the function "_init_learner(args)", find the following snippet:
#         optimizer = torch.optim.SGD(model.parameters(),
#             lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
#       And replace it with:
#         optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
#
# Command line for training (running from the compress_classifier.py directory):
# python compress_classifier.py --arch alexnet_bn <path_to_imagenet_dataset> -p=50 --epochs=110 --compress=../quantization/quant_aware_train/alexnet_bn_dorefa.yaml -j 22 --lr 0.0002 --wd 0.0001 --vs 0
#
# After 110 epochs we get:
# --- test ---------------------
# 50000 samples (256 per mini-batch)
# Test: [   50/  195]    Loss 1.645975    Top1 61.453125    Top5 83.437500
# Test: [  100/  195]    Loss 1.635810    Top1 61.507812    Top5 83.445312
# Test: [  150/  195]    Loss 1.633975    Top1 61.479167    Top5 83.419271
# ==> Top1: 61.498    Top5: 83.454    Loss: 1.635
#
# So that's 61.498%, compared to 61.914% for the FP32 model

quantizers:
  dorefa_quantizer:
    class: DorefaQuantizer
    bits_activations: 8
    bits_weights: 4
    overrides:
    # Don't quantize first and last layer
      features.0:
        bits_weights: null
        bits_activations: null
      features.1:
        bits_weights: null
        bits_activations: null
      classifier.5:
        bits_weights: null
        bits_activations: null
      classifier.6:
        bits_weights: null
        bits_activations: null

lr_schedulers:
  training_lr:
    class: MultiStepLR
    milestones: [60, 75]
    gamma: 0.2

policies:
    - quantizer:
        instance_name: dorefa_quantizer
      starting_epoch: 0
      ending_epoch: 200
      frequency: 1

    - lr_scheduler:
        instance_name: training_lr
      starting_epoch: 0
      ending_epoch: 200
      frequency: 1
