
import torch
from pyro.contrib.tracking.measurements import PositionMeasurement


def test_PositionMeasurement():
    dimension = 3
    time = 0.232
    frame_num = 5
    measurement = PositionMeasurement(
        mean=torch.rand(dimension),
        cov=torch.eye(dimension), time=time, frame_num=frame_num)
    assert measurement.dimension == dimension
    x = torch.rand(2*dimension)
    assert measurement(x).shape == (dimension,)
    assert measurement.mean.shape == (dimension,)
    assert measurement.cov.shape == (dimension, dimension)
    assert measurement.time == time
    assert measurement.frame_num == frame_num
    assert measurement.geodesic_difference(
        torch.rand(dimension), torch.rand(dimension)).shape \
        == (dimension,)
    assert measurement.jacobian().shape == (dimension, 2*dimension)
