"""
Evaluate P-IS of a batch of point clouds.

The point cloud batch should be saved to a npz file, where there is an
arr_0 key of shape [N x K x 3], where K is the dimensionality of each
point cloud and N is the number of clouds.
"""

import argparse
import numpy as np
from evals.classifier.classifier import PointNetClassifier
from evals.fid_is import compute_inception_score


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("batch", type=str)
    args = parser.parse_args()

    pts = np.load(args.batch)

    print("creating classifier...")
    clf = PointNetClassifier()

    print("computing batch predictions")
    _, preds = clf.features_and_preds(pts.transpose(1, 2))
    print(f"P-IS: {compute_inception_score(preds)}")


if __name__ == "__main__":
    main()
