import datetime
import json
import logging
import os
import time
import warnings
from collections import OrderedDict

import numpy as np
import pandas as pd
from numpy import corrcoef
from pandas import DataFrame, Series
from sklearn.metrics import accuracy_score, balanced_accuracy_score, matthews_corrcoef, f1_score, classification_report  # , roc_curve, auc
from sklearn.metrics import mean_absolute_error, explained_variance_score, r2_score, mean_squared_error, median_absolute_error  # , max_error

<<<<<<< HEAD
from ..constants import BINARY, MULTICLASS, REGRESSION
from ..trainer.abstract_trainer import AbstractTrainer
from ..tuning.ensemble_selection import EnsembleSelection
=======
from ..constants import BINARY, MULTICLASS, REGRESSION, SOFTCLASS
from ...data.label_cleaner import LabelCleaner
>>>>>>> augment-distill methodology
from ..utils import get_pred_from_proba
from ...data.label_cleaner import LabelCleaner
from ...utils.loaders import load_pkl, load_pd
from ...utils.savers import save_pkl, save_pd

logger = logging.getLogger(__name__)


# TODO: - Semi-supervised learning
# TODO: - Minimize memory usage of DataFrames (convert int64 -> uint8 when possible etc.)
# Learner encompasses full problem, loading initial data, feature generation, model training, model prediction
# TODO: Loading learner from S3 on Windows may cause issues due to os.path.sep
class AbstractLearner:
    save_file_name = 'learner.pkl'

    def __init__(self, path_context: str, label: str, id_columns: list, feature_generator, label_count_threshold=10,
                 problem_type=None, objective_func=None, stopping_metric=None, is_trainer_present=False):
        self.path_context, self.model_context, self.latest_model_checkpoint, self.eval_result_path, self.pred_cache_path, self.save_path = self.create_contexts(path_context)
        self.label = label
        self.submission_columns = id_columns
        self.threshold = label_count_threshold
        self.problem_type = problem_type
        self.trainer_problem_type = None
        self.objective_func = objective_func
        self.stopping_metric = stopping_metric
        self.is_trainer_present = is_trainer_present
        self.cleaner = None
        self.label_cleaner: LabelCleaner = None
        self.feature_generator = feature_generator
        self.feature_generators = [self.feature_generator]

        self.trainer: AbstractTrainer = None
        self.trainer_type = None
        self.trainer_path = None
        self.reset_paths = False

        self.time_fit_total = None
        self.time_fit_preprocessing = None
        self.time_fit_training = None
        self.time_limit = None

    @property
    def class_labels(self):
        if self.problem_type == MULTICLASS:
            return self.label_cleaner.ordered_class_labels
        else:
            return None

    def set_contexts(self, path_context):
        self.path_context, self.model_context, self.latest_model_checkpoint, self.eval_result_path, self.pred_cache_path, self.save_path = self.create_contexts(path_context)

    def create_contexts(self, path_context):
        model_context = path_context + 'models' + os.path.sep
        latest_model_checkpoint = model_context + 'model_checkpoint_latest.pointer'
        eval_result_path = model_context + 'eval_result.pkl'
        predictions_path = path_context + 'predictions.csv'
        save_path = path_context + self.save_file_name
        return path_context, model_context, latest_model_checkpoint, eval_result_path, predictions_path, save_path

    def fit(self, X: DataFrame, X_test: DataFrame = None, scheduler_options=None, hyperparameter_tune=True,
            feature_prune=False, holdout_frac=0.1, hyperparameters={}, verbosity=2):
        raise NotImplementedError

    # TODO: Add pred_proba_cache functionality as in predict()
    def predict_proba(self, X_test: DataFrame, model=None, as_pandas=False, inverse_transform=True, sample=None):
        ##########
        # Enable below for local testing # TODO: do we want to keep sample option?
        if sample is not None:
            X_test = X_test.head(sample)
        ##########
        trainer = self.load_trainer()

        X_test = self.transform_features(X_test)
        y_pred_proba = trainer.predict_proba(X_test, model=model)
        if inverse_transform:
            y_pred_proba = self.label_cleaner.inverse_transform_proba(y_pred_proba)
        if as_pandas:
            if self.problem_type == MULTICLASS:
                y_pred_proba = pd.DataFrame(data=y_pred_proba, columns=self.class_labels)
            else:
                y_pred_proba = pd.Series(data=y_pred_proba, name=self.label)
        return y_pred_proba

    # TODO: Add decorators for cache functionality, return core code to previous state
    # use_pred_cache to check for a cached prediction of rows, can dramatically speedup repeated runs
    # add_to_pred_cache will update pred_cache with new predictions
    def predict(self, X_test: DataFrame, model=None, as_pandas=False, sample=None, use_pred_cache=False, add_to_pred_cache=False):
        pred_cache = None
        if use_pred_cache or add_to_pred_cache:
            try:
                pred_cache = load_pd.load(path=self.pred_cache_path, dtype=X_test[self.submission_columns].dtypes.to_dict())
            except Exception:
                pass

        if use_pred_cache and (pred_cache is not None):
            X_id = X_test[self.submission_columns]
            X_in_cache_with_pred = pd.merge(left=X_id.reset_index(), right=pred_cache, on=self.submission_columns).set_index('index')  # Will break if 'index' == self.label or 'index' in self.submission_columns
            X_test_cache_miss = X_test[~X_test.index.isin(X_in_cache_with_pred.index)]
            logger.log(20, f'Using cached predictions for {len(X_in_cache_with_pred)} out of {len(X_test)} rows, '
                           f'which have already been predicted previously. To make new predictions, set use_pred_cache=False')
        else:
            X_in_cache_with_pred = pd.DataFrame(data=None, columns=self.submission_columns + [self.label])
            X_test_cache_miss = X_test

        if len(X_test_cache_miss) > 0:
            y_pred_proba = self.predict_proba(X_test=X_test_cache_miss, model=model, inverse_transform=False, sample=sample)
            if self.trainer_problem_type is not None:
                problem_type = self.trainer_problem_type
            else:
                problem_type = self.problem_type
            y_pred = get_pred_from_proba(y_pred_proba=y_pred_proba, problem_type=problem_type)
            y_pred = self.label_cleaner.inverse_transform(pd.Series(y_pred))
            y_pred.index = X_test_cache_miss.index
        else:
            logger.debug('All X_test rows found in cache, no need to load model')
            y_pred = X_in_cache_with_pred[self.label].values
            if as_pandas:
                y_pred = pd.Series(data=y_pred, name=self.label)
            return y_pred

        if add_to_pred_cache:
            X_id_with_y_pred = X_test_cache_miss[self.submission_columns].copy()
            X_id_with_y_pred[self.label] = y_pred
            if pred_cache is None:
                pred_cache = X_id_with_y_pred.drop_duplicates(subset=self.submission_columns).reset_index(drop=True)
            else:
                pred_cache = pd.concat([X_id_with_y_pred, pred_cache]).drop_duplicates(subset=self.submission_columns).reset_index(drop=True)
            save_pd.save(path=self.pred_cache_path, df=pred_cache)

        if len(X_in_cache_with_pred) > 0:
            y_pred = pd.concat([y_pred, X_in_cache_with_pred[self.label]]).reindex(X_test.index)

        y_pred = y_pred.values
        if as_pandas:
            y_pred = pd.Series(data=y_pred, name=self.label)
        return y_pred

    # TODO: Experimental, not integrated with core code, highly subject to change
    # TODO: Add X, y parameters -> Requires proper preprocessing of train data
    # X should be X_train from original fit call, if None then load saved X_train in trainer (if save_data=True)
    # y should be y_train from original fit call, if None then load saved y_train in trainer (if save_data=True)
    # Compresses bagged ensembles to a single model fit on 100% of the data.
    # Results in worse model quality (-), but much faster inference times (+++), reduced memory usage (+++), and reduced space usage (+++).
    def compress(self):
        X = None
        y = None
        if X is not None:
            if y is None:
                X, y = self.extract_label(X)
            X = self.transform_features(X)
            y = self.label_cleaner.transform(y)
        else:
            y = None
        trainer = self.load_trainer()
        trainer.compress(X=X, y=y)

    # TODO: Experimental, not integrated with core code, highly subject to change
    # TODO: Add X, y parameters -> Requires proper preprocessing of train data
    # X should be X_train from original fit call, if None then load saved X_train in trainer (if save_data=True)
    # y should be y_train from original fit call, if None then load saved y_train in trainer (if save_data=True)
    # Distills the full ensemble into a single model trained on 100% of the data.
    # Results in significantly worse model quality (--), but extremely faster inference times (++++), minimal memory usage (++++), and minimal space usage (++++).
    def distill(self):
        X = None
        y = None
        if X is not None:
            if y is None:
                X, y = self.extract_label(X)
            X = self.transform_features(X)
            if self.problem_type != MULTICLASS and self.problem_type != SOFTCLASS:
                y = self.label_cleaner.transform(y)
        else:
            y = None
        trainer = self.load_trainer()
        trainer.distill(X=X, y=y)

    def augment_distill(self, X=None, y=None, num_augmented_samples=50000):
        if X is not None:
            if y is None:
                X, y = self.extract_label(X)
            X = self.transform_features(X)
            if self.problem_type != MULTICLASS and self.problem_type != SOFTCLASS:
                y = self.label_cleaner.transform(y)
        else:
            y = None
        trainer = self.load_trainer()
        trainer.augment_distill(X=X, y=y, num_augmented_samples=num_augmented_samples)

    def fit_transform_features(self, X, y=None):
        for feature_generator in self.feature_generators:
            X = feature_generator.fit_transform(X, y)
        return X

    def transform_features(self, X):
        for feature_generator in self.feature_generators:
            X = feature_generator.transform(X)
        return X

    def score(self, X: DataFrame, y=None, model=None):
        if y is None:
            X, y = self.extract_label(X)
        X = self.transform_features(X)
        y = self.label_cleaner.transform(y)
        trainer = self.load_trainer()
        if self.problem_type == MULTICLASS:
            y = y.fillna(-1)
            if trainer.objective_func_expects_y_pred:
                return trainer.score(X=X, y=y, model=model)
            else:
                # Log loss
                if -1 in y.unique():
                    raise ValueError(f'Multiclass scoring with eval_metric={self.objective_func.name} does not support unknown classes.')
                return trainer.score(X=X, y=y, model=model)
        else:
            return trainer.score(X=X, y=y, model=model)

    # Scores both learner and all individual models, along with computing the optimal ensemble score + weights (oracle)
    def score_debug(self, X: DataFrame, y=None, silent=False):
        if y is None:
            X, y = self.extract_label(X)
        X = self.transform_features(X)
        y = self.label_cleaner.transform(y)
        trainer = self.load_trainer()
        if self.problem_type == MULTICLASS:
            y = y.fillna(-1)
            if (not trainer.objective_func_expects_y_pred) and (-1 in y.unique()):
                # Log loss
                raise ValueError(f'Multiclass scoring with eval_metric={self.objective_func.name} does not support unknown classes.')
        # TODO: Move below into trainer, should not live in learner

        max_level_to_check = trainer.get_max_level_all()
        scores = {}
        pred_times = {}
        pred_times_full = {}
        pred_time_offset = 0
        pred_probas = None
        stack_names = list(trainer.models_level.keys())
        stack_names_not_core = [name for name in stack_names if name != 'core']

        for level in range(max_level_to_check + 1):
            X_stack = trainer.get_inputs_to_stacker(X, level_start=0, level_end=level, y_pred_probas=pred_probas)

            for stack_name in stack_names_not_core:
                model_names_aux = trainer.models_level[stack_name][level]
                if len(model_names_aux) > 0:
                    pred_probas_auxiliary, pred_probas_time_auxiliary = self.get_pred_probas_models_and_time(X=X_stack, trainer=trainer, model_names=model_names_aux)
                    for i, model_name in enumerate(model_names_aux):
                        pred_proba = pred_probas_auxiliary[i]
                        pred_times[model_name] = pred_probas_time_auxiliary[i]
                        pred_times_full[model_name] = pred_probas_time_auxiliary[i] + pred_time_offset
                        if (trainer.problem_type == BINARY) and (self.problem_type == MULTICLASS):
                            pred_proba = self.label_cleaner.inverse_transform_proba(pred_proba)

                        if trainer.objective_func_expects_y_pred:
                            pred = get_pred_from_proba(y_pred_proba=pred_proba, problem_type=self.problem_type)
                            scores[model_name] = self.objective_func(y, pred)
                        else:
                            scores[model_name] = self.objective_func(y, pred_proba)

            model_names_core = trainer.models_level['core'][level]
            if len(model_names_core) > 0:
                pred_probas, pred_probas_time = self.get_pred_probas_models_and_time(X=X_stack, trainer=trainer, model_names=model_names_core)
                for i, model_name in enumerate(model_names_core):
                    pred_proba = pred_probas[i]
                    pred_times[model_name] = pred_probas_time[i]
                    pred_times_full[model_name] = pred_probas_time[i] + pred_time_offset
                    if (trainer.problem_type == BINARY) and (self.problem_type == MULTICLASS):
                        pred_proba = self.label_cleaner.inverse_transform_proba(pred_proba)

                    if trainer.objective_func_expects_y_pred:
                        pred = get_pred_from_proba(y_pred_proba=pred_proba, problem_type=self.problem_type)
                        scores[model_name] = self.objective_func(y, pred)
                    else:
                        scores[model_name] = self.objective_func(y, pred_proba)
                pred_time_offset += sum(pred_probas_time)

                ensemble_selection = EnsembleSelection(ensemble_size=100, problem_type=trainer.problem_type, metric=self.objective_func)
                ensemble_selection.fit(predictions=pred_probas, labels=y, identifiers=None)
                oracle_weights = ensemble_selection.weights_
                oracle_pred_time_start = time.time()
                oracle_pred_proba_norm = [pred * weight for pred, weight in zip(pred_probas, oracle_weights)]
                oracle_pred_proba_ensemble = np.sum(oracle_pred_proba_norm, axis=0)
                if (trainer.problem_type == BINARY) and (self.problem_type == MULTICLASS):
                    oracle_pred_proba_ensemble = self.label_cleaner.inverse_transform_proba(oracle_pred_proba_ensemble)
                oracle_pred_time = time.time() - oracle_pred_time_start
                pred_times[f'oracle_ensemble_l' + str(level + 1)] = oracle_pred_time
                pred_times_full['oracle_ensemble_l' + str(level + 1)] = oracle_pred_time + pred_time_offset
                if trainer.objective_func_expects_y_pred:
                    oracle_pred_ensemble = get_pred_from_proba(y_pred_proba=oracle_pred_proba_ensemble, problem_type=self.problem_type)
                    scores['oracle_ensemble_l' + str(level + 1)] = self.objective_func(y, oracle_pred_ensemble)
                else:
                    scores['oracle_ensemble_l' + str(level + 1)] = self.objective_func(y, oracle_pred_proba_ensemble)

        logger.debug('Model scores:')
        logger.debug(str(scores))
        df = pd.DataFrame(
            data={
                'model': list(scores.keys()),
                'score_test': list(scores.values()),
                'pred_time_test': [pred_times[model] for model in scores.keys()],
                'pred_time_test_full': [pred_times_full[model] for model in scores.keys()],
            }
        )

        df = df.sort_values(by='score_test', ascending=False).reset_index(drop=True)

        leaderboard_df = self.leaderboard(silent=silent)

        df_merged = pd.merge(df, leaderboard_df, on='model')
        df_columns_lst = df_merged.columns.tolist()
        explicit_order = [
            'model',
            'score_test',
            'score_val',
            'fit_time',
            'pred_time_test_full',
            'pred_time_test',
            'pred_time_val',
            'stack_level',
        ]
        df_columns_other = [column for column in df_columns_lst if column not in explicit_order]
        df_columns_new = explicit_order + df_columns_other
        df_merged = df_merged[df_columns_new]

        # TODO: Fix pred_time_test_full value for weighted_ensembles / models who only have X base_models instead of the full level.
        #  Currently it is over-estimating prediction time
        #  Fix by implementing DAG representation
        return df_merged

    def get_pred_probas_models_and_time(self, X, trainer, model_names):
        pred_probas_lst = []
        pred_probas_time_lst = []
        for model_name in model_names:
            model = trainer.load_model(model_name)
            time_start = time.time()
            pred_probas = trainer.pred_proba_predictions(models=[model], X_test=X)
            if (self.problem_type == MULTICLASS) and (not trainer.objective_func_expects_y_pred):
                # Handles case where we need to add empty columns to represent classes that were not used for training
                pred_probas = [self.label_cleaner.inverse_transform_proba(pred_proba) for pred_proba in pred_probas]
            time_diff = time.time() - time_start
            pred_probas_lst += pred_probas
            pred_probas_time_lst.append(time_diff)
        return pred_probas_lst, pred_probas_time_lst

    def evaluate(self, y_true, y_pred, silent=False, auxiliary_metrics=False, detailed_report=True, high_always_good=False):
        """ Evaluate predictions. 
            Args:
                silent (bool): Should we print which metric is being used as well as performance.
                auxiliary_metrics (bool): Should we compute other (problem_type specific) metrics in addition to the default metric?
                detailed_report (bool): Should we computed more-detailed versions of the auxiliary_metrics? (requires auxiliary_metrics=True).
                high_always_good (bool): If True, this means higher values of returned metric are ALWAYS superior (so metrics like MSE should be returned negated)
            
            Returns single performance-value if auxiliary_metrics=False.
            Otherwise returns dict where keys = metrics, values = performance along each metric.
        """

        # Remove missing labels and produce warning if any are found:
        if self.problem_type == REGRESSION:
            missing_indicators = [(y is None or np.isnan(y)) for y in y_true]
        else:
            missing_indicators = [(y is None or y == '') for y in y_true]
        missing_inds = [i for i, j in enumerate(missing_indicators) if j]
        if len(missing_inds) > 0:
            nonmissing_inds = [i for i, j in enumerate(missing_indicators) if not j]
            y_true = y_true[nonmissing_inds]
            y_pred = y_pred[nonmissing_inds]
            warnings.warn(f"There are {len(missing_inds)} (out of {len(y_true)}) evaluation datapoints for which the label is missing. "
                          f"AutoGluon removed these points from the evaluation, which thus may not be entirely representative. "
                          f"You should carefully study why there are missing labels in your evaluation data.")

        perf = self.objective_func(y_true, y_pred)
        metric = self.objective_func.name
        if not high_always_good:
            sign = self.objective_func._sign
            perf = perf * sign  # flip negative once again back to positive (so higher is no longer necessarily better)
        if not silent:
            logger.log(20, f"Evaluation: {metric} on test data: {perf}")
        if not auxiliary_metrics:
            return perf

        # Otherwise compute auxiliary metrics:
        perf_dict = OrderedDict({metric: perf})
        if self.problem_type == REGRESSION:  # Additional metrics: R^2, Mean-Absolute-Error, Pearson correlation
            pearson_corr = lambda x, y: corrcoef(x, y)[0][1]
            pearson_corr.__name__ = 'pearson_correlation'
            regression_metrics = [
                mean_absolute_error, explained_variance_score, r2_score, pearson_corr, mean_squared_error, median_absolute_error,
                # max_error
            ]
            for reg_metric in regression_metrics:
                metric_name = reg_metric.__name__
                if metric_name not in perf_dict:
                    perf_dict[metric_name] = reg_metric(y_true, y_pred)
        else:  # Compute classification metrics
            classif_metrics = [accuracy_score, balanced_accuracy_score, matthews_corrcoef]
            if self.problem_type == BINARY:  # binary-specific metrics
                # def auc_score(y_true, y_pred): # TODO: this requires y_pred to be probability-scores
                #     fpr, tpr, _ = roc_curve(y_true, y_pred, pos_label)
                #   return auc(fpr, tpr)
                f1micro_score = lambda y_true, y_pred: f1_score(y_true, y_pred, average='micro')
                f1micro_score.__name__ = f1_score.__name__
                classif_metrics += [f1micro_score]  # TODO: add auc?
            elif self.problem_type == MULTICLASS:  # multiclass metrics
                classif_metrics += []  # TODO: No multi-class specific metrics for now. Include, top-1, top-5, top-10 accuracy here.
            for cl_metric in classif_metrics:
                metric_name = cl_metric.__name__
                if metric_name not in perf_dict:
                    perf_dict[metric_name] = cl_metric(y_true, y_pred)

        if not silent:
            logger.log(20, "Evaluations on test data:")
            logger.log(20, json.dumps(perf_dict, indent=4))
        if detailed_report and (self.problem_type != REGRESSION):
            # One final set of metrics to report
            cl_metric = lambda y_true, y_pred: classification_report(y_true, y_pred, output_dict=True)
            metric_name = cl_metric.__name__
            if metric_name not in perf_dict:
                perf_dict[metric_name] = cl_metric(y_true, y_pred)
                if not silent:
                    logger.log(20, "Detailed (per-class) classification report:")
                    logger.log(20, json.dumps(perf_dict[metric_name], indent=4))
        return perf_dict

    def extract_label(self, X):
        if self.label not in list(X.columns):
            raise ValueError(f"Provided DataFrame does not contain label column: {self.label}")
        y = X[self.label].copy()
        X = X.drop(self.label, axis=1)
        return X, y

    def submit_from_preds(self, X_test: DataFrame, y_pred_proba, save=True, save_proba=False):
        submission = X_test[self.submission_columns].copy()
        y_pred = get_pred_from_proba(y_pred_proba=y_pred_proba, problem_type=self.problem_type)

        submission[self.label] = y_pred
        submission[self.label] = self.label_cleaner.inverse_transform(submission[self.label])

        if save:
            utcnow = datetime.datetime.utcnow()
            timestamp_str_now = utcnow.strftime("%Y%m%d_%H%M%S")
            path_submission = self.model_context + 'submissions' + os.path.sep + 'submission_' + timestamp_str_now + '.csv'
            path_submission_proba = self.model_context + 'submissions' + os.path.sep + 'submission_proba_' + timestamp_str_now + '.csv'
            save_pd.save(path=path_submission, df=submission)
            if save_proba:
                submission_proba = pd.DataFrame(y_pred_proba)  # TODO: Fix for multiclass
                save_pd.save(path=path_submission_proba, df=submission_proba)

        return submission

    def predict_and_submit(self, X_test: DataFrame, save=True, save_proba=False):
        y_pred_proba = self.predict_proba(X_test=X_test, inverse_transform=False)
        return self.submit_from_preds(X_test=X_test, y_pred_proba=y_pred_proba, save=save, save_proba=save_proba)

    def leaderboard(self, X=None, y=None, silent=False):
        if X is not None:
            leaderboard = self.score_debug(X=X, y=y, silent=True)
        else:
            trainer = self.load_trainer()
            leaderboard = trainer.leaderboard()
        if not silent:
            with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', 1000):
                print(leaderboard)
        return leaderboard

    def info(self):
        trainer = self.load_trainer()
        trainer_info = trainer.info()
        learner_info = {
            'path_context': self.path_context,
            'time_fit_preprocessing': self.time_fit_preprocessing,
            'time_fit_training': self.time_fit_training,
            'time_fit_total': self.time_fit_total,
            'time_limit': self.time_limit,
        }

        trainer_info.update(learner_info)
        return trainer_info

    @staticmethod
    def get_problem_type(y: Series):
        """ Identifies which type of prediction problem we are interested in (if user has not specified).
            Ie. binary classification, multi-class classification, or regression. 
        """
        if len(y) == 0:
            raise ValueError("provided labels cannot have length = 0")
        y = y.dropna()  # Remove missing values from y (there should not be any though as they were removed in Learner.general_data_processing())
        unique_vals = y.unique()
        num_rows = len(y)
        # print(unique_vals)
        logger.log(20, f'Here are the first 10 unique label values in your data:  {unique_vals[:10]}')
        unique_count = len(unique_vals)
        MULTICLASS_LIMIT = 1000  # if numeric and class count would be above this amount, assume it is regression
        if num_rows > 1000:
            REGRESS_THRESHOLD = 0.05  # if the unique-ratio is less than this, we assume multiclass classification, even when labels are integers
        else:
            REGRESS_THRESHOLD = 0.1

        if len(unique_vals) == 2:
            problem_type = BINARY
            reason = "only two unique label-values observed"
        elif np.issubdtype(unique_vals.dtype, np.floating):
            unique_ratio = len(unique_vals) / float(len(y))
            if (unique_ratio <= REGRESS_THRESHOLD) and (unique_count <= MULTICLASS_LIMIT):
                try:
                    can_convert_to_int = np.array_equal(y, y.astype(int))
                    if can_convert_to_int:
                        problem_type = MULTICLASS
                        reason = "dtype of label-column == float, but few unique label-values observed and label-values can be converted to int"
                    else:
                        problem_type = REGRESSION
                        reason = "dtype of label-column == float and label-values can't be converted to int"
                except:
                    problem_type = REGRESSION
                    reason = "dtype of label-column == float and label-values can't be converted to int"
            else:
                problem_type = REGRESSION
                reason = "dtype of label-column == float and many unique label-values observed"
        elif unique_vals.dtype == 'object':
            problem_type = MULTICLASS
            reason = "dtype of label-column == object"
        elif np.issubdtype(unique_vals.dtype, np.integer):
            unique_ratio = len(unique_vals) / float(len(y))
            if (unique_ratio <= REGRESS_THRESHOLD) and (unique_count <= MULTICLASS_LIMIT):
                problem_type = MULTICLASS  # TODO: Check if integers are from 0 to n-1 for n unique values, if they have a wide spread, it could still be regression
                reason = "dtype of label-column == int, but few unique label-values observed"
            else:
                problem_type = REGRESSION
                reason = "dtype of label-column == int and many unique label-values observed"
        else:
            raise NotImplementedError('label dtype', unique_vals.dtype, 'not supported!')
        logger.log(25, "AutoGluon infers your prediction problem is: %s  (because %s)" % (problem_type, reason))
        logger.log(25, "If this is wrong, please specify `problem_type` argument in fit() instead (You may specify problem_type as one of: ['%s', '%s', '%s'])\n" % (BINARY, MULTICLASS, REGRESSION))
        return problem_type

    def save(self):
        save_pkl.save(path=self.save_path, object=self)

    # reset_paths=True if the learner files have changed location since fitting.
    # TODO: Potentially set reset_paths=False inside load function if it is the same path to avoid re-computing paths on all models
    @classmethod
    def load(cls, path_context, reset_paths=True):
        load_path = path_context + cls.save_file_name
        obj = load_pkl.load(path=load_path)
        if reset_paths:
            obj.set_contexts(path_context)
            obj.trainer_path = obj.model_context
            obj.reset_paths = reset_paths
            # TODO: Still have to change paths of models in trainer + trainer object path variables
            return obj
        else:
            obj.set_contexts(obj.path_context)
            return obj

    def save_trainer(self, trainer):
        if self.is_trainer_present:
            self.trainer = trainer
            self.save()
        else:
            self.trainer_path = trainer.path
            trainer.save()

    def load_trainer(self) -> AbstractTrainer:
        if self.is_trainer_present:
            return self.trainer
        else:
            return self.trainer_type.load(path=self.trainer_path, reset_paths=self.reset_paths)

    # TODO: Add to predictor
    # TODO: Make this safe in large ensemble situations that would result in OOM
    # Loads all models in memory so that they don't have to loaded during predictions
    def persist_trainer(self, low_memory=False):
        self.trainer = self.load_trainer()
        self.is_trainer_present = True
        if not low_memory:
            self.trainer.load_models_into_memory()
            # Warning: After calling this, it is not necessarily safe to save learner or trainer anymore
            #  If neural network is persisted and then trainer or learner is saved, there will be an exception thrown
