import numpy as np

from softlearning.samplers import SimpleSampler


class BBACSampler(SimpleSampler):
    def __init__(self,
                 policies,
                 exploitation_policy,
                 exploitation_policy_sample_ratio=0,
                 **kwargs):
        super(BBACSampler, self).__init__(**kwargs)
        self.exploitation_policy = exploitation_policy
        self.policies = policies
        self.exploitation_policy_sample_ratio = exploitation_policy_sample_ratio

    def reset(self, *args, **kwargs):
        if np.random.rand() < self.exploitation_policy_sample_ratio:
            self.policy = self.exploitation_policy
            assert not any(
                self.policy is x for x in self.policies)
        else:
            self.policy = np.random.choice(self.policies)

        return super().reset(*args, **kwargs)
