# Canonical Capsules: Self-Supervised Capsules in Canonical Pose

![teaser](docs/teaser.gif)

## Introduction

This is the official repository for the PyTorch implementation of "Canonical Capsules: Self-Supervised Capsules in Canonical Pose" 

## Requirements

Please install dependencies with the provided `environment.yml`: 
```
conda env create -f environment.yml
```

## Datasets

- We use the ShapeNet dataset as in AtlasNetV2: download the data from AtlasNetV2's [official repo](https://github.com/TheoDEPRELLE/AtlasNetV2) and convert the downloaded data into h5 files with the provided script (i.e., `data_utils/ShapeNetLoader.py`).  

- For faster experimentation, please use our [2D planes dataset](https://drive.google.com/file/d/13FzfjGIBL7eypagy3-kAjG3GTIbzRFCU/view?usp=sharing), which we generated from ShapeNet.

## Training/testing (2D) 

To train the model on 2D planes (training of network takes only 50 epochs, and one epoch takes approximately 2.5 minutes on an NVIDIA GTX 1080 Ti):
```
./main.py --log_dir=plane_dim2 --indim=2 --scheduler=5
```

To visualize the decomposition and reconstruction:
```
./main.py --save_dir=gifs_plane2d --indim=2 --scheduler=5 --mode=vis --pt_file=logs/plane_dim2/checkpoint.pth
``` 

## Training/testing (3D)

To train the model on the 3D dataset:
```
./main.py --log_dir=plane_dim3 --indim=3 --cat_id=-1
```

We test the model with:
```
./main.py --log_dir=plane_dim3 --indim=3 --cat_id=-1 --mode=test
``` 

Note that the option `cat_id` indicates the category id to be used to load the corresponding h5 files ([this look-up table](https://drive.google.com/file/d/1njXYVoc7uWW0vopuzX0rDF1vdr2c-Xow/view?usp=sharing)):

| id | category |
|----|------------|
| -1 | all        |
| 0  | bench      |
| 1  | cabinet    |
| 2  | car        |
| 3  | cellphone  |
| 4  | chair      |
| 5  | couch      |
| 6  | firearm    |
| 7  | lamp       |
| 8  | monitor    |
| 9  | plane      |
| 10 | speaker    |
| 11 | table      |
| 12 | watercraft |

## Pre-trained models (3D)
We release the 3D [pretrained models](https://drive.google.com/file/d/1RblQQ-ocnrSOg4KTzFrLrU7SPOM_UwtU/view?usp=sharing)
for both single category (airplanes), as well as multi-category (all 13 classes).


## Classification

To use our classification script:
```
python classification.py --data_dir=/path/to/saved/features --feature_type=caca --method_type=svm --use_kpts
```
