# Rewrite of the original file in DeepXDE: https://github.com/lululxvi/deepxde
# ==============================================================================
from __future__ import annotations
from typing import Callable, Sequence, Union, Optional, Any, Dict
import brainstate
import brainunit as u
import jax
import numpy as np
from pinnx.fnspace import FunctionSpace
from pinnx.geometry import DictPointGeometry
from pinnx.icbc.base import ICBC
from pinnx.utils import run_if_all_none
from pinnx.utils.sampler import BatchSampler
from .pde import TimePDE
__all__ = [
'PDEOperator',
'PDEOperatorCartesianProd',
]
Inputs = Any
Outputs = Any
Auxiliary = Any
Residual = Any
[docs]
class PDEOperator(TimePDE):
"""
PDE solution operator.
Args:
function_space: Instance of ``pinnx.fnspace.FunctionSpace``.
evaluation_points: A NumPy array of shape (n_points, dim). Discretize the input
function sampled from `function_space` using point-wise evaluations at a set
of points as the input of the branch net.
num_function (int): The number of functions for training.
function_variables: ``None`` or a list of integers. The functions in the
`function_space` may not have the same domain as the PDE. For example, the
PDE is defined on a spatio-temporal domain (`x`, `t`), but the function is
IC, which is only a function of `x`. In this case, we need to specify the
variables of the function by `function_variables=[0]`, where `0` indicates
the first variable `x`. If ``None``, then we assume the domains of the
function and the PDE are the same.
num_fn_test: The number of functions for testing PDE loss. The testing functions
for BCs/ICs are the same functions used for training. If ``None``, then the
training functions will be used for testing.
"""
def __init__(
self,
geometry: DictPointGeometry,
pde: Callable[[Inputs, Outputs, Auxiliary], Residual],
constraints: Union[ICBC, Sequence[ICBC]],
function_space: FunctionSpace,
evaluation_points,
num_function: int,
function_variables: Optional[Sequence[int]] = None,
num_test: int = None,
approximator: Optional[brainstate.nn.Module] = None,
solution: Callable[[brainstate.typing.PyTree], brainstate.typing.PyTree] = None,
num_domain: int = 0, # for space PDE
num_boundary: int = 0, # for space PDE
num_initial: int = 0, # for time PDE
num_fn_test: int = None,
train_distribution: str = "Hammersley",
anchors: Optional[brainstate.typing.ArrayLike] = None,
exclusions=None,
loss_fn: str | Callable = 'MSE',
loss_weights: Sequence[float] = None,
):
assert isinstance(function_space, FunctionSpace), (
f"Expected `function_space` to be an instance of `FunctionSpace`, "
f"but got {type(function_space)}."
)
self.fn_space = function_space
self.eval_pts = evaluation_points
self.func_vars = (
function_variables
if function_variables is not None
else list(range(geometry.dim))
)
self.num_fn = num_function
self.num_fn_test = num_fn_test
self.fn_train_bc = None
self.fn_train_x = None
self.fn_train_y = None
self.fn_train_aux_vars = None
self.fn_test_x = None
self.fn_test_y = None
self.fn_test_aux_vars = None
super().__init__(
geometry=geometry,
pde=pde,
constraints=constraints,
approximator=approximator,
loss_fn=loss_fn,
loss_weights=loss_weights,
num_initial=num_initial,
num_domain=num_domain,
num_boundary=num_boundary,
train_distribution=train_distribution,
anchors=anchors,
exclusions=exclusions,
solution=solution,
num_test=num_test,
)
def call_pde_errors(self, inputs, outputs, **kwargs):
num_bcs = self.num_bcs
self.num_bcs = self.num_fn_bcs
losses = super().call_pde_errors(inputs, outputs, **kwargs)
self.num_bcs = num_bcs
return losses
def call_bc_errors(self, loss_fns, loss_weights, inputs, outputs, **kwargs):
num_bcs = self.num_bcs
self.num_bcs = self.num_fn_bcs
losses = super().call_bc_errors(loss_fns, loss_weights, inputs, outputs, **kwargs)
self.num_bcs = num_bcs
return losses
[docs]
@run_if_all_none("fn_train_x", "fn_train_y", "fn_train_aux_vars")
def train_next_batch(self, batch_size=None):
super().train_next_batch(batch_size)
self.num_fn_bcs = [n * self.num_fn for n in self.num_bcs]
func_feats = self.fn_space.random(self.num_fn)
func_vals = self.fn_space.eval_batch(func_feats, self.eval_pts)
v, x, vx = self.bc_inputs(func_feats, func_vals)
if self._pde is not None:
v_pde, x_pde, vx_pde = self.gen_inputs(
func_feats,
func_vals,
self.geometry.dict_to_arr(self.train_x_all)
)
v = np.vstack((v, v_pde))
x = np.vstack((x, x_pde))
vx = np.vstack((vx, vx_pde))
self.fn_train_x = (v, x)
self.fn_train_aux_vars = {'aux': vx}
return self.fn_train_x, self.fn_train_x, self.fn_train_aux_vars
[docs]
@run_if_all_none("fn_test_x", "fn_test_y", "fn_test_aux_vars")
def test(self):
super().test()
if self.num_fn_test is None:
self.fn_test_x = self.fn_train_x
self.fn_test_aux_vars = self.fn_train_aux_vars
else:
func_feats = self.fn_space.random(self.num_fn_test)
func_vals = self.fn_space.eval_batch(func_feats, self.eval_pts)
# TODO: Use different BC data from self.fn_train_x
v, x, vx = self.train_bc
if self._pde is not None:
test_x = self.geometry.dict_to_arr(self.test_x)
v_pde, x_pde, vx_pde = self.gen_inputs(
func_feats,
func_vals,
test_x[sum(self.num_bcs):]
)
v = np.vstack((v, v_pde))
x = np.vstack((x, x_pde))
vx = np.vstack((vx, vx_pde))
self.fn_test_x = (v, x)
self.fn_test_aux_vars = {'aux': vx}
return self.fn_test_x, self.fn_test_y, self.fn_test_aux_vars
def gen_inputs(self, func_feats, func_vals, points):
# Format:
# v1, x_1
# ...
# v1, x_N1
# v2, x_1
# ...
# v2, x_N1
v = np.repeat(func_vals, len(points), axis=0)
x = np.tile(points, (len(func_feats), 1))
vx = self.fn_space.eval_batch(func_feats, points[:, self.func_vars]).reshape(-1, 1)
return v, x, vx
def bc_inputs(self, func_feats, func_vals):
if len(self.constraints) == 0:
self.train_bc = (
np.empty((0, len(self.eval_pts)), dtype=brainstate.environ.dftype()),
np.empty((0, self.geometry.dim), dtype=brainstate.environ.dftype()),
np.empty((0, 1), dtype=brainstate.environ.dftype()),
)
return self.train_bc
v, x, vx = [], [], []
bcs_start = np.cumsum([0] + self.num_bcs)
train_x_bc = self.geometry.dict_to_arr(self.train_x_bc)
for i, _ in enumerate(self.num_bcs):
beg, end = bcs_start[i], bcs_start[i + 1]
vi, xi, vxi = self.gen_inputs(func_feats, func_vals, train_x_bc[beg:end])
v.append(vi)
x.append(xi)
vx.append(vxi)
self.train_bc = (np.vstack(v), np.vstack(x), np.vstack(vx))
return self.train_bc
[docs]
def resample_train_points(self, pde_points=True, bc_points=True):
"""
Resample the training points for the operator.
"""
super().resample_train_points(pde_points=pde_points, bc_points=bc_points)
self.fn_train_x, self.fn_train_x, self.fn_train_aux_vars = None, None, None
self.train_next_batch()
[docs]
class PDEOperatorCartesianProd(TimePDE):
"""
PDE solution operator with problem in the format of Cartesian product.
Args:
pde: Instance of ``pinnx.problem.PDE`` or ``pinnx.problem.TimePDE``.
function_space: Instance of ``pinnx.problem.FunctionSpace``.
evaluation_points: A NumPy array of shape (n_points, dim). Discretize the input
function sampled from `function_space` using pointwise evaluations at a set
of points as the input of the branch net.
num_function (int): The number of functions for training.
function_variables: ``None`` or a list of integers. The functions in the
`function_space` may not have the same domain as the PDE. For example, the
PDE is defined on a spatio-temporal domain (`x`, `t`), but the function is
IC, which is only a function of `x`. In this case, we need to specify the
variables of the function by `function_variables=[0]`, where `0` indicates
the first variable `x`. If ``None``, then we assume the domains of the
function and the PDE are the same.
num_test: The number of functions for testing PDE loss. The testing functions
for BCs/ICs are the same functions used for training. If ``None``, then the
training functions will be used for testing.
batch_size: Integer or ``None``.
Attributes:
train_x: A tuple of two Numpy arrays (v, x) fed into PIDeepONet for training. v
is the function input to the branch net and has the shape (`N1`, `dim1`); x
is the point input to the trunk net and has the shape (`N2`, `dim2`).
"""
def __init__(
self,
geometry: DictPointGeometry,
pde: Callable[[Inputs, Outputs, Auxiliary], Residual],
constraints: Union[ICBC, Sequence[ICBC]],
function_space: FunctionSpace,
evaluation_points,
num_function: int,
function_variables: Optional[Sequence[int]] = None,
num_test: int = None,
approximator: Optional[brainstate.nn.Module] = None,
solution: Callable[[brainstate.typing.PyTree], brainstate.typing.PyTree] = None,
num_domain: int = 0, # for space PDE
num_boundary: int = 0, # for space PDE
num_initial: int = 0, # for time PDE
num_fn_test: int = None, # for function space
train_distribution: str = "Hammersley",
anchors: Optional[brainstate.typing.ArrayLike] = None,
exclusions=None,
loss_fn: str | Callable = 'MSE',
loss_weights: Sequence[float] = None,
batch_size: int = None,
):
assert isinstance(function_space, FunctionSpace), (
f"Expected `function_space` to be an instance of `FunctionSpace`, "
f"but got {type(function_space)}."
)
self.fn_space = function_space
self.eval_pts = evaluation_points
self.func_vars = (
function_variables
if function_variables is not None
else list(range(geometry.dim))
)
self.num_fn = num_function
self.num_fn_test = num_fn_test
self.train_sampler = BatchSampler(self.num_fn, shuffle=True)
self.batch_size = batch_size
self.fn_train_bc = None
self.fn_train_x = None
self.fn_train_y = None
self.fn_train_aux_vars = None
self.fn_test_x = None
self.fn_test_y = None
self.fn_test_aux_vars = None
super().__init__(
geometry=geometry,
pde=pde,
constraints=constraints,
approximator=approximator,
loss_fn=loss_fn,
loss_weights=loss_weights,
num_initial=num_initial,
num_domain=num_domain,
num_boundary=num_boundary,
train_distribution=train_distribution,
anchors=anchors,
exclusions=exclusions,
solution=solution,
num_test=num_test,
)
def call_pde_errors(self, inputs, outputs, **kwargs):
bcs_start = np.cumsum([0] + self.num_bcs)
# PDE inputs and outputs, computing PDE losses
pde_inputs = (inputs[0], jax.tree.map(lambda x: x[bcs_start[-1]:], inputs[1]))
pde_outputs = jax.tree.map(lambda x: x[:, bcs_start[-1]:], outputs)
pde_kwargs = jax.tree.map(lambda x: x[:, bcs_start[-1]:], kwargs)
# error
pde_errors = self.pde(pde_inputs, pde_outputs, **pde_kwargs)
return pde_errors
def call_bc_errors(self, loss_fns, loss_weights, inputs, outputs, **kwargs):
bcs_start = np.cumsum([0] + self.num_bcs)
losses = []
for i, bc in enumerate(self.constraints):
# ICBC inputs and outputs, computing ICBC losses
beg, end = bcs_start[i], bcs_start[i + 1]
icbc_inputs = (inputs[0], jax.tree.map(lambda x: x[beg:end], inputs[1]))
icbc_outputs = jax.tree.map(lambda x: x[:, beg:end], outputs)
icbc_kwargs = jax.tree.map(lambda x: x[:, beg:end], kwargs)
# error
error: Dict = bc.error(icbc_inputs, icbc_outputs, **icbc_kwargs)
# loss and weights
f_loss = loss_fns[i]
if loss_weights is not None:
w = loss_weights[i]
bc_loss = jax.tree.map(lambda err: f_loss(u.math.zeros_like(err), err) * w, error)
else:
bc_loss = jax.tree.map(lambda err: f_loss(u.math.zeros_like(err), err), error)
# append to losses
losses.append({f'ibc{i}': bc_loss})
return losses
# def _losses(self, inputs, outputs, num_fn):
# bcs_start = np.cumsum([0] + self.num_bcs)
#
# losses = []
# for i in range(num_fn):
# out = outputs[i]
# # Single output
# if u.math.ndim(out) == 1:
# out = out[:, None]
# f = []
# if self.pde.pde is not None:
# f = self.pde.pde(partial(model.fn_outputs, True), inputs[1])
# if not isinstance(f, (list, tuple)):
# f = [f]
# error_f = [fi[bcs_start[-1]:] for fi in f]
# losses_i = [loss_fn(u.math.zeros_like(error), error) for error in error_f]
#
# for j, bc in enumerate(self.constraints):
# beg, end = bcs_start[j], bcs_start[j + 1]
# # The same BC points are used for training and testing.
# error = bc.error(
# self.fn_train_x[1],
# inputs[1],
# out,
# beg,
# end,
# aux_var=model.net.auxiliary_vars[i][:, None],
# )
# losses_i.append(loss_fn(u.math.zeros_like(error), error))
#
# losses.append(losses_i)
#
# losses = zip(*losses)
# # Use stack instead of as_tensor to keep the gradients.
# losses = [u.math.mean(u.math.stack(loss, 0)) for loss in losses]
# return losses
#
# def losses_train(self, inputs, outputs, targets, **kwargs):
# num_fn = self.num_fn if self.batch_size is None else self.batch_size
# return self._losses(outputs, inputs, num_fn)
#
# def losses_test(self, inputs, outputs, targets, **kwargs):
# return self._losses(outputs, inputs, len(self.test_x[0]))
[docs]
def train_next_batch(self, batch_size=None):
super().train_next_batch(batch_size)
if self.fn_train_x is None:
train_x = self.geometry.dict_to_arr(self.train_x)
func_feats = self.fn_space.random(self.num_fn)
func_vals = self.fn_space.eval_batch(func_feats, self.eval_pts)
vx = self.fn_space.eval_batch(func_feats, train_x[:, self.func_vars])
self.fn_train_x = (func_vals, train_x)
self.fn_train_aux_vars = {'aux': vx}
if self.batch_size is None:
return self.fn_train_x, self.train_y, self.fn_train_aux_vars
indices = self.train_sampler.get_next(self.batch_size)
train_x = (self.fn_train_x[0][indices], self.fn_train_x[1])
return train_x, self.train_y, {'aux': self.fn_train_aux_vars['aux'][indices]}
[docs]
@run_if_all_none("fn_test_x", "test_y", "fn_test_aux_vars")
def test(self):
super().test()
if self.num_fn_test is None:
self.fn_test_x = self.fn_train_x
self.fn_test_aux_vars = self.fn_train_aux_vars
else:
test_x = self.geometry.dict_to_arr(self.test_x)
func_feats = self.fn_space.random(self.num_fn_test)
func_vals = self.fn_space.eval_batch(func_feats, self.eval_pts)
vx = self.fn_space.eval_batch(func_feats, test_x[:, self.func_vars])
self.fn_test_x = (func_vals, test_x)
self.fn_test_aux_vars = {'aux': vx}
return self.fn_test_x, self.test_y, {'aux': self.fn_test_aux_vars}