# @package _global_
defaults:
  - override /dataset: diving48
  - override /model/optim_params: optim_deit
  - override /model/loss_params: loss_defaults
dataset:
  dataloader_params:
    batch_size: 90
    num_workers: 8
    persistent_workers: True
  dataset_params:
    max_steps: 10
    min_steps_past: 3
    min_steps_future: 3
    uniform: True
checkpoint:
  save_top_k: 3
trainer:
  gpus: 4
  plugins: null
  max_epochs: 10000
model:
  name: VRNNModel
  #  sample_mode: True  # Only in an inference setting
  hidden_size: 512
  feature_size: 25
  trajectron: True
  latent_distribution: 'categorical'
  encoder_params:
    name: EncoderTrajectron
    num_layers: 2
    in_size: 50  # Size of the points (25*2)
  decoder_params:
    name: TrajectronDecoder
    out_size: 50  # Size of the points (25*2)
    n_blocks: 4
  loss_params: # InfoVAE formulation
    dict_losses:
      kld_loss:
        λ: 1
        params:
      reconstruction_loss:
        λ: 1
        params:
          distance_fn_name: 'euclidean_l2_keypoints'
          loss_type: 'regression'
      info_loss:
        λ: 1
        params:
          info_prior: True
  optim_params:
    base_lr: 0.00001
    warmup_epochs: 1
    max_epoch: 500  # This is the number of training steps, not epochs
    warmup_start_lr: 0.000001
    grad_clip_val: 0.01
    grad_clip_strategy: 'norm'
setting: train
wandb:
  name: diving_trajectron_uniftrain
