# Fine grained (element-wise) pruning using RNN pruning scheduling for PyTorch's example Word Language model.
# The pruning schedule is based on the following paper from ICLR 2017:
#    Narang, Sharan & Diamos, Gregory & Sengupta, Shubho & Elsen, Erich. (2017).
#    Exploring Sparsity in Recurrent Neural Networks.
#    (https://arxiv.org/abs/1704.05119)
#
# The README of PyTorch's word language model example code, promises that this configuration will produce a Test perplexity
# of 72.30, while I was only able to get 84.23, so I use that as the baseline for comparison.
#
# Baseline generation:
# time python3 main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 --tied
#
# Pruning:
# python3 main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 --tied --compress=../../examples/baidu-rnn-pruning/word_lang_model.schedule_baidu_rnn.yaml
#
# The Baidu pruner uses a value they refer to as 'q':
#   "In order to determine q in equation 1, we use an existing weight array from a previously trained
#    model. The weights are sorted using absolute values and we pick the weight corresponding to the
#    90th percentile as q."
#
# To determine this 'q' value we first train the baseline network (saved in model.emsize1500.nhid1500.dropout065.tied.pt),
# and then extract the statistics:
# python3 main.py --cuda --resume=model.emsize1500.nhid1500.dropout065.tied.pt --summary=percentile
#
# parameter encoder.weight: q = 0.16
# parameter rnn.weight_ih_l0: q = 0.17
# parameter rnn.weight_hh_l0: q = 0.11
# parameter rnn.weight_ih_l1: q = 0.18
# parameter rnn.weight_hh_l1: q = 0.15
# parameter decoder.weight: q = 0.16
#
# To save you time, you can download a pretrained model from here:
# https://s3-us-west-1.amazonaws.com/nndistiller/agp-pruning/word_language_model/model.emsize1500.nhid1500.dropout065.tied.pt
#

version: 1
pruners:
  ih_l0_rnn_pruner:
    class: BaiduRNNPruner
    q: 0.17
    ramp_epoch_offset: 3
    ramp_slope_mult: 2
    weights: [rnn.weight_ih_l0]

  hh_l0_rnn_pruner:
    class: BaiduRNNPruner
    q: 0.11
    ramp_epoch_offset: 3
    ramp_slope_mult: 2
    weights: [rnn.weight_hh_l0]

  ih_l1_rnn_pruner:
    class: BaiduRNNPruner
    q: 0.18
    ramp_epoch_offset: 3
    ramp_slope_mult: 2
    weights: [rnn.weight_ih_l1]

  hh_l1_rnn_pruner:
    class: BaiduRNNPruner
    q: 0.15
    ramp_epoch_offset: 3
    ramp_slope_mult: 2
    weights: [rnn.weight_hh_l1]

  embedding_pruner:
    class: BaiduRNNPruner
    q: 0.16
    ramp_epoch_offset: 3
    ramp_slope_mult: 2
    weights: [encoder.weight]

policies:
  - pruner:
      instance_name : ih_l0_rnn_pruner
    starting_epoch: 4
    ending_epoch: 21
    frequency: 3

  - pruner:
      instance_name : hh_l0_rnn_pruner
    starting_epoch: 4
    ending_epoch: 21
    frequency: 3

  - pruner:
      instance_name : ih_l1_rnn_pruner
    starting_epoch: 5
    ending_epoch: 22
    frequency: 3

  - pruner:
      instance_name : hh_l1_rnn_pruner
    starting_epoch: 5
    ending_epoch: 22
    frequency: 3

  - pruner:
      instance_name : embedding_pruner
    starting_epoch: 6
    ending_epoch: 23
    frequency: 3
