#!/usr/bin/env python

from pyflann import *
from copy import copy
from numpy import *
from numpy.random import *
import unittest


class Test_PyFLANN_nn(unittest.TestCase):

    def setUp(self):
        self.nn = FLANN()


class Test_PyFLANN_nn_index(unittest.TestCase):
    
    def testnn_index_save_kdtree_1(self):
        self.run_nn_index_save_perturbed(64,1000, algorithm="kdtree", trees=1)

    def testnn_index_save_kdtree_4(self):
        self.run_nn_index_save_perturbed(64,1000, algorithm="kdtree", trees=4)

    def testnn_index_save_kdtree_10(self):
        self.run_nn_index_save_perturbed(64,1000, algorithm="kdtree", trees=10)


    def testnn_index_save_kmeans_2(self):
        self.run_nn_index_save_perturbed(64,1000, algorithm="kmeans", branching=2, iterations=11)

    def testnn_index_save_kmeans_16(self):
        self.run_nn_index_save_perturbed(64,1000, algorithm="kmeans", branching=16, iterations=11)
    
    def testnn_index_save_kmeans_32(self):
        self.run_nn_index_save_perturbed(64,1000, algorithm="kmeans", branching=32, iterations=11)
    
    def testnn_index_save_kmeans_64(self):
        self.run_nn_index_save_perturbed(64,1000, algorithm="kmeans", branching=64, iterations=11)


    def testnn__save_kdtree_1(self):
        self.run_nn_index_save_rand(64,10000,1000, algorithm="kdtree", trees=1, checks=128)    

    def testnn__save_kdtree_4(self):
        self.run_nn_index_save_rand(64,10000,1000, algorithm="kdtree", trees=4, checks=128)

    def testnn__save_kdtree_10(self):
        self.run_nn_index_save_rand(64,10000,1000, algorithm="kdtree", trees=10, checks=128)

    def testnn__save_kmeans_2(self):
        self.run_nn_index_save_rand(64,1000,1000, algorithm="kmeans", branching=2, iterations=11, checks=64)

    def testnn__save_kmeans_8(self):
        self.run_nn_index_save_rand(64,10000,1000, algorithm="kmeans", branching=8, iterations=11, checks=32)

    def testnn__save_kmeans_16(self):
        self.run_nn_index_save_rand(64,10000,1000, algorithm="kmeans", branching=16, iterations=11, checks=40)
    
    def testnn__save_kmeans_32(self):
        self.run_nn_index_save_rand(64,10000,1000, algorithm="kmeans", branching=32, iterations=11, checks=56)




    def run_nn_index_save_perturbed(self, dim, N, **kwargs):
        
        x = rand(N, dim)

        nn = FLANN()
        nn.build_index(x, **kwargs)
        nn.save_index("index.dat")
        nn.delete_index();

        nn = FLANN()
        nn.load_index("index.dat",x)
        x_query = x + randn(x.shape[0], x.shape[1])*0.0001/dim
        nnidx, nndist = nn.nn_index(x_query)
        correct = all(nnidx == arange(N, dtype = index_type))
                
        nn.delete_index()
        self.assertTrue(correct)
    
    def run_nn_index_save_rand(self, dim, N, Nq, **kwargs):

        x = rand(N, dim)
        x_query = rand(Nq,dim)

        # build index, search and delete it
        nn = FLANN()
        nn.build_index(x, **kwargs)
        nnidx, nndist = nn.nn_index(x_query, checks=kwargs["checks"])
        nn.save_index("index.dat")
        del nn


        # now reload index and search again
        nn = FLANN()
        nn.load_index("index.dat",x)
        nnidx2, nndist2 = nn.nn_index(x_query, checks=kwargs["checks"])
        del nn

        correct = all(nnidx == nnidx2)
        self.assertTrue(correct)

if __name__ == '__main__':
    unittest.main()
