from .base import Baseline
import numpy as np
import os
import time
import pickle as pkl
from openbox.surrogate.base.build_gp import create_gp_model
from openbox.surrogate.base.rf_with_instances import RandomForestWithInstances
from openbox.acquisition_function.acquisition import EI
from openbox.utils.util_funcs import get_types
from openbox.utils.config_space.util import convert_configurations_to_array

from .acq_optimizer.local_random import InterleavedLocalAndRandomSearch

from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics._scorer import _BaseScorer, _PredictScorer, _ThresholdScorer

from utils.time import time_limit, TimeoutException

from mindware.components.utils.constants import *

class BayesianOptimizationEnsemble(Baseline):
    def __init__(self, config_space, eval_func, iter_num=200, save_dir='./results', task_name='default',
                 surrogate_type='prf', scorer = None, task_type = 'cls', train_node = None, test_node = None):
        super().__init__(config_space, eval_func, iter_num, save_dir, task_name)
        types, bounds = get_types(config_space)

        if surrogate_type == 'gp':
            self.surrogate = create_gp_model(model_type='gp',
                                             config_space=config_space,
                                             types=types,
                                             bounds=bounds,
                                             rng=self.rng)
        elif surrogate_type == 'prf':
            self.surrogate = RandomForestWithInstances(types=types, bounds=bounds, seed=self.seed)
        else:
            raise ValueError("Surrogate type %s not supported!" % surrogate_type)

        self.acq_func = EI(self.surrogate)
        self.acq_optimizer = InterleavedLocalAndRandomSearch(acquisition_function=self.acq_func,
                                                             config_space=config_space, rng=self.rng)

        self.init_num = 5

        self.timestamp = time.time()
        self.save_path = os.path.join(self.save_dir, '%s_%s_%d_%s.pkl' % (task_name, 'eo', iter_num, self.timestamp))

        self.task_type = task_type
        self.scorer = scorer

        self.ensemble_size = 12
        self.ensemble = list()

        # Store valid labels and test labels
        test_size = 0.25  # Consistent with evaluate.py
        seed = 1  # Consistent with evaluate.py
        if task_type == 'cls':
            from sklearn.model_selection import StratifiedShuffleSplit

            ss = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=seed)
        else:
            from sklearn.model_selection import ShuffleSplit

            ss = ShuffleSplit(n_splits=1, test_size=test_size, random_state=seed)
        for train_index, test_index in ss.split(train_node.data[0], train_node.data[1]):
            _y_train, self._y_val = train_node.data[1][train_index], train_node.data[1][test_index]
        self._y_test = test_node.data[1]

    def sample(self):
        num_config_evaluated = len(self.observations)

        if num_config_evaluated < self.init_num:  # Sample initial configurations randomly
            repeated_flag = True
            while repeated_flag:
                repeated_flag = False
                config = self.config_space.sample_configuration()
                for observation in self.observations:
                    if config == observation[0]:
                        repeated_flag = True
                        break
            return config

        if len(self.ensemble) == self.ensemble_size:
            del self.ensemble[0]
        
        X = convert_configurations_to_array([observation[0] for observation in self.observations])
        
        val_pred_ens = [model[3] for model in self.ensemble]
        val_pred_ob = [np.mean(np.array(val_pred_ens + [observation[3]]), axis=0) if observation[1] != np.inf
                        else None for observation in self.observations]
        if self.task_type == 'cls':
            val_pred_ob = [np.argmax(pred, axis=-1) if pred is not None else None for pred in val_pred_ob]

        Y = np.array([np.inf if pred is None else -self.scorer._score_func(self._y_val, pred) for pred in val_pred_ob])

        self.surrogate.train(X, Y)

        self.incumble_value = np.min(Y)
        self.acq_func.update(model=self.surrogate,
                             eta=self.incumbent_value,
                             num_data=num_config_evaluated)

        challengers = self.acq_optimizer.maximize(observations=self.observations,
                                                  num_points=5000)

        repeated_flag = True
        repeated_time = 0
        cur_config = None
        while repeated_flag:
            repeated_flag = False
            cur_config = challengers.challengers[repeated_time]
            for observation in self.observations:
                if cur_config == observation[0]:
                    repeated_flag = True
                    repeated_time += 1
                    break
        return cur_config

    def run(self, time_limit_per_trial=30):
        for iter in range(self.iter_num):
            config = self.sample()
            start_time = time.time()
            try:
                with time_limit(time_limit_per_trial):
                    val_obj, test_obj, val_pred, test_pred = self.eval_func(config)
                runtime = time.time() - start_time
                
                self.ensemble.append((config, val_obj, test_obj, val_pred, test_pred, time))

                val_pred_ens = np.mean(np.array([model[3] for model in self.ensemble]), axis = 0)
                if self.task_type == 'cls':
                    val_pred_ens = np.argmax(val_pred_ens, axis=-1)
                val_obj_ens = -self.scorer._score_func(self._y_val, val_pred_ens)

                test_pred_ens = np.mean(np.array([model[4] for model in self.ensemble]), axis = 0)
                if self.task_type == 'cls':
                    test_pred_ens = np.argmax(test_pred_ens, axis=-1)
                test_obj_ens = -self.scorer._score_func(self._y_test, test_pred_ens)

                print('Iter: %d, Obj: %f, Test obj: %f, Eval time: %f' % (iter, val_obj_ens, test_obj_ens, runtime))
                val_obj = val_obj_ens
                test_obj = test_obj_ens

            except TimeoutException as e:
                print('Time out!')
                val_obj, test_obj, val_pred, test_pred = np.inf, np.inf, None, None
                runtime = time.time() - start_time
                print('Iter: %d, Failed Obj: %f, Test obj: %f, Eval time: %f' % (iter, val_obj, test_obj, runtime))
            except Exception as e:
                print(e)
                val_obj, test_obj, val_pred, test_pred = np.inf, np.inf, None, None
                runtime = time.time() - start_time
                print('Iter: %d, Failed Obj: %f, Test obj: %f, Eval time: %f' % (iter, val_obj, test_obj, runtime))
            self.update(config, val_obj, test_obj, val_pred, test_pred, runtime)
        with open(self.save_path, 'wb') as f:
            pkl.dump(self.observations, f)