# Rewrite of the original file in DeepXDE: https://github.com/lululxvi/deepxde
# ==============================================================================
from __future__ import annotations
from typing import Callable, Sequence, Union, Optional, Dict, List
import brainstate
import brainunit as u
import jax.tree
import numpy as np
from pinnx import utils
from pinnx.geometry import GeometryXTime, DictPointGeometry
from .base import Problem
from ..icbc.base import ICBC
__all__ = [
"PDE",
"TimePDE"
]
[docs]
class PDE(Problem):
"""ODE or time-independent PDE solver.
Args:
geometry: Instance of ``Geometry``.
constraints: A boundary condition or a list of boundary conditions. Use ``[]`` if no
boundary condition.
approximator: A neural network trainer for approximating the solution.
num_domain (int): The number of training points sampled inside the domain.
num_boundary (int): The number of training points sampled on the boundary.
train_distribution (string): The distribution to sample training points. One of
the following: "uniform" (equispaced grid), "pseudo" (pseudorandom), "LHS"
(Latin hypercube sampling), "Halton" (Halton sequence), "Hammersley"
(Hammersley sequence), or "Sobol" (Sobol sequence).
anchors: A Numpy array of training points, in addition to the `num_domain` and
`num_boundary` sampled points.
exclusions: A Numpy array of points to be excluded for training.
solution: The reference solution.
num_test: The number of points sampled inside the domain for testing PDE loss.
The testing points for BCs/ICs are the same set of points used for training.
If ``None``, then the training points will be used for testing.
Warning:
The testing points include points inside the domain and points on the boundary,
and they may not have the same density, and thus the entire testing points may
not be uniformly distributed. As a result, if you have a reference solution
(`solution`) and would like to compute a metric such as
.. code-block:: python
Trainer.compile(metrics=["l2 relative error"])
then the metric may not be very accurate. To better compute a metric, you can
sample the points manually, and then use ``Trainer.predict()`` to predict the
solution on these points and compute the metric:
.. code-block:: python
x = geometry.uniform_points(num, boundary=True)
y_true = ...
y_pred = trainer.predict(x)
error= pinnx.metrics.l2_relative_error(y_true, y_pred)
Attributes:
train_x_all: A Numpy array of points for PDE training. `train_x_all` is
unordered, and does not have duplication. If there is PDE, then
`train_x_all` is used as the training points of PDE.
train_x_bc: A Numpy array of the training points for BCs. `train_x_bc` is
constructed from `train_x_all` at the first step of training, by default it
won't be updated when `train_x_all` changes. To update `train_x_bc`, set it
to `None` and call `bc_points`, and then update the loss function by
``trainer.compile()``.
num_bcs (list): `num_bcs[i]` is the number of points for `constraints[i]`.
train_x: A Numpy array of the points fed into the network for training.
`train_x` is ordered from BC points (`train_x_bc`) to PDE points
(`train_x_all`), and may have duplicate points.
test_x: A Numpy array of the points fed into the network for testing, ordered
from BCs to PDE. The BC points are exactly the same points in `train_x_bc`.
"""
def __init__(
self,
geometry: DictPointGeometry,
pde: Callable,
constraints: Union[ICBC, Sequence[ICBC]],
approximator: Optional[brainstate.nn.Module] = None,
solution: Callable[[brainstate.typing.PyTree], brainstate.typing.PyTree] = None,
loss_fn: str | Callable = 'MSE',
num_domain: int = 0,
num_boundary: int = 0,
num_test: int = None,
train_distribution: str = "Hammersley",
anchors: Optional[brainstate.typing.ArrayLike] = None,
exclusions=None,
loss_weights: Sequence[float] = None,
):
super().__init__(
approximator=approximator,
loss_fn=loss_fn,
loss_weights=loss_weights
)
assert isinstance(geometry, DictPointGeometry), f"Expected DictPointGeometry, got {type(geometry)}"
# geometry is a Geometry object
self.geometry = geometry
# PDE function
self._pde = pde
if pde is not None:
assert callable(pde), f"Expected callable, got {type(pde)}"
# initial and boundary conditions
self.constraints = constraints if isinstance(constraints, (list, tuple)) else [constraints]
for bc in self.constraints:
assert isinstance(bc, ICBC), f"Expected ICBC, got {type(bc)}"
bc.apply_geometry(self.geometry)
bc.apply_problem(self)
# anchors
self.anchors = (None
if anchors is None else
jax.tree.map(lambda x: x.astype(brainstate.environ.dftype()), anchors))
# solution
if solution is not None:
assert callable(solution), f"Expected callable, got {type(solution)}"
self.solution = solution
# exclusions
self.exclusions = exclusions
# others
self.num_domain = num_domain
self.num_boundary = num_boundary
self.num_test = num_test
self.train_distribution = train_distribution
# training data
self.train_x_all: Dict[str, brainstate.typing.ArrayLike] = None
self.train_x_bc: Dict[str, brainstate.typing.ArrayLike] = None
self.num_bcs: List[int] = None
# these include both BC and PDE points
self.train_x: Dict[str, brainstate.typing.ArrayLike] = None
self.train_y: Dict[str, brainstate.typing.ArrayLike] = None
self.test_x: Dict[str, brainstate.typing.ArrayLike] = None
self.test_y: Dict[str, brainstate.typing.ArrayLike] = None
# generate training data and testing data
self.train_next_batch()
self.test()
[docs]
def pde(self, *args, **kwargs):
"""
Compute the PDE residual.
"""
if self._pde is not None:
return self._pde(*args, **kwargs)
else:
raise NotImplementedError("PDE is not defined.")
def call_pde_errors(self, inputs, outputs, **kwargs):
bcs_start = np.cumsum([0] + self.num_bcs)
# PDE inputs and outputs, computing PDE losses
pde_inputs = jax.tree.map(lambda x: x[bcs_start[-1]:], inputs)
pde_outputs = jax.tree.map(lambda x: x[bcs_start[-1]:], outputs)
pde_kwargs = jax.tree.map(lambda x: x[bcs_start[-1]:], kwargs)
# error
pde_errors = self.pde(pde_inputs, pde_outputs, **pde_kwargs)
return pde_errors
def call_bc_errors(self, loss_fns, loss_weights, inputs, outputs, **kwargs):
bcs_start = np.cumsum([0] + self.num_bcs)
losses = []
for i, bc in enumerate(self.constraints):
# ICBC inputs and outputs, computing ICBC losses
beg, end = bcs_start[i], bcs_start[i + 1]
icbc_inputs = jax.tree.map(lambda x: x[beg:end], inputs)
icbc_outputs = jax.tree.map(lambda x: x[beg:end], outputs)
icbc_kwargs = jax.tree.map(lambda x: x[beg:end], kwargs)
# error
error: Dict = bc.error(icbc_inputs, icbc_outputs, **icbc_kwargs)
# loss and weights
f_loss = loss_fns[i]
if loss_weights is not None:
w = loss_weights[i]
bc_loss = jax.tree.map(lambda err: f_loss(u.math.zeros_like(err), err) * w, error)
else:
bc_loss = jax.tree.map(lambda err: f_loss(u.math.zeros_like(err), err), error)
# append to losses
losses.append({f'ibc{i}': bc_loss})
return losses
[docs]
@utils.check_not_none('num_bcs')
def losses(self, inputs, outputs, targets, **kwargs):
# PDE inputs and outputs, computing PDE losses
pde_errors = self.call_pde_errors(inputs, outputs, **kwargs)
if not isinstance(pde_errors, (list, tuple)):
pde_errors = [pde_errors]
# loss functions
if not isinstance(self.loss_fn, (list, tuple)):
loss_fn = [self.loss_fn] * (len(pde_errors) + len(self.constraints))
else:
loss_fn = self.loss_fn
if len(loss_fn) != len(pde_errors) + len(self.constraints):
raise ValueError(f"There are {len(pde_errors) + len(self.constraints)} errors, "
f"but only {len(loss_fn)} losses.")
# PDE loss
losses = [loss_fn[i](u.math.zeros_like(error), error) for i, error in enumerate(pde_errors)]
if self.loss_weights is not None:
n_loss = len(losses) + len(self.constraints)
if len(self.loss_weights) != len(losses) + len(self.constraints):
raise ValueError(f"Expected {n_loss} weights, got {len(self.loss_weights)}. "
f"There are {len(losses)} PDE losses and {len(self.constraints)} IC+BC losses.")
del n_loss
losses = [w * loss for w, loss in zip(self.loss_weights[:len(losses)], losses)]
# loss of boundary or initial conditions
bc_errors = self.call_bc_errors(
loss_fn[len(pde_errors):],
self.loss_weights[len(pde_errors):] if self.loss_weights is not None else None,
inputs,
outputs,
**kwargs
)
losses.extend(bc_errors)
return losses
[docs]
@utils.run_if_all_none("train_x", "train_y")
def train_next_batch(self, batch_size=None):
# Generate `self.train_x_all`
self.train_points()
# Generate `self.num_bcs` and `self.train_x_bc`
self.bc_points()
if self.pde is not None:
# include data in boundary, initial conditions, and PDE
if len(self.train_x_bc):
self.train_x = jax.tree.map(lambda x, y: u.math.concatenate((x, y), axis=0),
self.train_x_bc,
self.train_x_all)
else:
self.train_x = self.train_x_all
else:
# only include data in boundary or initial conditions
self.train_x = self.train_x_bc
self.train_y = self.solution(self.train_x) if self.solution is not None else None
return self.train_x, self.train_y
[docs]
@utils.run_if_all_none("test_x", "test_y")
def test(self):
if self.num_test is None:
# assign the training points to the testing points
self.test_x = self.train_x
else:
# Generate `self.test_x`, resampling the test points
self.test_x = self.test_points()
# solution on the test points
self.test_y = self.solution(self.test_x) if self.solution is not None else None
return self.test_x, self.test_y
[docs]
def resample_train_points(self, pde_points=True, bc_points=True):
"""Resample the training points for PDE and/or BC."""
if pde_points:
self.train_x_all = None
if bc_points:
self.train_x_bc = None
self.train_x, self.train_y = None, None
self.train_next_batch()
[docs]
def add_anchors(self, anchors: brainstate.typing.PyTree):
"""
Add new points for training PDE losses.
The BC points will not be updated.
"""
anchors = jax.tree.map(lambda x: x.astype(brainstate.environ.dftype()), anchors)
if self.anchors is None:
self.anchors = anchors
else:
self.anchors = jax.tree.map(lambda x, y: u.math.concatenate((x, y), axis=-1),
self.anchors,
anchors)
# include anchors in the training points
self.train_x_all = jax.tree.map(lambda x, y: u.math.concatenate((x, y), axis=-1),
anchors,
self.train_x_all)
if self.pde is not None:
# include data in boundary, initial conditions, and PDE
self.train_x = jax.tree.map(lambda x, y: u.math.concatenate((x, y), axis=-1),
self.bc_points(),
self.train_x_all)
else:
# only include data in boundary or initial conditions
self.train_x = self.bc_points()
# solution on the training points
self.train_y = self.solution(self.train_x) if self.solution is not None else None
[docs]
def replace_with_anchors(self, anchors):
"""Replace the current PDE training points with anchors.
The BC points will not be changed.
"""
self.anchors = jax.tree.map(lambda x: x.astype(brainstate.environ.dftype()), anchors)
self.train_x_all = self.anchors
if self.pde is not None:
# include data in boundary, initial conditions, and PDE
self.train_x = jax.tree.map(lambda x, y: u.math.concatenate((x, y), axis=-1),
self.bc_points(),
self.train_x_all)
else:
# only include data in boundary or initial conditions
self.train_x = self.bc_points()
# solution on the training points
self.train_y = self.solution(self.train_x) if self.solution is not None else None
@utils.run_if_all_none("train_x_all")
def train_points(self):
X = None
# sampling points in the domain
if self.num_domain > 0:
if self.train_distribution == "uniform":
X = self.geometry.uniform_points(self.num_domain, boundary=False)
else:
X = self.geometry.random_points(self.num_domain, random=self.train_distribution)
# sampling points on the boundary
if self.num_boundary > 0:
if self.train_distribution == "uniform":
tmp = self.geometry.uniform_boundary_points(self.num_boundary)
else:
tmp = self.geometry.random_boundary_points(self.num_boundary, random=self.train_distribution)
X = (tmp
if X is None else
jax.tree.map(lambda x, y: u.math.concatenate((x, y), axis=0), X, tmp))
# add anchors
if self.anchors is not None:
X = (self.anchors
if X is None else
jax.tree.map(lambda x, y: u.math.concatenate((x, y), axis=0), self.anchors, X))
# exclude points
if self.exclusions is not None:
raise NotImplementedError
# TODO: Check if this is correct
def is_not_excluded(x):
return not np.any([np.allclose(x, y) for y in self.exclusions])
X = np.array(list(filter(is_not_excluded, X)))
# save the training points
self.train_x_all = X
return X
[docs]
@utils.run_if_all_none("train_x_bc")
def bc_points(self):
"""
Generate boundary condition points.
Returns:
np.ndarray: The boundary condition points.
"""
x_bcs = [bc.collocation_points(self.train_x_all) for bc in self.constraints]
# self.num_bcs = list([len(x[self.geometry.names[0]]) for x in x_bcs])
self.num_bcs = list([len(tuple(x.values())[0]) for x in x_bcs])
if len(self.num_bcs):
self.train_x_bc = jax.tree.map(lambda *x: u.math.concatenate(x, axis=0), *x_bcs)
else:
self.train_x_bc = dict()
return self.train_x_bc
def test_points(self):
# different points from self.train_x_all
x = self.geometry.uniform_points(self.num_test, boundary=False)
# # different BC points from self.train_x_bc
# x_bcs = [bc.collocation_points(x) for bc in self.constraints]
# x_bcs = jax.tree.map(lambda *x: u.math.vstack(x), *x_bcs)
# reuse the same BC points
if len(self.num_bcs):
x_bcs = self.train_x_bc
x = jax.tree.map(lambda x_, y_: u.math.concatenate((x_, y_), axis=0), x_bcs, x)
return x
[docs]
class TimePDE(PDE):
"""Time-dependent PDE solver.
Args:
num_initial (int): The number of training points sampled on the initial
location.
"""
def __init__(
self,
geometry: DictPointGeometry,
pde: Callable,
constraints: Union[ICBC, Sequence[ICBC]],
approximator: Optional[brainstate.nn.Module] = None,
num_domain: int = 0,
num_boundary: int = 0,
num_initial: int = 0,
train_distribution: str = "Hammersley",
anchors=None,
exclusions=None,
solution=None,
num_test: int = None,
loss_fn: str | Callable = 'MSE',
loss_weights: Sequence[float] = None,
):
self.num_initial = num_initial
super().__init__(
geometry,
pde,
constraints,
num_domain=num_domain,
num_boundary=num_boundary,
train_distribution=train_distribution,
anchors=anchors,
exclusions=exclusions,
solution=solution,
num_test=num_test,
approximator=approximator,
loss_fn=loss_fn,
loss_weights=loss_weights,
)
@utils.run_if_all_none("train_x_all")
def train_points(self):
self.geometry: GeometryXTime
X = super().train_points()
if self.num_initial > 0:
if self.train_distribution == "uniform":
tmp = self.geometry.uniform_initial_points(self.num_initial)
else:
tmp = self.geometry.random_initial_points(self.num_initial, random=self.train_distribution)
if self.exclusions is not None:
def is_not_excluded(x):
return not np.any([np.allclose(x, y) for y in self.exclusions])
tmp = np.array(list(filter(is_not_excluded, tmp)))
X = jax.tree.map(lambda x, y: u.math.concatenate((x, y), axis=0), X, tmp)
self.train_x_all = X
return X