import numpy as np

from ...constants import MULTICLASS, SOFTCLASS


def func_generator(metric, is_higher_better, needs_pred_proba, problem_type):
    if needs_pred_proba:
        if problem_type == MULTICLASS:
            def function_template(y_hat, data):
                y_true = data.get_label()
                y_hat = y_hat.reshape(len(np.unique(y_true)), -1).T
                return metric.name, metric(y_true, y_hat), is_higher_better
        elif problem_type == SOFTCLASS:  # metric must take in soft labels array, like soft_log_loss
            def function_template(y_hat, data):
                y_true = data.softlabels
                y_hat = y_hat.reshape(y_true.shape[1], -1).T
                y_hat = np.exp(y_hat)
                y_hat = np.multiply(y_hat, 1/np.sum(y_hat, axis=1)[:, np.newaxis])
                return metric.name, metric(y_true, y_hat), is_higher_better
        else:
            def function_template(y_hat, data):
                y_true = data.get_label()
                return metric.name, metric(y_true, y_hat), is_higher_better
    else:
        if problem_type == MULTICLASS:
            def function_template(y_hat, data):
                y_true = data.get_label()
                y_hat = y_hat.reshape(len(np.unique(y_true)), -1)
                y_hat = y_hat.argmax(axis=0)
                return metric.name, metric(y_true, y_hat), is_higher_better
        else:
            def function_template(y_hat, data):
                y_true = data.get_label()
                y_hat = np.round(y_hat)
                return metric.name, metric(y_true, y_hat), is_higher_better
    return function_template

def softclass_lgbobj(preds, train_data):
    """ Custom LightGBM loss function for soft (probabilistic, vector-valued) class-labels only,
        which have been appended to lgb.Dataset (train_data) as additional ".softlabels" attribute (2D numpy array).
    """
    softlabels = train_data.softlabels
    num_classes = softlabels.shape[1]
    preds=np.reshape(preds, (len(softlabels), num_classes), order='F')
    preds = np.exp(preds)
    preds = np.multiply(preds, 1/np.sum(preds, axis=1)[:, np.newaxis])
    grad = (preds - softlabels)
    hess = 2.0 * preds * (1.0-preds)
    return grad.flatten('F'), hess.flatten('F')
