Source code for pinnx.nn.base
# Rewrite of the original file in DeepXDE: https://github.com/lululxvi/deepxde
# ==============================================================================
from typing import Optional, Callable
import brainstate
import jax.tree
[docs]
class NN(brainstate.nn.Module):
"""Base class for all neural network modules."""
def __init__(
self,
input_transform: Optional[Callable] = None,
output_transform: Optional[Callable] = None,
):
super().__init__()
self.regularization = None
self._input_transform = input_transform
self._output_transform = output_transform
[docs]
def num_trainable_parameters(self):
"""Evaluate the number of trainable parameters for the NN."""
n_param = 0
for key, val in self.states(brainstate.ParamState).items():
n_param += [v.size for v in jax.tree_leaves(val)]
return n_param