#!/usr/bin/env python3

"""
a single logit for positive class prediction.
s(x) = 1 / (1 + e^x)
"""

from jax import random
from jax import numpy as jnp
from jax import jit
import optax

from fairgym.utils.misc import _augment_feat_array


def _logits(params, aug_feat_array):
    return jnp.einsum("ij,j->i", aug_feat_array, params)


def _expit_b(x, b):
    """
    Compute s(x) - b, component-wise,

    where s(x) = logistic(x) = 1 / (1 + e^-x)
    avoid loss of precision issues with naive calculation

    Adapted from
    https://fa.bianp.net/blog/2019/evaluate_logistic/

    1 / (1 + e^-x) - b
    =
    [(1 - b) - b e^-x] / (1 + e^-x) # won't blow up for postive x
    =
    [(1 - b) e^x - b] / (e^x + 1) # won't blow up for negative x
    """

    exp_x = jnp.exp(x)
    exp_nx = jnp.exp(-x)

    # won't blow up for negative x
    out_nx = ((1 - b) * exp_x - b) / (1 + exp_x)
    # won't blow up for positive x
    out_x = ((1 - b) - b * exp_nx) / (1 + exp_nx)

    return out_nx * (x < 0) + out_x * (x >= 0)


def _logistic_fn(logits):
    """
    logistic function
    1 / (1 + e^-x)
    of logits
    x
    """
    return _expit_b(logits, 0)


def _sample_loss(logits, labels, sample_weight):
    """
    Computes logistic loss for each sample

    Adapted from
    https://fa.bianp.net/blog/2019/evaluate_logistic/

    Parameters
    ----------
    logits: array-like, shape (n_samples,)
        Array of logits

    labels: array-like, shape (n_samples,)
        true labels

    Returns
    -------
    loss: float

    Let y in {0, 1} be the label and
    Let p be the predicted probability that y=1
    Let x be the logit value such that p = logistic(x)

    Sample loss is log-loss or logistic loss or cross-entropy loss.
    L = - y log(p) - (1 - y) log(1 - p)

    When p = logistic(x) = s(x) = 1 / (1 + e^-x)
    where x is a logit value,
    log(1 - p) = log( e^-x / (1 + e^-x) ) = -x + log(s(x))

    Sample loss can thus be simplified to
    L = - y log(s(x)) - (1 - y) [ -x + log(s(x)) ]
      = (1 - y) x - log(s(x))
    """

    return (1 - labels) * logits - jnp.log(_logistic_fn(logits))


def _loss(logits, labels, sample_weight):
    return jnp.einsum("i->", _sample_loss(logits, labels, sample_weight))


def _loss_grad(params, aug_feat_array, labels, sample_weight):
    """
    Computes the gradient of the logistic loss.

    Adapted from
    https://fa.bianp.net/blog/2019/evaluate_logistic/

    Parameters
    ----------
    params: array-like, shape (n_features,)
        Bias column, then weights

    aug_feat_array: array-like, shape (n_samples, n_features)
        Data matrix, with column of 1s at beginning

    labels: array-like, shape (n_samples,)
        true labels

    Returns
    -------
    grad: array-like, shape (n_features,)


    Let y in {0, 1} be the label and p be the predicted probability that y=1

    Recall that the derivative of logistic function s(x) is
    s(x) * (1 - s(x))

    and the derivative of log(s(x)) is, therefore
    (1 - s(x))

    The derivative of sample loss L with respect to lot x is therefore
    (1 - y) - (1 - s(x)) = s(x) - y
    denote this quantity as dd

    The gradient of sample loss in params,
    when x = (aug_feat_array . params), is
    [s(x) - y] * aug_feat_array
    =
    dd * aug_feat_array
    """

    n = aug_feat_array.shape[0]

    logits = _logits(params, aug_feat_array)
    dd = _expit_b(logits, labels)

    # allow sample_weight to weight samples
    # return weighted average of sample_loss gradients
    return jnp.einsum("ij,i,i->j", aug_feat_array, dd, sample_weight) / n


class LogisticRegression:
    """
    Logistic regression implemented with JAX

    Basic use:

    clf = LogisticRegression()
    clf = clf.fit(X_train, Y_train)
    Y_hat = clf.predict_proba(X_test)
    """

    def __init__(
        self, random_state=0, num_iter=100, learn_rate=1e-1, stopping_update_size=1e-5
    ):
        self.key = random.PRNGKey(random_state)
        self.num_iter = num_iter
        self.learn_rate = learn_rate
        self.stopping_update_size = stopping_update_size

        self.params = None

        self.loss_grad_method = _loss_grad

    @staticmethod
    def _params_shape(num_dimensions):
        return (num_dimensions + 1,)

    def fit(self, feat_array, labels, sample_weight=None):
        """
        train self.params based on feat_array and labels, using sample_weights
        to weight loss of individual samples.
        """

        # check or coerce shapes of arguments #################################
        # feat_array
        if len(feat_array.shape) == 1:
            num_dimensions = 1
        else:
            assert len(feat_array.shape) == 2
            num_dimensions = feat_array.shape[1]
        num_examples = feat_array.shape[0]
        aug_feat_array = _augment_feat_array(feat_array)

        # params
        params_shape = self._params_shape(num_dimensions)

        # labels
        assert labels.shape[0] == num_examples

        # sample_weight
        if sample_weight is None:
            sample_weight = jnp.ones(labels.shape)

        if (self.params is not None) and (self.params.shape == params_shape):
            # warm start
            params = self.params
        else:
            # cold start
            params = random.normal(self.key, shape=params_shape)

        # (re)initialize optimizer
        optimizer = optax.sgd(self.learn_rate)
        opt_state = optimizer.init(params)

        @jit
        def step(params, opt_state):
            """
            One step of gradient decent:
            params[t] -> params[t+1]
            """

            grads = self.loss_grad_method(params, aug_feat_array, labels, sample_weight)

            updates, opt_state = optimizer.update(grads, opt_state, params)
            params = optax.apply_updates(params, updates)

            return params, opt_state, updates

        for _ in range(self.num_iter):
            params, opt_state, updates = step(params, opt_state)

            # makes jit choke
            # if jnp.linalg.norm(updates) <= self.stopping_update_size:
            #     self.params = params
            #     return self

        self.params = params

        return self

    def loss(self, feat_array, labels, sample_weight=None):
        """
        Average loss on feat_array, labels, and sample_weight
        using currently trained parameters
        """

        # sample_weight
        if sample_weight is None:
            sample_weight = jnp.ones(labels.shape)

        aug_feat_array = _augment_feat_array(feat_array)
        logits = _logits(self.params, aug_feat_array)

        return _loss(logits, labels, sample_weight)

    def predict_proba(self, feat_array):
        """
        return probability of label=1
        """

        aug_feat_array = _augment_feat_array(feat_array)

        logits = _logits(self.params, aug_feat_array)

        # Pr(Y=1)
        return _logistic_fn(logits)


if __name__ == "__main__":
    N = 1000
    X = jnp.linspace(0, 1, N)

    key = random.PRNGKey(0)
    y = random.bernoulli(key, p=X)

    clf = LogisticRegression(
        random_state=0, num_iter=1000, learn_rate=1e0, stopping_update_size=1e-5
    )
    import time

    now = time.time_ns()
    clf.fit(X.reshape((N, 1)), y)
    print((time.time_ns() - now) / 1e9, "seconds")
    print(clf.loss(X, y), "loss")
