# Rewrite of the original file in DeepXDE: https://github.com/lululxvi/deepxde
# ==============================================================================
from __future__ import annotations
from typing import Callable, Sequence, Union, Optional, Dict, Any
import brainstate
import brainunit as u
import jax
import numpy as np
from pinnx import utils
from pinnx.geometry import DictPointGeometry
from pinnx.icbc.base import ICBC
from .pde import PDE
__all__ = [
"IDE",
]
X = Dict[str, brainstate.typing.ArrayLike]
Y = Dict[str, brainstate.typing.ArrayLike]
InitMat = Any
[docs]
class IDE(PDE):
"""IDE solver.
The current version only supports 1D problems with the integral int_0^x K(x, t) y(t) dt.
Args:
kernel: (x, t) --> R.
"""
def __init__(
self,
geometry: DictPointGeometry,
ide: Callable[[X, Y, InitMat], Any],
constraints: Union[ICBC, Sequence[ICBC]],
quad_deg: int,
approximator: Optional[brainstate.nn.Module] = None,
kernel: Callable = None,
num_domain: int = 0,
num_boundary: int = 0,
train_distribution: str = "Hammersley",
anchors=None,
solution=None,
num_test: int = None,
loss_fn: str | Callable = 'MSE',
loss_weights: Sequence[float] = None,
):
self.kernel = kernel or (lambda x, *args: np.ones((len(x), 1)))
self.quad_deg = quad_deg
self.quad_x, self.quad_w = np.polynomial.legendre.leggauss(quad_deg)
self.quad_x = self.quad_x.astype(brainstate.environ.dftype())
self.quad_w = self.quad_w.astype(brainstate.environ.dftype())
super().__init__(
geometry,
ide,
constraints,
approximator=approximator,
num_domain=num_domain,
num_boundary=num_boundary,
train_distribution=train_distribution,
anchors=anchors,
solution=solution,
num_test=num_test,
loss_fn=loss_fn,
loss_weights=loss_weights
)
def call_pde_errors(self, inputs, outputs, **kwargs):
bcs_start = np.cumsum([0] + self.num_bcs)
fit = brainstate.environ.get('fit')
int_mat = self.get_int_matrix(fit)
pde_errors = self.pde(inputs, outputs, int_mat, **kwargs)
return jax.tree.map(lambda x: x[bcs_start[-1]:], pde_errors)
[docs]
@utils.run_if_all_none("train_x", "train_y")
def train_next_batch(self, batch_size=None):
self.train_x_all = self.train_points()
x_bc = self.bc_points()
x_quad = self.quad_points(self.train_x_all)
self.train_x = jax.tree.map(
lambda x, y, z: u.math.concatenate((x, y, z), axis=0),
x_bc,
self.train_x_all,
x_quad,
is_leaf=u.math.is_quantity
)
self.train_y = self.solution(self.train_x) if self.solution else None
return self.train_x, self.train_y
[docs]
@utils.run_if_all_none("test_x", "test_y")
def test(self):
if self.num_test is None:
self.test_x = self.train_x_all
else:
self.test_x = self.test_points()
x_quad = self.quad_points(self.test_x)
self.test_x = jax.tree.map(
lambda x, y: u.math.concatenate((x, y), axis=0),
self.test_x,
x_quad,
is_leaf=u.math.is_quantity
)
self.test_y = self.solution(self.test_x) if self.solution else None
return self.test_x, self.test_y
def test_points(self):
return self.geometry.uniform_points(self.num_test, True)
def quad_points(self, X):
fn = lambda xs: (jax.vmap(lambda x: (self.quad_x + 1) * x / 2)(xs)).flatten()
return jax.tree.map(
fn,
X,
is_leaf=u.math.is_quantity
)
def get_int_matrix(self, training):
def get_quad_weights(x):
return self.quad_w * x / 2
with jax.ensure_compile_time_eval():
if training:
num_bc = sum(self.num_bcs)
X = self.train_x
else:
num_bc = 0
X = self.test_x
X = np.asarray(self.geometry.dict_to_arr(X))
if training or self.num_test is None:
num_f = tuple(self.train_x_all.values())[0].shape[0]
else:
num_f = self.num_test
int_mat = np.zeros((num_bc + num_f, X.size), dtype=brainstate.environ.dftype())
for i in range(num_f):
x = X[i + num_bc, 0]
beg = num_f + num_bc + self.quad_deg * i
end = beg + self.quad_deg
K = np.ravel(self.kernel(np.full((self.quad_deg, 1), x), X[beg:end]))
int_mat[i + num_bc, beg:end] = get_quad_weights(x) * K
return int_mat