# Fine grained (element-wise) pruning using Automated Gradual Pruner scheduling for ResNet18 using Imagenet dataset.
# 1. The Top1 result (69.872) is better than the TorchVision baseline (69.76) from which we start.
# 2. Note the high rate of 2D sparsity in some conv 3x3 layers; and module.layer1.0.conv1.weight even removes 8 channels of 64.
#
# time python3 compress_classifier.py -a=resnet18  -p=50 ../../../data.imagenet/ -j=22   --epochs=100 --lr=0.001 --compress=../agp-pruning/resnet18.schedule_agp.yaml --pretrained
#
# Parameters:
# +----+-------------------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# |    | Name                                | Shape            |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (64, 3, 7, 7)    |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12031 |  0.00002 |    0.07019 |
# |  1 | module.layer1.0.conv1.weight        | (64, 64, 3, 3)   |         36864 |           8848 |    0.00000 |    0.00000 | 12.50000 | 34.22852 |  0.00000 |   75.99826 | 0.04662 | -0.00260 |    0.01883 |
# |  2 | module.layer1.0.conv2.weight        | (64, 64, 3, 3)   |         36864 |           8848 |    0.00000 |    0.00000 |  0.00000 | 30.02930 |  0.00000 |   75.99826 | 0.03763 | -0.00066 |    0.01677 |
# |  3 | module.layer1.1.conv1.weight        | (64, 64, 3, 3)   |         36864 |          12166 |    0.00000 |    0.00000 |  0.00000 | 14.84375 |  0.00000 |   66.99761 | 0.04443 | -0.00199 |    0.02160 |
# |  4 | module.layer1.1.conv2.weight        | (64, 64, 3, 3)   |         36864 |          12166 |    0.00000 |    0.00000 |  0.00000 | 14.57520 |  0.00000 |   66.99761 | 0.03794 | -0.00105 |    0.01946 |
# |  5 | module.layer2.0.conv1.weight        | (128, 64, 3, 3)  |         73728 |          24331 |    0.00000 |    0.00000 |  0.00000 | 17.57812 |  0.00000 |   66.99897 | 0.03602 | -0.00093 |    0.01835 |
# |  6 | module.layer2.0.conv2.weight        | (128, 128, 3, 3) |        147456 |          58983 |    0.00000 |    0.00000 |  0.00000 |  4.83398 |  0.00000 |   59.99959 | 0.02995 | -0.00099 |    0.01620 |
# |  7 | module.layer2.0.downsample.0.weight | (128, 64, 1, 1)  |          8192 |           2704 |    0.00000 |    0.00000 |  0.00000 | 66.99219 |  0.78125 |   66.99219 | 0.06225 | -0.00253 |    0.02885 |
# |  8 | module.layer2.1.conv1.weight        | (128, 128, 3, 3) |        147456 |          35390 |    0.00000 |    0.00000 |  0.00000 | 23.21167 |  0.00000 |   75.99962 | 0.02837 | -0.00095 |    0.01250 |
# |  9 | module.layer2.1.conv2.weight        | (128, 128, 3, 3) |        147456 |          35390 |    0.00000 |    0.00000 |  0.00000 | 23.16284 |  0.00000 |   75.99962 | 0.02450 | -0.00079 |    0.01110 |
# | 10 | module.layer3.0.conv1.weight        | (256, 128, 3, 3) |        294912 |          70779 |    0.00000 |    0.00000 |  0.00000 | 31.27441 |  0.00000 |   75.99996 | 0.02386 | -0.00077 |    0.01063 |
# | 11 | module.layer3.0.conv2.weight        | (256, 256, 3, 3) |        589824 |         194642 |    0.00000 |    0.00000 |  0.00000 | 11.57684 |  0.00000 |   66.99999 | 0.02140 | -0.00055 |    0.01106 |
# | 12 | module.layer3.0.downsample.0.weight | (256, 128, 1, 1) |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03017 | -0.00168 |    0.02186 |
# | 13 | module.layer3.1.conv1.weight        | (256, 256, 3, 3) |        589824 |         235930 |    0.00000 |    0.00000 |  0.00000 |  6.06537 |  0.00000 |   59.99993 | 0.01956 | -0.00122 |    0.01110 |
# | 14 | module.layer3.1.conv2.weight        | (256, 256, 3, 3) |        589824 |         235930 |    0.00000 |    0.00000 |  0.00000 |  7.33185 |  0.00000 |   59.99993 | 0.01806 | -0.00111 |    0.01034 |
# | 15 | module.layer4.0.conv1.weight        | (512, 256, 3, 3) |       1179648 |         471860 |    0.00000 |    0.00000 |  0.00000 |  7.11365 |  0.00000 |   59.99993 | 0.01723 | -0.00116 |    0.01002 |
# | 16 | module.layer4.0.conv2.weight        | (512, 512, 3, 3) |       2359296 |         943719 |    0.00000 |    0.00000 |  0.00000 |  3.83797 |  0.00000 |   59.99997 | 0.01501 | -0.00095 |    0.00875 |
# | 17 | module.layer4.0.downsample.0.weight | (512, 256, 1, 1) |        131072 |          52429 |    0.00000 |    0.00000 |  0.00000 | 59.99985 |  0.00000 |   59.99985 | 0.02863 | -0.00054 |    0.01600 |
# | 18 | module.layer4.1.conv1.weight        | (512, 512, 3, 3) |       2359296 |         943719 |    0.00000 |    0.00000 |  0.00000 |  7.81326 |  0.00000 |   59.99997 | 0.01539 | -0.00166 |    0.00915 |
# | 19 | module.layer4.1.conv2.weight        | (512, 512, 3, 3) |       2359296 |         778568 |    0.00000 |    0.00000 |  0.00000 | 36.92665 |  0.00000 |   66.99999 | 0.01105 |  0.00002 |    0.00594 |
# | 20 | module.fc.weight                    | (1000, 512)      |        512000 |         512000 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06898 |  0.00001 |    0.05061 |
# | 21 | Total sparsity:                     | -                |      11678912 |        4680578 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   59.92283 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# Total sparsity: 59.92
#
# --- validate (epoch=99)-----------
# 128116 samples (256 per mini-batch)
# Epoch: [99][   50/  500]    Loss 1.377026    Top1 67.554688    Top5 86.484375
# Epoch: [99][  100/  500]    Loss 1.370474    Top1 67.761719    Top5 86.554688
# Epoch: [99][  150/  500]    Loss 1.377697    Top1 67.507812    Top5 86.398438
# Epoch: [99][  200/  500]    Loss 1.380634    Top1 67.414062    Top5 86.316406
# Epoch: [99][  250/  500]    Loss 1.381116    Top1 67.448437    Top5 86.282813
# Epoch: [99][  300/  500]    Loss 1.382881    Top1 67.437500    Top5 86.300781
# Epoch: [99][  350/  500]    Loss 1.380314    Top1 67.500000    Top5 86.368304
# Epoch: [99][  400/  500]    Loss 1.380993    Top1 67.465820    Top5 86.353516
# Epoch: [99][  450/  500]    Loss 1.383427    Top1 67.413194    Top5 86.314236
# Epoch: [99][  500/  500]    Loss 1.386347    Top1 67.363281    Top5 86.282813
# ==> Top1: 67.372    Top5: 86.286    Loss: 1.386
#
# Saving checkpoint
# --- test ---------------------
# 50000 samples (256 per mini-batch)
# Test: [   50/  195]    Loss 0.911991    Top1 76.109375    Top5 93.242188
# Test: [  100/  195]    Loss 1.040435    Top1 73.699219    Top5 91.812500
# Test: [  150/  195]    Loss 1.182236    Top1 70.924479    Top5 89.895833
# ==> Top1: 69.872    Top5: 89.162    Loss: 1.238
#
#
# Log file for this run: /data/home/cvds_lab/nzmora/private-distiller/examples/classifier_compression/logs/2018.04.22-224857/2018.04.22-224857.log
#
# real    989m37.564s
# user    15908m42.302s
# sys     1909m6.207s


version: 1
pruners:
  low_pruner:
    class: AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.60
    weights: [module.layer2.0.conv2.weight,
              module.layer3.1.conv1.weight,  module.layer3.1.conv2.weight,
              module.layer4.0.conv1.weight,  module.layer4.0.conv2.weight, module.layer4.0.downsample.0.weight,
              module.layer4.1.conv1.weight]

  mid_pruner:
    class:  AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.67
    weights: [module.layer1.1.conv1.weight,  module.layer1.1.conv2.weight,
              module.layer2.0.conv1.weight,  module.layer2.0.downsample.0.weight,
              module.layer3.0.conv2.weight,  module.layer4.1.conv2.weight]

  high_pruner:
    class:  AutomatedGradualPruner
    initial_sparsity : 0.05
    final_sparsity: 0.76
    weights: [module.layer1.0.conv1.weight,  module.layer1.0.conv2.weight,
              module.layer2.1.conv1.weight,  module.layer2.1.conv2.weight,
              module.layer3.0.conv1.weight]

lr_schedulers:
   pruning_lr:
     class: ExponentialLR
     gamma: 0.9


policies:
  - pruner:
      instance_name : low_pruner
    starting_epoch: 0
    ending_epoch: 16
    frequency: 2

  - pruner:
      instance_name : mid_pruner
    starting_epoch: 4
    ending_epoch: 16
    frequency: 2

  - pruner:
      instance_name : high_pruner
    starting_epoch: 4
    ending_epoch: 16
    frequency: 2

  - lr_scheduler:
      instance_name: pruning_lr
    starting_epoch: 13
    ending_epoch: 100
    frequency: 1
