#!/usr/bin/env python
# -*- coding: utf-8 -*-

import jax.numpy as np

import staix
from staix import Dense, ScaleDenseNoBias, elementwise, Flatten
from jax.nn import celu, tanh, softplus

tanh1 = lambda x: tanh(x) / np.sqrt(0.394294)
celu1 = lambda x: celu(x, alpha=1) / np.sqrt(0.644945)
celu0_1 = lambda x: celu(x, alpha=0.1) / np.sqrt(0.5125331)
softplus1 = lambda x: softplus(x) / np.sqrt(0.921246)


flatten = lambda l: [item for sublist in l for item in sublist]


def get_network(d, hidden, input_norm=1.0, activation="softplus"):
    if activation == "softplus":
        ac = softplus1
    elif activation == "tanh":
        ac = tanh1
    net = staix.serial(
        ScaleDenseNoBias(hidden, input_norm=input_norm),
        elementwise(ac),
        *flatten(
            [
                [
                    ScaleDenseNoBias(hidden),
                    elementwise(ac),
                ]
                for i in range(2, d)
            ]
        ),
        ScaleDenseNoBias(1),
    )
    net["d"] = d
    net["hidden"] = hidden
    return net


def huber_loss(predictions, targets, delta=1.0):
    errors = predictions - targets
    abs_errors = np.abs(errors)
    quadratic = np.minimum(abs_errors, delta)
    linear = abs_errors - quadratic
    return 0.5 * quadratic ** 2 + delta * linear


def square_loss(predictions, targets, a=1.0):
    return (a / 2) * np.square(predictions - targets)


def softplus_loss(predictions, targets):
    return softplus(-np.multiply(predictions, targets))
