import logging, time
import pickle, psutil, sys
import math
import numpy as np

from .....try_import import try_import_catboost, try_import_catboostdev
from ..abstract.abstract_model import AbstractModel
from .hyperparameters.parameters import get_param_baseline
from .catboost_utils import construct_custom_catboost_metric, SoftclassObjective, SoftclassCustomMetric
from ...constants import PROBLEM_TYPES_CLASSIFICATION, MULTICLASS, SOFTCLASS
from ......core import Int, Real
from ....utils.exceptions import NotEnoughMemoryError, TimeLimitExceeded
from ...utils import normalize_pred_probas

logger = logging.getLogger(__name__)


# TODO: Catboost crashes on multiclass problems where only two classes have significant member count.
#  Question: Do we turn these into binary classification and then convert to multiclass output in Learner? This would make the most sense.
# TODO: Consider having Catboost variant that converts all categoricals to numerical as done in RFModel, was showing improved results in some problems.
class CatboostModel(AbstractModel):
    def __init__(self, path: str, name: str, problem_type: str, objective_func, stopping_metric=None, num_classes=None, hyperparameters=None, features=None, debug=0):
        self.num_classes = num_classes
        super().__init__(path=path, name=name, problem_type=problem_type, objective_func=objective_func, stopping_metric=stopping_metric, hyperparameters=hyperparameters, features=features, debug=debug)
        try_import_catboost()
        from catboost import CatBoostClassifier, CatBoostRegressor
        if problem_type == SOFTCLASS:
            try_import_catboostdev()
            from catboost_dev import CatBoostClassifier, CatBoostRegressor

        self.model_type = CatBoostClassifier if problem_type in PROBLEM_TYPES_CLASSIFICATION else CatBoostRegressor
        self.best_iteration = 0
        if isinstance(self.params['eval_metric'], str):
            self.metric_name = self.params['eval_metric']
        else:
            self.metric_name = type(self.params['eval_metric']).__name__

    def _set_default_params(self):
        default_params = get_param_baseline(problem_type=self.problem_type)
        for param, val in default_params.items():
            self._set_default_param_value(param, val)
        self._set_default_param_value('random_seed', 0)  # Remove randomness for reproducibility
        self._set_default_param_value('eval_metric', construct_custom_catboost_metric(self.stopping_metric, True, not self.stopping_metric_needs_y_pred, self.problem_type))
        if self.problem_type == SOFTCLASS:
            self.params['loss_function'] = SoftclassObjective.SoftLogLossObjective()
            self.params['eval_metric'] = SoftclassCustomMetric.SoftLogLossMetric()
            self._set_default_param_value('early_stopping_rounds', 50)  # Speeds up training with custom losses

    def _get_default_searchspace(self, problem_type):
        spaces = {
            'learning_rate': Real(lower=5e-3, upper=0.2, default=0.1, log=True),
            'depth': Int(lower=5, upper=8, default=6),
            'l2_leaf_reg': Real(lower=1, upper=5, default=3),
        }

        return spaces

    def preprocess(self, X):
        X = super().preprocess(X)
        categoricals = list(X.select_dtypes(include='category').columns)
        if categoricals:
            X = X.copy()
            for category in categoricals:
                current_categories = X[category].cat.categories
                if '__NaN__' in current_categories:
                    X[category] = X[category].fillna('__NaN__')
                else:
                    X[category] = X[category].cat.add_categories('__NaN__').fillna('__NaN__')
        return X

    # TODO: Use Pool in preprocess, optimize bagging to do Pool.split() to avoid re-computing pool for each fold! Requires stateful + y
    #  Pool is much more memory efficient, avoids copying data twice in memory
    def fit(self, X_train, Y_train, X_test=None, Y_test=None, time_limit=None, **kwargs):
        from catboost import Pool
        if self.problem_type == SOFTCLASS:
            try_import_catboostdev()
            from catboost_dev import Pool

        num_rows_train = len(X_train)
        num_cols_train = len(X_train.columns)
        if self.problem_type == MULTICLASS:
            if self.num_classes is not None:
                num_classes = self.num_classes
            else:
                num_classes = 10  # Guess if not given, can do better by looking at y_train
        elif self.problem_type == SOFTCLASS:
            num_classes = Y_train.shape[1]
            self.num_classes = num_classes
        else:
            num_classes = 1

        # TODO: Add ignore_memory_limits param to disable NotEnoughMemoryError Exceptions
        approx_mem_size_req = num_rows_train * num_cols_train * num_classes / 2  # TODO: Extremely crude approximation, can be vastly improved
        if approx_mem_size_req > 1e9:  # > 1 GB
            available_mem = psutil.virtual_memory().available
            ratio = approx_mem_size_req / available_mem
            if ratio > 1:
                logger.warning('Warning: Not enough memory to safely train CatBoost model, roughly requires: %s GB, but only %s GB is available...' % (round(approx_mem_size_req / 1e9, 3), round(available_mem / 1e9, 3)))
                raise NotEnoughMemoryError
            elif ratio > 0.2:
                logger.warning('Warning: Potentially not enough memory to safely train CatBoost model, roughly requires: %s GB, but only %s GB is available...' % (round(approx_mem_size_req / 1e9, 3), round(available_mem / 1e9, 3)))

        start_time = time.time()
        X_train = self.preprocess(X_train)
        cat_features = list(X_train.select_dtypes(include='category').columns)
        X_train = Pool(data=X_train, label=Y_train, cat_features=cat_features)

        if X_test is not None:
            X_test = self.preprocess(X_test)
            X_test = Pool(data=X_test, label=Y_test, cat_features=cat_features)
            eval_set = X_test
            if num_rows_train <= 10000:
                modifier = 1
            else:
                modifier = 10000/num_rows_train
            early_stopping_rounds = max(round(modifier*150), 10)
            num_sample_iter_max = max(round(modifier*100), 2)
        else:
            eval_set = None
            early_stopping_rounds = None
            num_sample_iter_max = 100

        invalid_params = ['num_threads', 'num_gpus']
        for invalid in invalid_params:
            if invalid in self.params:
                self.params.pop(invalid)
        logger.log(15, 'Catboost model hyperparameters:')
        logger.log(15, self.params)

        # TODO: Add more control over these params (specifically early_stopping_rounds)
        verbosity = kwargs.get('verbosity', 2)
        if verbosity <= 1:
            verbose = False
        elif verbosity == 2:
            verbose = False
        elif verbosity == 3:
            verbose = 20
        else:
            verbose = True

        init_model = None
        init_model_tree_count = None
        init_model_best_iteration = None
        init_model_best_score = None

        if time_limit:
            time_left_start = time_limit - (time.time() - start_time)
            if time_left_start <= time_limit * 0.4:  # if 60% of time was spent preprocessing, likely not enough time to train model
                raise TimeLimitExceeded
            params_init = self.params.copy()
            num_sample_iter = min(num_sample_iter_max, params_init['iterations'])
            params_init['iterations'] = num_sample_iter
            self.model = self.model_type(
                **params_init,
            )
            self.model.fit(
                X_train,
                eval_set=eval_set,
                use_best_model=True,
                verbose=verbose,
                # early_stopping_rounds=early_stopping_rounds,
            )

            init_model_tree_count = self.model.tree_count_
            init_model_best_iteration = self.model.get_best_iteration()
            init_model_best_score = self.model.get_best_score()['validation'][self.metric_name]

            time_left_end = time_limit - (time.time() - start_time)
            time_taken_per_iter = (time_left_start - time_left_end) / num_sample_iter
            estimated_iters_in_time = round(time_left_end / time_taken_per_iter)
            init_model = self.model

            params_final = self.params.copy()

            # TODO: This only handles memory with time_limits specified, but not with time_limits=None, handle when time_limits=None
            available_mem = psutil.virtual_memory().available
            model_size_bytes = sys.getsizeof(pickle.dumps(self.model))

            max_memory_proportion = 0.3
            mem_usage_per_iter = model_size_bytes / num_sample_iter
            max_memory_iters = math.floor(available_mem * max_memory_proportion / mem_usage_per_iter)

            params_final['iterations'] = min(self.params['iterations'] - num_sample_iter, estimated_iters_in_time)
            if params_final['iterations'] > max_memory_iters - num_sample_iter:
                if max_memory_iters - num_sample_iter <= 500:
                    logger.warning('Warning: CatBoost will be early stopped due to lack of memory, increase memory to enable full quality models, max training iterations changed to %s from %s' % (max_memory_iters - num_sample_iter, params_final['iterations']))
                params_final['iterations'] = max_memory_iters - num_sample_iter
        else:
            params_final = self.params.copy()

        if params_final['iterations'] > 0:
            self.model = self.model_type(
                **params_final,
            )

            # TODO: Strangely, this performs different if clone init_model is sent in than if trained for same total number of iterations. May be able to optimize catboost models further with this
            self.model.fit(
                X_train,
                eval_set=eval_set,
                verbose=verbose,
                early_stopping_rounds=early_stopping_rounds,
                # use_best_model=True,
                init_model=init_model,
            )

            if init_model is not None:
                final_model_best_score = self.model.get_best_score()['validation'][self.metric_name]
                if self.stopping_metric._optimum > final_model_best_score:
                    if final_model_best_score > init_model_best_score:
                        best_iteration = init_model_tree_count + self.model.get_best_iteration()
                    else:
                        best_iteration = init_model_best_iteration
                else:
                    if final_model_best_score < init_model_best_score:
                        best_iteration = init_model_tree_count + self.model.get_best_iteration()
                    else:
                        best_iteration = init_model_best_iteration

                self.model.shrink(ntree_start=0, ntree_end=best_iteration+1)

        self.best_iteration = self.model.tree_count_ - 1
        self.params_trained['iterations'] = self.model.tree_count_ - 1

    def predict_proba(self, X, preprocess=True):
        if self.problem_type != SOFTCLASS:
            return super().predict_proba(X, preprocess)
        # For SOFTCLASS problem, need to manually transform predictions into probabilities
        if preprocess:
            X = self.preprocess(X)

        y_pred_proba = self.model.predict(X, prediction_type = 'RawFormulaVal')  # need to apply softmax after this
        y_pred_proba = np.exp(y_pred_proba)
        y_pred_proba = np.multiply(y_pred_proba, 1/np.sum(y_pred_proba, axis=1)[:, np.newaxis])
        if y_pred_proba.shape[1] == 2:
            y_pred_proba = y_pred_proba[:,1]
        if self.normalize_predprobs:
            y_pred_proba = normalize_pred_probas(y_pred_proba, self.problem_type)
        return y_pred_proba
