Source code for pinnx.problem.dataset_general

# Rewrite of the original file in DeepXDE: https://github.com/lululxvi/deepxde
# ==============================================================================


from typing import Sequence, Dict

import brainstate
import jax
import numpy as np

from pinnx import utils
from .base import Problem

__all__ = [
    'DataSet'
]


[docs] class DataSet(Problem): """Fitting Problem set. Args: X_train (np.ndarray): Training input data. y_train (np.ndarray): Training output data. X_test (np.ndarray): Testing input data. y_test (np.ndarray): Testing output data. standardize (bool, optional): Standardize input data. Defaults to False. """ def __init__( self, X_train: Dict[str, brainstate.typing.ArrayLike], y_train: Dict[str, brainstate.typing.ArrayLike], X_test: Dict[str, brainstate.typing.ArrayLike], y_test: Dict[str, brainstate.typing.ArrayLike], standardize: bool = False, approximator: brainstate.nn.Module = None, loss_fn: str = 'MSE', loss_weights: Sequence[float] = None, ): super().__init__(approximator=approximator, loss_fn=loss_fn, loss_weights=loss_weights) self.train_x = X_train self.train_y = y_train self.test_x = X_test self.test_y = y_test self.scaler_x = None if standardize: r = jax.tree.map( lambda train, test: utils.standardize(train, test), self.train_x, self.test_x ) self.train_x = dict() self.test_x = dict() for key, val in r.items(): self.train_x[key] = val[0] self.test_x[key] = val[1]
[docs] def losses(self, inputs, outputs, targets, **kwargs): return self.loss_fn(targets, outputs)
[docs] def train_next_batch(self, batch_size=None): return self.train_x, self.train_y
[docs] def test(self): return self.test_x, self.test_y