# Rewrite of the original file in DeepXDE: https://github.com/lululxvi/deepxde
# ==============================================================================
import itertools
import brainstate
import jax.numpy as jnp
from .base import AbstractGeometry
from .geometry_1d import Interval
from .geometry_2d import Rectangle
from .geometry_3d import Cuboid
from .geometry_nd import Hypercube
from ..utils import isclose
[docs]
class TimeDomain(Interval):
def __init__(self, t0, t1):
super().__init__(t0, t1)
self.t0 = jnp.asarray(t0, dtype=brainstate.environ.dftype())
self.t1 = jnp.asarray(t1, dtype=brainstate.environ.dftype())
def on_initial(self, t):
return isclose(t, self.t0).flatten()
[docs]
class GeometryXTime(AbstractGeometry):
def __init__(self, geometry, timedomain):
self.geometry = geometry
self.timedomain = timedomain
super().__init__(geometry.dim + timedomain.dim)
[docs]
def inside(self, x):
return jnp.logical_and(self.geometry.inside(x[:, :-1]),
self.timedomain.inside(x[:, -1:]))
[docs]
def on_boundary(self, x):
return self.geometry.on_boundary(x[:, :-1])
def on_initial(self, x):
return self.timedomain.on_initial(x[:, -1:])
[docs]
def boundary_normal(self, x):
_n = self.geometry.boundary_normal(x[:, :-1])
return jnp.hstack([_n, jnp.zeros((len(_n), 1))])
[docs]
def random_points(self, n, random="pseudo"):
if isinstance(self.geometry, Interval):
geom = Rectangle(
[self.geometry.l, self.timedomain.t0],
[self.geometry.r, self.timedomain.t1],
)
return geom.random_points(n, random=random)
if isinstance(self.geometry, Rectangle):
geom = Cuboid(
[self.geometry.xmin[0], self.geometry.xmin[1], self.timedomain.t0],
[self.geometry.xmax[0], self.geometry.xmax[1], self.timedomain.t1],
)
return geom.random_points(n, random=random)
if isinstance(self.geometry, (Cuboid, Hypercube)):
geom = Hypercube(
jnp.append(self.geometry.xmin, self.timedomain.t0),
jnp.append(self.geometry.xmax, self.timedomain.t1),
)
return geom.random_points(n, random=random)
x = self.geometry.random_points(n, random=random)
t = self.timedomain.random_points(n, random=random)
t = brainstate.random.permutation(t)
return jnp.hstack((x, t))
[docs]
def random_boundary_points(self, n, random="pseudo"):
x = self.geometry.random_boundary_points(n, random=random)
t = self.timedomain.random_points(n, random=random)
t = brainstate.random.permutation(t)
return jnp.hstack((x, t))
def uniform_initial_points(self, n):
x = self.geometry.uniform_points(n, True)
t = self.timedomain.t0
if n != len(x):
print(
"Warning: {} points required, but {} points sampled.".format(n, len(x))
)
return jnp.hstack((x, jnp.full([len(x), 1], t, dtype=brainstate.environ.dftype())))
def random_initial_points(self, n, random="pseudo"):
x = self.geometry.random_points(n, random=random)
t = self.timedomain.t0
return jnp.hstack((x, jnp.full([n, 1], t, dtype=brainstate.environ.dftype())))
[docs]
def periodic_point(self, x, component):
xp = self.geometry.periodic_point(x[:, :-1], component)
return jnp.hstack([xp, x[:, -1:]])