import unittest

import torch

from bbo.algorithms import (
    RandomStrategy,
    BestKStrategy,
)


class RandomStrategyTest(unittest.TestCase):
    def test_fill(self):
        dim = 5
        lb = torch.zeros(dim)
        ub = torch.tensor([1, 2, 3, 2, 10])
        random_strategy = RandomStrategy(dim, lb, ub)
        
        important_idx = [2, 3]
        important_x = torch.tensor([0.5, 1])

        new_x = random_strategy.fill(important_idx, important_x)
        self.assertEqual(new_x.shape, (dim, ))
        self.assertTrue((new_x >= lb).all() and (new_x <= ub).all())

    
class BestKStrategyTest(unittest.TestCase):
    def test_fill(self):
        dim = 5
        lb = torch.zeros(dim)
        ub = torch.tensor([1, 2, 3, 2, 1])
        k = 3
        bestk_strategy = BestKStrategy(dim, lb, ub, k=k)

        # init
        init_X = lb + (ub - lb) * torch.rand((k, dim))
        init_Y = torch.tensor([1, 2, 3], dtype=torch.float)
        for x, y in zip(init_X, init_Y):
            bestk_strategy.update(x, y)

        # fill
        important_idx = [2, 3]
        important_x = torch.tensor([0.5, 1])

        new_x = bestk_strategy.fill(important_idx, important_x)
        self.assertEqual(new_x.shape, (dim, ))
        self.assertTrue((new_x >= lb).all() and (new_x <= ub).all())

        # update
        new_y = 2
        bestk_strategy.update(new_x, new_y)
        self.assertTrue(
            (bestk_strategy.best_Y == torch.tensor([2, 2, 3]).reshape(-1, 1)).all()
        )