#
# THIS IS CURRENTLY BROKEN!
# There are two challenges to getting this to work:
#   1. If we want to remove layers, we should use a thinning recipe and leave instructions.  Otherwise, it's very hard to
#      decide if we should remove layers or not.
#   2. We need a generic solution for removing layers.  Perhaps a new module type which wraps a layer and has a bypass-gate.
#
# If you still want to use this, until we implement a proper solution, invoke thinning.resnet_cifar_remove_layers from
# compress_classifier.py.
#
# We used this schedule to train CIFAR10-ResNet20 from scratch with SSL.
# After running this schedule, use ssl_4D-removal_finetuning.yaml to "physically" remove the layers and fine-tune the smaller model.
#
# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../ssl/ssl_4D-removal_training.yaml -j=1 --deterministic
#
# Parameters:
#
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+-----------+-----------+-----------+------------+---------+----------+------------+
# |    | Name                                | Shape          |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |    Ch (%) |    2D (%) |    3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
# |----+-------------------------------------+----------------+---------------+----------------+------------+------------+-----------+-----------+-----------+------------+---------+----------+------------|
# |  0 | module.conv1.weight                 | (16, 3, 3, 3)  |           432 |            432 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.52566 | -0.00661 |    0.37310 |
# |  1 | module.layer1.0.conv1.weight        | (16, 16, 3, 3) |          2304 |              0 |    0.00000 |    0.00000 | 100.00000 | 100.00000 | 100.00000 |  100.00000 | 0.00000 |  0.00000 |    0.00000 |
# |  2 | module.layer1.0.conv2.weight        | (16, 16, 3, 3) |          2304 |              0 |    0.00000 |    0.00000 | 100.00000 | 100.00000 | 100.00000 |  100.00000 | 0.00000 |  0.00000 |    0.00000 |
# |  3 | module.layer1.1.conv1.weight        | (16, 16, 3, 3) |          2304 |              0 |    0.00000 |    0.00000 | 100.00000 | 100.00000 | 100.00000 |  100.00000 | 0.00000 |  0.00000 |    0.00000 |
# |  4 | module.layer1.1.conv2.weight        | (16, 16, 3, 3) |          2304 |              0 |    0.00000 |    0.00000 | 100.00000 | 100.00000 | 100.00000 |  100.00000 | 0.00000 |  0.00000 |    0.00000 |
# |  5 | module.layer1.2.conv1.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.16538 | -0.01605 |    0.10076 |
# |  6 | module.layer1.2.conv2.weight        | (16, 16, 3, 3) |          2304 |           2304 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.01700 |  0.00062 |    0.00961 |
# |  7 | module.layer2.0.conv1.weight        | (32, 16, 3, 3) |          4608 |           4608 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.17273 | -0.01612 |    0.12345 |
# |  8 | module.layer2.0.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.14296 | -0.01012 |    0.10956 |
# |  9 | module.layer2.0.downsample.0.weight | (32, 16, 1, 1) |           512 |            512 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.25792 | -0.02183 |    0.16897 |
# | 10 | module.layer2.1.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.09893 | -0.01006 |    0.07186 |
# | 11 | module.layer2.1.conv2.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.01403 |  0.00033 |    0.00989 |
# | 12 | module.layer2.2.conv1.weight        | (32, 32, 3, 3) |          9216 |           9216 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.00211 | -0.00020 |    0.00113 |
# | 13 | module.layer2.2.conv2.weight        | (32, 32, 3, 3) |          9216 |              0 |    0.00000 |    0.00000 | 100.00000 | 100.00000 | 100.00000 |  100.00000 | 0.00000 |  0.00000 |    0.00000 |
# | 14 | module.layer3.0.conv1.weight        | (64, 32, 3, 3) |         18432 |          18432 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.11741 | -0.01275 |    0.09240 |
# | 15 | module.layer3.0.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.10811 | -0.00625 |    0.08531 |
# | 16 | module.layer3.0.downsample.0.weight | (64, 32, 1, 1) |          2048 |           2048 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.14911 | -0.01850 |    0.11430 |
# | 17 | module.layer3.1.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.09496 | -0.01072 |    0.07573 |
# | 18 | module.layer3.1.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.02961 | -0.00108 |    0.02315 |
# | 19 | module.layer3.2.conv1.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.09024 | -0.01083 |    0.07156 |
# | 20 | module.layer3.2.conv2.weight        | (64, 64, 3, 3) |         36864 |          36864 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.01302 |  0.00067 |    0.01009 |
# | 21 | module.fc.weight                    | (10, 64)       |           640 |            640 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    0.00000 | 0.58337 | -0.00001 |    0.48239 |
# | 22 | Total sparsity:                     | -              |        270896 |         252464 |    0.00000 |    0.00000 |   0.00000 |   0.00000 |   0.00000 |    6.80409 | 0.00000 |  0.00000 |    0.00000 |
# +----+-------------------------------------+----------------+---------------+----------------+------------+------------+-----------+-----------+-----------+------------+---------+----------+------------+
# Total sparsity: 6.80
#
# --- validate (epoch=179)-----------
# 5000 samples (256 per mini-batch)
# ==> Top1: 89.660    Top5: 99.680    Loss: 0.404
#
# Saving checkpoint
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# ==> Top1: 90.720    Top5: 99.640    Loss: 0.401
#
#
# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/private-distiller/examples/classifier_compression/logs/2018.04.07-010542/2018.04.07-010542.log
#
# real    36m52.051s
# user    75m35.413s
# sys     10m48.064s

lr_schedulers:
  training_lr:
    class: StepLR
    step_size: 45
    gamma: 0.10

regularizers:
  4D_groups_regularizer:
    class: GroupLassoRegularizer
    reg_regims:
      module.layer1.0.conv1.weight: [0.004, 4D]
      module.layer1.0.conv2.weight: [0.004, 4D]
      module.layer1.1.conv1.weight: [0.004, 4D]
      module.layer1.1.conv2.weight: [0.004, 4D]
      #module.layer1.2.conv1.weight: [0.004, 4D]
      module.layer1.2.conv2.weight: [0.004, 4D]
      #module.layer2.0.conv1.weight: [0.004, 4D]
      ##module.layer2.0.conv2.weight: [0.004, 4D]
      #module.layer2.1.conv1.weight: [0.004, 4D]
      module.layer2.1.conv2.weight: [0.004, 4D]
      #module.layer2.2.conv1.weight: [0.004, 4D]
      module.layer2.2.conv2.weight: [0.004, 4D]
      #module.layer3.0.conv1.weight: [0.004, 4D]
      #module.layer3.0.conv2.weight: [0.004, 4D]
      #module.layer3.1.conv1.weight: [0.004, 4D]
      module.layer3.1.conv2.weight: [0.004, 4D]
      #module.layer3.2.conv1.weight: [0.004, 4D]
      module.layer3.2.conv2.weight: [0.004, 4D]
    threshold_criteria: Mean_Abs

policies:
  - lr_scheduler:
      instance_name: training_lr
    starting_epoch: 45
    ending_epoch: 300
    frequency: 1

  - regularizer:
      instance_name: 4D_groups_regularizer
      args:
        keep_mask: True
    starting_epoch: 0
    ending_epoch: 180
    frequency: 1
