from .base import Baseline
import numpy as np
import random
import os
import time
from openbox.utils.config_space import get_one_exchange_neighbourhood
from ensembles.ensemble_selection import EnsembleSelection


class RegularizedEAEnsemble(Baseline):
    def __init__(self, config_space, eval_func, iter_num=200, save_dir='./results', task_name='default',
                 ens_size=25, scorer=None, task_type='cls', val_y_labels=None,
                 population_size=20):
        super().__init__(config_space, eval_func, iter_num, save_dir, task_name)

        self.timestamp = time.time()
        self.save_path = os.path.join(self.save_dir, '%s_%s_%d_%s.pkl' % (task_name, 'reaes', iter_num, self.timestamp))

        # Intermediate ensemble
        assert val_y_labels is not None
        self.val_y_labels = val_y_labels
        self.ens_size = ens_size
        self.scorer = scorer
        self.task_type = task_type
        self.ensemble = None
        self.e_config_list = []
        self.e_valid_list = []
        self.cmp_config_list = []

        self.population_size = population_size
        self.population = list()

    def sample(self):
        num_config_evaluated = len(self.observations)

        if num_config_evaluated < self.population_size:  # Sample initial configurations randomly
            repeated_flag = True
            while repeated_flag:
                repeated_flag = False
                config = self.config_space.sample_configuration()
                for observation in self.observations:
                    if config == observation[0]:
                        repeated_flag = True
                        break
            return config

        # Sample a configuration in the population ensemble
        sample_idx = random.sample(self.ensemble.model_idx, 1)[0]
        sample_config = self.population[sample_idx][0]

        # Mutation to 1-step neighbors
        next_config = None
        neighbors_gen = get_one_exchange_neighbourhood(sample_config, seed=self.rng.randint(int(1e8)), num_neighbors=50)

        # Choose a neighbor that is not evaluated
        for neighbor in neighbors_gen:
            repeated_flag = False
            for observation in self.observations:
                if neighbor == observation[0]:
                    repeated_flag = True
                    break
            if repeated_flag is False:
                next_config = neighbor
                return next_config

        # If all the neighbors are evaluated, sample randomly!
        if next_config is None:
            print('All neighbors are evaluated! Choose a random one instead!')
            repeated_flag = True
            while repeated_flag:
                repeated_flag = False
                next_config = self.config_space.sample_configuration()
                for observation in self.observations:
                    if next_config == observation[0]:
                        repeated_flag = True
                        break
            return next_config

    def update(self, config, val_perf, test_perf, val_pred, test_pred, time):
        if val_perf < self.incumbent_value:
            self.incumbent_value = val_perf
            self.incumbent_config = config
        self.observations.append((config, val_perf, test_perf, val_pred, test_pred, time))
        self.population.append((config, val_perf, test_perf, val_pred, test_pred, time))

        # Eliminate the oldest observation
        if len(self.population) > self.population_size:
            self.population.pop(0)

        self.e_valid_list = []
        self.e_config_list = []
        self.cmp_config_list = []
        self.ensemble = EnsembleSelection(ensemble_size=self.ens_size,
                                          task_type=self.task_type,
                                          scorer=self.scorer)

        for ob in self.population:
            config, val_perf, test_perf, val_pred, test_pred, _ = ob
            if val_pred is not None:
                self.e_valid_list.append(val_pred)
                self.e_config_list.append(config)

        if len(self.e_valid_list) > 0:
            self.ensemble.fit(self.e_valid_list, self.val_y_labels)
            print('Idx in population: ' + str(self.ensemble.model_idx))
            ens_val_pred = self.ensemble.predict(self.e_valid_list)
            ens_val_pred = np.argmax(ens_val_pred, axis=-1)
            print(self.scorer._score_func(ens_val_pred, self.val_y_labels))

            # Get configs in the intermediate ensemble
            for i in self.ensemble.model_idx:
                self.cmp_config_list.append(self.e_config_list[i])
