# GSS (Guided Structured Sparsity).
# "Attention-Based Guided Structured Sparsity of Deep Neural Networks",
# Amirsina Torfi, Rouzbeh A. Shirvani, Sobhan Soleymani, Nasser M. Nasrabadi
# ICLR 2018
# https://arxiv.org/abs/1802.09902
#
# Add group variance regularization to SSL.
# So far I haven't produced results better than SSL. The regularization strengh of the variance does not come into play
# because it seems like the variance cost diminishes very quickly.
#
# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../ssl/ssl_channels-removal_training.yaml -j=1 --deterministic
#


lr_schedulers:
  training_lr:
    class: StepLR
    step_size: 45
    gamma: 0.10

regularizers:
  Channels_l2_regularizer:
    class: GroupLassoRegularizer
    reg_regims:
      module.layer1.0.conv2.weight: [0.0028, Channels]
      module.layer1.1.conv2.weight: [0.0028, Channels]
      module.layer1.2.conv2.weight: [0.0024, Channels]
      module.layer2.0.conv2.weight: [0.0016, Channels]   # sensitive
      module.layer2.1.conv2.weight: [0.0028, Channels]
      module.layer2.2.conv2.weight: [0.0028, Channels]
      module.layer3.0.conv2.weight: [0.0008, Channels]  # sensitive
      module.layer3.1.conv2.weight: [0.0028, Channels]
      #module.layer3.2.conv2.weight: [0.0006, Channels] # very sensitive
    threshold_criteria: Mean_Abs

  Channels_variance_reguralizer:
    class: GroupVarianceRegularizer
    reg_regims:
      module.layer1.0.conv2.weight: [0.000008, Channels]
      module.layer1.1.conv2.weight: [0.000008, Channels]
      module.layer1.2.conv2.weight: [0.000008, Channels]
      module.layer2.0.conv2.weight: [0.000008, Channels]
      module.layer2.1.conv2.weight: [0.000008, Channels]
      module.layer2.2.conv2.weight: [0.000008, Channels]
      module.layer3.0.conv2.weight: [0.000008, Channels]
      module.layer3.1.conv2.weight: [0.000008, Channels]
      #module.layer3.2.conv2.weight: [0.000008, Channels]

extensions:
  net_thinner:
      class: 'ChannelRemover'
      thinning_func_str: remove_channels
      arch: 'resnet20_cifar'
      dataset: 'cifar10'

policies:
  - lr_scheduler:
      instance_name: training_lr
    starting_epoch: 45
    ending_epoch: 300
    frequency: 1

# After completeing the regularization, we perform network thinning and exit.
  - extension:
      instance_name: net_thinner
    epochs: [179]

  - regularizer:
      instance_name: Channels_l2_regularizer
      args:
        keep_mask: True
    starting_epoch: 0
    ending_epoch: 180
    frequency: 1

  - regularizer:
      instance_name: Channels_variance_reguralizer
      args:
        keep_mask: True
    starting_epoch: 0
    ending_epoch: 180
    frequency: 1
