Source code for pinnx.utils.losses

# Rewrite of the original file in DeepXDE: https://github.com/lululxvi/deepxde
# ==============================================================================


import braintools
import brainunit as u
import jax


[docs] def mean_absolute_error(y_true, y_pred): return jax.tree.map( lambda x, y: braintools.metric.absolute_error(x, y).mean(), y_true, y_pred, is_leaf=u.math.is_quantity )
[docs] def mean_squared_error(y_true, y_pred): return jax.tree.map( lambda x, y: braintools.metric.squared_error(x, y).mean(), y_true, y_pred, is_leaf=u.math.is_quantity )
[docs] def mean_l2_relative_error(y_true, y_pred): return jax.tree.map( lambda x, y: braintools.metric.l2_norm(x, y).mean(), y_true, y_pred, is_leaf=u.math.is_quantity )
[docs] def softmax_cross_entropy(y_true, y_pred): return jax.tree.map( lambda x, y: braintools.metric.softmax_cross_entropy(x, y).mean(), y_true, y_pred, is_leaf=u.math.is_quantity )
LOSS_DICT = { # mean absolute error "mean absolute error": mean_absolute_error, "MAE": mean_absolute_error, "mae": mean_absolute_error, # mean squared error "mean squared error": mean_squared_error, "MSE": mean_squared_error, "mse": mean_squared_error, # mean l2 relative error "mean l2 relative error": mean_l2_relative_error, # softmax cross entropy "softmax cross entropy": softmax_cross_entropy, }
[docs] def get_loss(identifier): """Retrieves a loss function. Args: identifier: A loss identifier. String name of a loss function, or a loss function. Returns: A loss function. """ if isinstance(identifier, (list, tuple)): return list(map(get_loss, identifier)) if isinstance(identifier, str): return LOSS_DICT[identifier] if callable(identifier): return identifier raise ValueError("Could not interpret loss function identifier:", identifier)