<type>: Munch
<init>: true

# TestTube experiment.
logger:
  <type>: WandbLogger
  <init>: false
  project: "deq-neurips-prefix-sum-ablations"
  tags: ["ablation9", "fixed_depth_ablation_1", "one_layer_classifier"]

# PyTorch Lightning System.
system:
  <type>: SingleMixtureTaskLearner
  <init>: false

  model|unroll_bp:
    <type>: get_default_prefix_sum_skeleton
    <init>: true
    forward_solver{}:
      <type>: fixed_point_iterator
      <init>: false
      num_iters{}: 32
    backward_solver{}: null
    use_single_layer_classifier{}: True

  model|unroll_bp_no_wnorm:
    <type>: get_default_prefix_sum_skeleton
    <init>: true
    forward_solver{}:
      <type>: fixed_point_iterator
      <init>: false
      num_iters{}: 32
    backward_solver{}: null
    weight_normalization: false
    use_single_layer_classifier{}: True

  model|unroll_bp_shallow:
    <type>: get_default_prefix_sum_skeleton
    <init>: true
    forward_solver{}:
      <type>: fixed_point_iterator
      <init>: false
      num_iters{}: 6
    backward_solver{}: null
    use_single_layer_classifier{}: True


  model|unroll_bp_shallow_no_wnorm:
    <type>: get_default_prefix_sum_skeleton
    <init>: true
    forward_solver{}:
      <type>: fixed_point_iterator
      <init>: false
      num_iters{}: 6
    backward_solver{}: null
    weight_normalization: false
    use_single_layer_classifier{}: True


  # Get loaders.
  train_loader:
    <type>: get_default_loaders_for_prefix_sum
    <init>: true
    partition_name: "train"
    batch_size{}: 150

  valid_loader:
    <type>: get_default_loaders_for_prefix_sum
    <init>: true
    batch_size: 32
    partition_name: "valid"

  test_loader:
    <type>: get_default_loaders_for_prefix_sum
    <init>: true
    batch_size: 32
    partition_name: "test"

  # Optimizers.
  optimizer{}:
    <type>: Adam
    <init>: false
    lr{}|INFER1: 0.001
    lr{}|INFER2: 0.0001
    lr{}|INFER3: 0.00001
    weight_decay: 0

  # Learning rate scheduler.
  lr_scheduler:
    <type>: MultiStepLR
    <init>: false
    milestones: [15000, 22500]
    gamma: 0.5

  # Loss functions.
  loss_fn:
    <type>: CrossEntropyLoss
    <init>: true

  # Logging functions.
  logging_functions:
    - <type>: LogGenericTrainingState
      <init>: true
    - <type>: LogOptimizerStats
      <init>: true
    - <type>: LogClassificationMetrics
      <init>: true
    - <type>: LogMetricModelLogs
      <init>: true

  # ____ Other optional keyword arguments. ____
  model_kwargs_getter:
    <type>: get_pretraining_mode_kwarg
    <init>: false
    num_pretraining_steps: 0

# Trainer.
trainer:
  <type>: Trainer
  <init>: false
  max_steps: 30000
  profiler: false
  num_sanity_val_steps: 1
  check_val_every_n_epoch: 5000
  gradient_clip_val: 1.

# ____ Checkpoint callback. ____
checkpoint_callback:
  <type>: SimpleCheckpointer
  <init>: false
  save_at_steps:
    <type>: int_powers_of_k
    <init>: true
    k: 1.415
    multiplier: 20
  keep_all_checkpoints: true
  best_checkpoint_determining_metric_key: "validation/average_in_distribution_val_error"

callbacks: null

# Task function.
task_fns:
  <type>: Munch
  <init>: true
  train:
    <type>: train
    <init>: false

# Other.
resume_if_possible: True
test: False
seed{}|INFER1: 0
use_deterministic_algorithms: False
wandb_dryrun: False