from .base import Baseline
import numpy as np
import os
import time
from openbox.surrogate.base.build_gp import create_gp_model
from openbox.surrogate.base.rf_with_instances import RandomForestWithInstances
from openbox.acquisition_function.acquisition import EI
from openbox.utils.util_funcs import get_types
from openbox.utils.config_space.util import convert_configurations_to_array

from .acq_optimizer.local_random import InterleavedLocalAndRandomSearch


class BayesianOptimization(Baseline):
    def __init__(self, config_space, eval_func, iter_num=200, save_dir='./results', task_name='default',
                 surrogate_type='prf'):
        super().__init__(config_space, eval_func, iter_num, save_dir, task_name)
        types, bounds = get_types(config_space)

        if surrogate_type == 'gp':
            self.surrogate = create_gp_model(model_type='gp',
                                             config_space=config_space,
                                             types=types,
                                             bounds=bounds,
                                             rng=self.rng)
        elif surrogate_type == 'prf':
            self.surrogate = RandomForestWithInstances(types=types, bounds=bounds, seed=self.seed)
        else:
            raise ValueError("Surrogate type %s not supported!" % surrogate_type)

        self.acq_func = EI(self.surrogate)
        self.acq_optimizer = InterleavedLocalAndRandomSearch(acquisition_function=self.acq_func,
                                                             config_space=config_space, rng=self.rng)

        self.init_num = 5

        self.timestamp = time.time()
        self.save_path = os.path.join(self.save_dir, '%s_%s_%d_%s.pkl' % (task_name, 'bo', iter_num, self.timestamp))

    def sample(self):
        num_config_evaluated = len(self.observations)

        if num_config_evaluated < self.init_num:  # Sample initial configurations randomly
            repeated_flag = True
            while repeated_flag:
                repeated_flag = False
                config = self.config_space.sample_configuration()
                for observation in self.observations:
                    if config == observation[0]:
                        repeated_flag = True
                        break
            return config

        X = convert_configurations_to_array([observation[0] for observation in self.observations])
        Y = np.array([observation[1] for observation in self.observations])

        self.surrogate.train(X, Y)

        self.acq_func.update(model=self.surrogate,
                             eta=self.incumbent_value,
                             num_data=num_config_evaluated)

        challengers = self.acq_optimizer.maximize(observations=self.observations,
                                                  num_points=5000)

        repeated_flag = True
        repeated_time = 0
        cur_config = None
        while repeated_flag:
            repeated_flag = False
            cur_config = challengers.challengers[repeated_time]
            for observation in self.observations:
                if cur_config == observation[0]:
                    repeated_flag = True
                    repeated_time += 1
                    break
        return cur_config
