# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
from functools import wraps
from typing import Dict, Callable, Sequence, Union, Optional, Tuple, Any, Iterator
import brainstate
import brainunit as u
TransformFn = Callable
__all__ = [
'jacobian', 'hessian', 'gradient',
]
class GradientTransform(brainstate.util.PrettyRepr):
def __init__(
self,
target: Callable,
transform: TransformFn,
return_value: bool = False,
has_aux: bool = False,
transform_params: Optional[Dict[str, Any]] = None,
):
self._return_value = return_value
self._has_aux = has_aux
# target
self.target = target
# transform
self._states_to_be_written: Tuple[brainstate.State, ...] = None
_grad_setting = dict() if transform_params is None else transform_params
if self._has_aux:
self._transform = transform(self._fun_with_aux, has_aux=True, **_grad_setting)
else:
self._transform = transform(self._fun_without_aux, has_aux=True, **_grad_setting)
def __pretty_repr__(self) -> Iterator[Union[brainstate.util.PrettyType, brainstate.util.PrettyAttr]]:
yield brainstate.util.PrettyType(self.__class__.__name__)
yield brainstate.util.PrettyAttr("target", self.target)
yield brainstate.util.PrettyAttr("return_value", self._return_value)
yield brainstate.util.PrettyAttr("has_aux", self._has_aux)
yield brainstate.util.PrettyAttr("transform", self._transform)
def _call_target(self, *args, **kwargs):
if self._states_to_be_written is None:
with brainstate.StateTraceStack() as stack:
output = self.target(*args, **kwargs)
self._states_to_be_written = [st for st in stack.get_write_states()]
else:
output = self.target(*args, **kwargs)
return output
def _fun_with_aux(self, *args, **kwargs):
# Users should return the auxiliary data like::
# >>> # 1. example of return one data
# >>> return scalar_loss, data
# >>> # 2. example of return multiple data
# >>> return scalar_loss, (data1, data2, ...)
outs = self._call_target(*args, **kwargs)
# outputs: [0] is the value for gradient,
# [1] is other values for return
assert self._states_to_be_written is not None, "The states to be written should be collected."
return outs[0], (outs, [v.value for v in self._states_to_be_written])
def _fun_without_aux(self, *args, **kwargs):
# Users should return the scalar value like this::
# >>> return scalar_loss
out = self._call_target(*args, **kwargs)
assert self._states_to_be_written is not None, "The states to be written should be collected."
return out, (out, [v.value for v in self._states_to_be_written])
def _return(self, rets):
grads, (outputs, new_dyn_vals) = rets
for i, val in enumerate(new_dyn_vals):
self._states_to_be_written[i].value = val
# check returned value
if self._return_value:
# check aux
if self._has_aux:
return grads, outputs[0], outputs[1]
else:
return grads, outputs
else:
# check aux
if self._has_aux:
return grads, outputs[1]
else:
return grads
def __call__(self, *args, **kwargs):
rets = self._transform(*args, **kwargs)
return self._return(rets)
def _raw_jacrev(
fun: Callable,
has_aux: bool = False,
y: str | Sequence[str] | None = None,
x: str | Sequence[str] | None = None,
) -> Callable:
# process only for y
if isinstance(y, str):
y = [y]
if y is not None:
fun = _format_y(fun, y, has_aux=has_aux)
# process only for x
if isinstance(x, str):
x = [x]
def transform(inputs):
if x is not None:
fun2, inputs = _format_x(fun, x, inputs)
return u.autograd.jacrev(fun2, has_aux=has_aux)(inputs)
else:
return u.autograd.jacrev(fun, has_aux=has_aux)(inputs)
return transform
def _raw_jacfwd(
fun: Callable,
has_aux: bool = False,
y: str | Sequence[str] | None = None,
x: str | Sequence[str] | None = None,
) -> Callable:
# process only for y
if isinstance(y, str):
y = [y]
if y is not None:
fun = _format_y(fun, y, has_aux=has_aux)
# process only for x
if isinstance(x, str):
x = [x]
def transform(inputs):
if x is not None:
fun2, inputs = _format_x(fun, x, inputs)
return u.autograd.jacfwd(fun2, has_aux=has_aux)(inputs)
else:
return u.autograd.jacfwd(fun, has_aux=has_aux)(inputs)
return transform
def _raw_hessian(
fun: Callable,
has_aux: bool = False,
y: str | Sequence[str] | None = None,
xi: str | Sequence[str] | None = None,
xj: str | Sequence[str] | None = None,
) -> Callable:
r"""
Physical unit-aware version of `jax.hessian <https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html>`_,
computing Hessian of ``fun`` as a dense array.
H[y][xi][xj] = d^2y / dxi dxj
Args:
fun: Function whose Hessian is to be computed. Its arguments at positions
specified by ``argnums`` should be arrays, scalars, or standard Python
containers thereof. It should return arrays, scalars, or standard Python
containers thereof.
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
Returns:
A function with the same arguments as ``fun``, that evaluates the Hessian of
``fun``.
"""
inner = _raw_jacrev(fun, has_aux=has_aux, y=y, x=xi)
# process only for xj
if isinstance(xj, str):
xj = [xj]
def transform(inputs):
if xj is not None:
fun2, inputs = _format_x(inner, xj, inputs)
return u.autograd.jacfwd(fun2, has_aux=has_aux)(inputs)
else:
return u.autograd.jacfwd(inner, has_aux=has_aux)(inputs)
return transform
def _format_x(fn, x_keys, xs):
assert isinstance(xs, dict), 'xs must be a dictionary.'
assert isinstance(x_keys, (tuple, list)), 'x must be a tuple or list.'
assert all(isinstance(key, str) for key in x_keys), 'x_keys must be a tuple or list of strings.'
others = {key: xs[key] for key in xs if key not in x_keys}
xs = {key: xs[key] for key in x_keys}
@wraps(fn)
def fn_new(inputs):
return fn({**inputs, **others})
return fn_new, xs
def _format_y(fn, y, has_aux: bool):
assert isinstance(y, (tuple, list)), 'y must be a tuple or list.'
assert all(isinstance(key, str) for key in y), 'y must be a tuple or list of strings.'
@wraps(fn)
def fn_new(inputs):
if has_aux:
outs, _aux = fn(inputs)
return {key: outs[key] for key in y}, _aux
else:
outs = fn(inputs)
return {key: outs[key] for key in y}
return fn_new
[docs]
def jacobian(
fn: Callable,
xs: Dict,
y: str | Sequence[str] | None = None,
x: str | Sequence[str] | None = None,
mode: str = 'backward',
vmap: bool = True,
):
"""
Compute `Jacobian matrix <https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant>`_
J as J[i, j] = dy_i / dx_j, where i = 0, ..., dim_y - 1 and j = 0, ..., dim_x - 1.
Args:
fn: Function to compute the gradient.
xs: Inputs of the function.
mode: The mode of the gradient computation. Choose between 'backward' and 'forward'.
x (str or None): `i`th row. If `i` is ``None``, returns the `j`th column
J[:, `j`].
y (str or None): `j`th column. If `j` is ``None``, returns the `i`th row
J[`i`, :], i.e., the gradient of y_i. `i` and `j` cannot be both ``None``,
unless J has only one element, which is returned.
Returns:
(`i`, `j`)th entry J[`i`, `j`], `i`th row J[`i`, :], or `j`th column J[:, `j`].
"""
# assert isinstance(xs, dict), 'xs must be a dictionary.'
assert isinstance(mode, str), 'mode must be a string.'
assert mode in ['backward', 'forward'], 'mode must be either backward or forward.'
# process only for x
if isinstance(x, str):
x = [x]
# process only for y
if isinstance(y, str):
y = [y]
# compute the Jacobian
if mode == 'backward':
transform = GradientTransform(fn, _raw_jacrev, transform_params={'y': y, 'x': x})
elif mode == 'forward':
transform = GradientTransform(fn, _raw_jacfwd, transform_params={'y': y, 'x': x})
else:
raise ValueError('Invalid mode. Choose between backward and forward.')
if vmap:
return brainstate.transform.vmap(transform)(xs)
else:
return transform(xs)
[docs]
def hessian(
fn: Callable,
xs: Dict,
y: str | Sequence[str] | None = None,
xi: str | Sequence[str] | None = None,
xj: str | Sequence[str] | None = None,
vmap: bool = True,
):
"""
Compute `Hessian matrix <https://en.wikipedia.org/wiki/Hessian_matrix>`_ H as
H[i, j] = d^2y / dx_i dx_j, where i,j = 0, ..., dim_x - 1.
Args:
fn: Function to compute the gradient.
xs: Inputs of the function.
y (str or None): The output variable.
xi (str or None): `i`th row. If `i` is ``None``, returns the `j`th column H[:, `j`].
xj (str or None): `j`th column. If `j` is ``None``, returns the `i`th row
H[`i`, :], i.e., the gradient of y_i. `i` and `j` cannot be both ``None``,
unless H has only one element, which is returned.
Returns:
H[`i`, `j`].
"""
# assert isinstance(xs, dict), 'xs must be a dictionary.'
transform = GradientTransform(fn, _raw_hessian, transform_params={'y': y, 'xi': xi, 'xj': xj})
if vmap:
return brainstate.transform.vmap(transform)(xs)
else:
return transform(xs)
[docs]
def gradient(
fn: Callable,
xs: Dict,
y: str | Sequence[str] | None = None,
*xi: str | Sequence[str] | None,
order: int = 1,
):
"""
Compute the gradient dy/dx of a function y = f(x) with respect to x.
If order is 1, it computes the first derivative dy/dx.
Args:
fn: Function to compute the gradient.
xs: Inputs of the function.
y (str or None): The variable to differentiate.
xi (str or None): The variable to differentiate with respect to.
order: The order of the gradient. Default is 1.
Returns:
dy/dx.
"""
assert isinstance(order, int), 'order must be an integer.'
assert order > 0, 'order must be positive.'
# process only for y
if isinstance(y, str):
y = [y]
if y is not None:
fn = _format_y(fn, y, has_aux=False)
# process xi
if len(xi) > 0:
assert len(xi) == order, 'The number of xi must be equal to order.'
xi = list(xi)
for i in range(order):
if isinstance(xi[i], str):
xi[i] = [xi[i]]
else:
xi = [None] * order
# compute the gradient
for i, x in enumerate(xi):
if i == 0:
fn = _raw_jacrev(fn, y=y, x=x)
else:
fn = _raw_jacfwd(fn, y=None, x=x)
return brainstate.transform.vmap(fn)(xs)