# PreDiff
This repo is the official PyTorch implementation of PreDiff that is submitted to NeurIPS 2023.
## Installation
We recommend managing the environment through Anaconda. First, create a new conda environment:
```bash
conda create -n prediff python=3.9.15
conda activate prediff
```
Then, install PyTorch with correct CUDA support:
```bash
python -m pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html
```
Finally, install PreDiff in dev mode
```bash
cd ROOT_DIR/prediff
python -m pip install -U -e . --no-build-isolation
```

## Dataset
Run the following command to generate the N-body MNIST dataset:
```bash
cd ROOT_DIR/prediff
python ./scripts/nbody_dataset/generate_nbody_dataset.py --cfg ./scripts/nbody_dataset/cfg.yaml
```
See the corresponding [README](./scripts/nbody_dataset/README.md) for more information.

## Training
Run the following command to first train VAE on N-body MNIST dataset:
```bash
cd ROOT_DIR/prediff
MASTER_ADDR=localhost MASTER_PORT=10001 python ./scripts/vae/train_vae_nbody.py --gpus 2 --cfg ./scripts/vae/cfg.yaml --save tmp_vae_nbody
```
After the VAE gets trained, put the checkpoint of the VAE from `ROOT_DIR/prediff/experiments/tmp_vae_nbody/checkpoints/vae_nbody.pt` to `ROOT_DIR/prediff/pretrained/vae_nbody.pt`.

Then, run the following command to train the knowledge control network on N-body MNIST dataset.
```bash
cd ROOT_DIR/prediff
MASTER_ADDR=localhost MASTER_PORT=10001 python ./scripts/knowledge_control/train_kc_nbody.py --gpus 2 --cfg ./scripts/knowledge_control/cfg.yaml --save tmp_kc_nbody
```
After the knowledge control network gets trained, put the checkpoint of the knowledge control network from `ROOT_DIR/prediff/experiments/tmp_kc_nbody/checkpoints/kc_nbody.pt` to `ROOT_DIR/prediff/pretrained/kc_nbody.pt`.

Then, run the following command to train the latent diffusion model of the PreDiff on N-body MNIST dataset. 
The order of training the latent diffusion model and training the knowledge control network is arbitrary.
```bash
MASTER_ADDR=localhost MASTER_PORT=10001 python ./scripts/ldm/train_ldm_nbody.py --gpus 2 --cfg ./scripts/ldm/cfg.yaml --save tmp_ldm_nbody
```
With trained PreDiff saved at `ROOT_DIR/prediff/experiments/tmp_ldm_nbody`, run the following command to do inference without knowledge control:
```bash
cd ROOT_DIR/prediff
MASTER_ADDR=localhost MASTER_PORT=10001 python ./scripts/ldm/train_ldm_nbody.py --gpus 2 --cfg ./scripts/ldm/cfg.yaml --save tmp_ldm_nbody --test --ckpt_name last.ckpt
```
and run the following command to do inference under knowledge control:
```bash
cd ROOT_DIR/prediff
MASTER_ADDR=localhost MASTER_PORT=10001 python ./scripts/ldm/train_ldm_nbody.py --gpus 2 --cfg ./scripts/ldm/cfg_kc.yaml --save tmp_ldm_nbody --test --ckpt_name last.ckpt
```
More information can be found in the corresponding [README](./scripts/vae/README.md) for VAE, [README](./scripts/ldm/README.md) for the latent diffusion model, and [README](./scripts/knowledge_control/README.md) for knowledge control.
