Source code for pinnx.icbc.initial_conditions

# Rewrite of the original file in DeepXDE: https://github.com/lululxvi/deepxde
# ==============================================================================

from __future__ import annotations

from typing import Callable, Dict

import brainstate
import jax
import numpy as np

from .base import ICBC

__all__ = ["IC"]


[docs] class IC(ICBC): """ Initial conditions: ``y([x, t0]) = func([x, t0])``. Args: func: Function that returns the initial conditions. This function should take a dictionary of collocation points and return a dictionary of initial conditions. For example:: import brainunit as u def func(x): return {'y': -u.math.sin(np.pi * x['x'] / u.meter) * u.meter / u.second} on_initial: Filter function for initial conditions. This function should take a dictionary of collocation points and return a boolean array indicating whether the points are initial conditions. For example:: def on_initial(x, on): return on """ def __init__( self, func: Callable[[Dict, ...], Dict] | Callable[[Dict], Dict], on_initial: Callable[[Dict, np.array], np.array] = lambda x, on: on, ): self.func = func self.on_initial = lambda x, on: jax.vmap(on_initial)(x, on)
[docs] def filter(self, X): """ Filter the collocation points for initial conditions. Args: X: Collocation points. Returns: Filtered collocation points. """ # the "geometry" should be "TimeDomain" or "GeometryXTime" positions = self.on_initial(X, self.geometry.on_initial(X)) return jax.tree.map(lambda x: x[positions], X)
[docs] def collocation_points(self, X): """ Return the collocation points for initial conditions. Args: X: Collocation points. Returns: Collocation points for initial conditions. """ return self.filter(X)
[docs] def error(self, inputs, outputs, **kwargs) -> Dict[str, brainstate.typing.ArrayLike]: """ Error for initial conditions. Compare the initial conditions with the outputs. Args: inputs: Collocation points. outputs: Collocation values. Returns: Error for initial conditions. """ values = self.func(inputs, **kwargs) errors = dict() for key, value in values.items(): errors[key] = outputs[key] - value return errors