# Tuning Learning Rates with Amortized Proximal Optimization

This repository contains a PyTorch implementation of Amortized Proximal Optimization for learning rate tuning.

## Requirements

* ruamel.yaml
* PyTorch (tested on version 1.5)


## Examples

### MLP on MNIST

**MNIST RMSprop Baseline**
```
python train.py  \
    --save_dir=mnist_baseline  \
    --dataset=mnist  \
    --model=mlp  \
    --seed=11  \
    --epochs=100  \
    --batch_size=100  \
    --base_optimizer=rmsprop  \
    --meta_optimizer=rmsprop \
    --num_meta_steps=0 \
    --lr=1e-4
```

**MNIST RMSprop-APO**
```
python train.py \
    --save_dir=mnist_apo \
    --dataset=mnist \
    --model=mlp \
    --seed=11 \
    --epochs=100 \
    --batch_size=100 \
    --base_optimizer=rmsprop \
    --meta_optimizer=rmsprop \
    --num_meta_steps=1 \
    --lr=1e-4 \
    --meta_lr=0.1 \
    --meta_interval=1 \
    --lam=1e-5
```

**Plot MNIST**
```
python plot_mnist_lr.py
```


### WideResNet-28-10 on CIFAR-10, using SGDm as the Base Optimizer

**Fixed LR Baseline**
```
python train.py \
    --save_dir=cifar10_wrn_sgdm_fixed_lr \
    --seed=11 \
    --base_optimizer=sgdmwd \
    --meta_optimizer=rmsprop \
    --batch_size=128 \
    --epochs=200 \
    --num_meta_steps=0 \
    --data_augmentation \
    --model=wideresnet \
    --dataset=cifar10 \
    --lr=0.03 \
    --wdecay=0
```

**Decayed LR Baseline**
```
python train.py \
    --save_dir=cifar10_wrn_sgdm_decayed \
    --seed=11 \
    --base_optimizer=sgdmwd \
    --meta_optimizer=rmsprop \
    --batch_size=128 \
    --epochs=200 \
    --num_meta_steps=0 \
    --data_augmentation \
    --model=wideresnet \
    --dataset=cifar10 \
    --lr=0.1 \
    --wdecay=5e-4 \
    --decay_at=60,120,160 \
    --factor=0.2 \
    --schedule
```

**APO**
```
python train.py \
    --save_dir=cifar10_wrn_sgdm_apo \
    --seed=11 \
    --base_optimizer=sgdmwd \
    --meta_optimizer=rmsprop \
    --batch_size=128 \
    --epochs=200 \
    --data_augmentation \
    --model=wideresnet \
    --dataset=cifar10 \
    --lr=0.1 \
    --wdecay=1e-4 \
    --num_meta_steps=1 \
    --meta_interval=10 \
    --meta_lr=0.1 \
    --lam=0.1
```

**Plot CIFAR-10**
```
python plot_cifar10_lr.py
```


### ResNet34 on CIFAR-10, using RMSprop as the Base Optimizer

**Fixed LR Baseline**
```
python train.py \
    --save_dir=cifar10_resnet34_rmsprop_fixed_lr \
    --seed=11 \
    --base_optimizer=rmsprop \
    --batch_size=128 \
    --epochs=200 \
    --num_meta_steps=0 \
    --data_augmentation \
    --model=resnet34 \
    --dataset=cifar10 \
    --lr=1e-3 \
    --wdecay=0
```

**Decayed LR Baseline**
```
python train.py \
    --save_dir=cifar10_resnet34_rmsprop_decayed_lr \
    --seed=11 \
    --base_optimizer=rmsprop \
    --batch_size=128 \
    --epochs=200 \
    --num_meta_steps=0 \
    --data_augmentation \
    --model=resnet34 \
    --dataset=cifar10 \
    --lr=1e-3 \
    --wdecay=0 \
    --decay_at=60,120,160 \
    --factor=0.2 \
    --schedule
```

**APO**
```
python train.py \
    --save_dir=cifar10_resnet34_rmsprop_apo \
    --seed=11 \
    --base_optimizer=rmsprop \
    --batch_size=128 \
    --epochs=200 \
    --num_meta_steps=1 \
    --data_augmentation \
    --model=resnet34 \
    --dataset=cifar10 \
    --lr=1e-4 \
    --lam=1e-5 \
    --wdecay=0
```

**Plot CIFAR-10**
```
python plot_cifar10_lr.py
```
