"""Tests for the Geometric Manifold Component Estimator (GEOMANCER)."""

from absl.testing import absltest
from absl.testing import parameterized

import numpy as np

import geomancer


class GeomancerTest(parameterized.TestCase):

  @parameterized.parameters(
      {'zero_trace': False},
      {'zero_trace': True})
  def test_sym_op(self, zero_trace):
    """_antisym_op on tril(X) gives same result as QXQ' for antisymmetric X?"""
    n = 5
    x = np.random.randn(n, n)
    x += x.T
    if zero_trace:
      np.fill_diagonal(x, np.diag(x)-np.trace(x)/n)
    q, _ = np.linalg.qr(np.random.randn(n, n))
    sym_q = geomancer.sym_op(q, zero_trace=zero_trace)
    tril_x = x[np.tril_indices(n)]
    if zero_trace:
      tril_x = tril_x[:-1]
    vec_y = sym_q @ tril_x
    y = q @ x @ q.T
    y_ = geomancer.vec_to_sym(vec_y, n, zero_trace=zero_trace)
    assert np.allclose(y_, y)

  def test_ffdiag(self):
    pass

  def test_make_nearest_neighbor_graph(self):
    n = 100
    # make points on a circle
    data = np.zeros((n, 2))
    for i in range(n):
      data[i, 0] = np.sin(i*2*np.pi/n)
      data[i, 1] = np.cos(i*2*np.pi/n)
    graph = geomancer.make_nearest_neighbors_graph(data, 4, n=10)
    for i in range(n):
      assert len(graph.rows[i]) == 4
      assert (i+1) % n in graph.rows[i]
      assert (i+2) % n in graph.rows[i]
      assert (i-1) % n in graph.rows[i]
      assert (i-2) % n in graph.rows[i]


if __name__ == '__main__':
  absltest.main()
