defaults:
  - _self_
  - dataset: mnist
  - pretrained_model: vitca
  - classifier: linear
  - override hydra/job_logging: colorlog
  - override hydra/hydra_logging: colorlog

experiment:
  name: mnist
  device: cuda:0
  save_root: SAVE_ROOT
  num_workers: 2
  pretrained_model_path: 'FOLDER/TO/nca_best.pth.tar'  # ViTCA baseline
  deterministic: true

  use_avg_feats: false

  trainer:
    loss:
      _target_: image_classification.src.losses.CALoss
      # _target_: image_classification.src.losses.Loss  # for non-CA models
      rec_factor: 1e2
      overflow_factor: 1e2
    checkpointing:
      enabled: false
    linear_probing:
      # Use batch size of 32. Batch size of 8 for tinyimagenet since it's 64x64.
      enabled: true
      masking_enabled: false
      use_pretrained_model: true
      # Use SGD with cosine annealing lr, momentum of 0.9, lr of 1e-1, with some weight decay.
      # opt:
      #   _target_: torch.optim.SGD
      #   lr: 1e-1
      #   momentum: 0.9
      #   weight_decay: 1e-4
      # lr_sched:
      #   _target_: torch.optim.lr_scheduler.CosineAnnealingLR
      #   T_max: ${experiment.iter.train.total}
      # Use AdamW but no lr schedule. Same lr though.
      opt:
        _target_: torch.optim.AdamW
        lr: 1e-3
      lr_sched: null
    fine_tuning:
      # Use AdamW but no lr schedule. Same lr though.
      # Fine-tune with cross-entropy for classification + L1 for (denoising, if requested) autoencoding.
      # Use batch size of 32. Batch size of 8 for tinyimagenet since it's 64x64.
      enabled: false
      masking_enabled: false
      opt:
        _target_: torch.optim.AdamW
        lr: 1e-3
      lr_sched: null
    # TODO: better integrate fewshot learning into the source code.
    fewshot:
      enabled: false
      samples_per_class: 1
      episode_length: 3

  iter:
    train:
      start: 1
      total: 100000
      ca:
        min: 8
        max: 32
        update_rate: 0.5
    val:
      ca:
        value: 64
        update_rate: 0.5
    test:  # for linear probing with CA-based models
      ca:
        value: 64
        update_rate: 0.5

  batch_size:
    train: 32
    val: 32

  pool_size: 1024
  resume_from_latest: true

  sample_with_replacement: false

  input_size:
    train: [32, 32]
    val: [32, 32]
    test: [32, 32]  # for linear probing with CA-based models

  attn_size:
    train: [3, 3]
    val: [3, 3]
    test: [3, 3]  # for linear probing with CA-based models

  log_frequency: 11
  save_frequency: 5000
  val_frequency: 1000

  normalize_gradients: true

  random_seed: 1

  masking:
    train:
      type: noise
      max_prob: 0.75
      max_patch_shape: [4, 4]
      prob_stages: 3
      patch_shape_stages: 3
      schedule_start: 500
      schedule_end: 10000
    val:
      type: noise
      max_prob: 0.75
      max_patch_shape: [4, 4]
      prob_stages: 3
      patch_shape_stages: 3
    test:  # for linear probing with CA-based models
      type: noise
      max_prob: 0.75
      max_patch_shape: [4, 4]
      prob_stages: 3
      patch_shape_stages: 3

classifier:
  num_classes: 10

wandb:
  entity: ENTITY
  project: PROJECT
  run_id: null

hydra:
  run:
    dir: ${experiment.save_root}/${experiment.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}