"""Random utilities and space for common code."""
import dataclasses
from typing import Optional

import cvxpy as cp

from xoid import constants
from xoid.util import network_util


def compute_loss(loss_fn: str, y_true, y_pred):
    if loss_fn == 'l1':
        return cp.sum(cp.abs(y_true - y_pred)) / y_pred.shape[0]

    elif loss_fn == 'l2':
        return cp.sum_squares(y_true - y_pred) / y_pred.shape[0]

    elif loss_fn == 'sigmoid_cross_entropy':
        loss = -cp.sum(cp.multiply(y_true, y_pred)) + cp.sum(cp.logistic(y_pred))
        return loss / y_pred.shape[0]

    elif loss_fn == 'softmax_cross_entropy':
        log_softmax = y_true - cp.log_sum_exp(y_true, axis=-1, keepdims=True)
        return -cp.sum(cp.multiply(log_softmax, y_true)) / y_pred.shape[0]

    else:
        raise ValueError(loss_fn)


def compute_loss_regularization(regularization: Optional[str], w):
    if regularization == 'l1_loss':
        return cp.sum(cp.abs(w))

    elif regularization == 'l2_loss':
        return cp.sum_squares(w)

    elif regularization is not None and regularization not in constants.REGULARIZERS:
        raise ValueError(regularization)


def compute_regularization_constraints(
    regularization: Optional[str],
    reg_const: Optional[float],
    w,
    v_times_vertex=None,
):
    if regularization == 'l1_constraint':
        return [cp.sum(cp.abs(w)) <= reg_const]

    elif regularization == 'l2_constraint':
        return [cp.sum_squares(w) <= reg_const]

    elif regularization == 'lipshitz_constraint':
        assert v_times_vertex is not None
        dfdx = v_times_vertex @ w.T
        dfdx = cp.sum(cp.square(dfdx), axis=1)
        return [cp.max(dfdx) <= reg_const]

    elif regularization is not None and regularization not in constants.REGULARIZERS:
        raise ValueError(regularization)

    return []
