Source code for pinnx.problem.dataset_function
# Rewrite of the original file in DeepXDE: https://github.com/lululxvi/deepxde
# ==============================================================================
from typing import Callable, Sequence
import brainstate
from pinnx.geometry.base import AbstractGeometry
from pinnx.utils import run_if_any_none
from .base import Problem
__all__ = [
'Function',
]
[docs]
class Function(Problem):
"""
Approximate a function via a network.
Args:
geometry: The domain of the function. Instance of ``Geometry``.
function: The function to be approximated. A callable function takes a NumPy array as the input and returns the
a NumPy array of corresponding function values.
num_train (int): The number of training points sampled inside the domain.
num_test (int). The number of points for testing.
train_distribution (string): The distribution to sample training points. One of the following: "uniform"
(equispaced grid), "pseudo" (pseudorandom), "LHS" (Latin hypercube sampling), "Halton" (Halton sequence),
"Hammersley" (Hammersley sequence), or "Sobol" (Sobol sequence).
online (bool): If ``True``, resample the pseudorandom training points every training step, otherwise, use the
same training points.
"""
def __init__(
self,
geometry: AbstractGeometry,
function: Callable,
num_train: int,
num_test: int,
train_distribution: str = "uniform",
online: bool = False,
approximator: brainstate.nn.Module = None,
loss_fn: str = 'MSE',
loss_weights: Sequence[float] = None,
):
super().__init__(approximator=approximator, loss_fn=loss_fn, loss_weights=loss_weights)
self.geom = geometry
self.func = function
self.num_train = num_train
self.num_test = num_test
self.dist_train = train_distribution
self.online = online
if online and train_distribution != "pseudo":
print("Warning: Online learning should use pseudorandom sampling.")
self.dist_train = "pseudo"
self.train_x, self.train_y = None, None
self.test_x, self.test_y = None, None
[docs]
def losses(self, inputs, outputs, targets, **kwargs):
return self.loss_fn(targets, outputs)
[docs]
def train_next_batch(self, batch_size=None):
if self.train_x is None or self.online:
if self.dist_train == "uniform":
self.train_x = self.geom.uniform_points(self.num_train, boundary=True)
else:
self.train_x = self.geom.random_points(self.num_train, random=self.dist_train)
self.train_y = self.func(self.train_x)
return self.train_x, self.train_y
[docs]
@run_if_any_none("test_x", "test_y")
def test(self):
self.test_x = self.geom.uniform_points(self.num_test, boundary=True)
self.test_y = self.func(self.test_x)
return self.test_x, self.test_y