# Rewrite of the original file in DeepXDE: https://github.com/lululxvi/deepxde
# ==============================================================================
import time
from typing import Union, Sequence, Callable, Optional
import braintools
import brainstate
import brainunit as u
import jax.numpy as jnp
import jax.tree
import numpy as np
from . import metrics as metrics_module
from . import utils
from .callbacks import CallbackList, Callback
from .problem.base import Problem
from .utils._display import training_display
__all__ = [
"Trainer",
"TrainState",
"LossHistory",
]
class Trainer:
"""
A ``Trainer`` trains a neural network on a ``Problem``.
Args:
problem: ``pinnx.problem.Problem`` instance.
external_trainable_variables: A trainable ``brainstate.ParamState`` object or a list
of trainable ``brainstate.ParamState`` objects. The unknown parameters in the
physics systems that need to be recovered.
"""
__module__ = 'pinnx'
optimizer: braintools.optim.Optimizer # optimizer
problem: Problem # problem
params: brainstate.util.FlattedDict # trainable variables
def __init__(
self,
problem: Problem,
external_trainable_variables: Union[brainstate.ParamState, Sequence[brainstate.ParamState]] = None,
batch_size: Optional[int] = None,
):
# the problem
self.problem = problem
assert isinstance(self.problem, Problem), "problem must be a Problem instance."
# the approximator
if self.problem.approximator is None:
raise ValueError("Problem must define an approximator before training.")
# parameters and external trainable variables
params = brainstate.graph.states(self.problem.approximator, brainstate.ParamState)
if external_trainable_variables is None:
external_trainable_variables = []
else:
if not isinstance(external_trainable_variables, list):
external_trainable_variables = [external_trainable_variables]
for i, var in enumerate(external_trainable_variables):
assert isinstance(var, brainstate.ParamState), ("external_trainable_variables must be a "
"list of ParamState instance.")
params[('external_trainable_variable', i)] = var
self.params = params
# other useful parameters
self.metrics = None
self.batch_size = batch_size
# training state
self.train_state = TrainState()
self.loss_history = LossHistory()
self.stop_training = False
[docs]
@utils.timing
def compile(
self,
optimizer: braintools.optim.Optimizer,
metrics: Union[str, Sequence[str]] = None,
measture_train_step_compile_time: bool = False,
):
"""
Configures the trainer for training.
Args:
optimizer: String name of an optimizer, or an optimizer class instance.
metrics: List of metrics to be evaluated by the trainer during training.
"""
print("Compiling trainer...")
# optimizer
assert isinstance(optimizer, braintools.optim.Optimizer), "optimizer must be an Optimizer instance."
self.optimizer = optimizer
self.optimizer.register_trainable_weights(self.params)
# metrics may use trainer variables such as self.net,
# and thus are instantiated after compile.
metrics = metrics or []
self.metrics = [metrics_module.get(m) for m in metrics]
def fn_outputs(training: bool, inputs):
with brainstate.environ.context(fit=training):
inputs = jax.tree.map(lambda x: u.math.asarray(x), inputs, is_leaf=u.math.is_quantity)
return self.problem.approximator(inputs)
def fn_outputs_losses(training, inputs, targets, **kwargs):
with brainstate.environ.context(fit=training):
# inputs
inputs = jax.tree.map(lambda x: u.math.asarray(x), inputs, is_leaf=u.math.is_quantity)
# outputs
outputs = self.problem.approximator(inputs)
# targets
if targets is not None:
targets = jax.tree.map(lambda x: u.math.asarray(x), targets, is_leaf=u.math.is_quantity)
# compute losses
if training:
losses = self.problem.losses_train(inputs, outputs, targets, **kwargs)
else:
losses = self.problem.losses_test(inputs, outputs, targets, **kwargs)
return outputs, losses
def fn_outputs_losses_train(inputs, targets, **aux):
return fn_outputs_losses(True, inputs, targets, **aux)
def fn_outputs_losses_test(inputs, targets, **aux):
return fn_outputs_losses(False, inputs, targets, **aux)
def fn_train_step(inputs, targets, **aux):
def _loss_fun():
losses = fn_outputs_losses_train(inputs, targets, **aux)[1]
return u.math.sum(u.math.asarray([loss.sum() for loss in jax.tree.leaves(losses)]))
grads = brainstate.transform.grad(_loss_fun, grad_states=self.params)()
self.optimizer.update(grads)
# Callables
self.fn_outputs = brainstate.transform.jit(fn_outputs, static_argnums=0)
self.fn_outputs_losses_train = brainstate.transform.jit(fn_outputs_losses_train)
self.fn_outputs_losses_test = brainstate.transform.jit(fn_outputs_losses_test)
self.fn_train_step = brainstate.transform.jit(fn_train_step)
if measture_train_step_compile_time:
t0 = time.time()
self._compile_training_step(self.batch_size)
t1 = time.time()
return self, t1 - t0
return self
[docs]
@utils.timing
def train(
self,
iterations: int,
batch_size: int = None,
display_every: int = 1000,
disregard_previous_best: bool = False,
callbacks: Union[Callback, Sequence[Callback]] = None,
model_restore_path: str = None,
model_save_path: str = None,
measture_train_step_time: bool = False,
):
"""
Trains the trainer.
Args:
iterations (Integer): Number of iterations to train the trainer, i.e., number
of times the network weights are updated.
batch_size: Integer, tuple, or ``None``.
- If you solve PDEs via ``pinnx.problem.PDE`` or ``pinnx.problem.TimePDE``, do not use `batch_size`,
and instead use `pinnx.callbacks.PDEPointResampler
<https://deepxde.readthedocs.io/en/latest/modules/deepxde.html#deepxde.callbacks.PDEPointResampler>`_,
see an `example <https://github.com/lululxvi/deepxde/blob/master/examples/pinn_forward/diffusion_1d_resample.py>`_.
- For DeepONet in the format of Cartesian product, if `batch_size` is an Integer,
then it is the batch size for the branch input;
if you want to also use mini-batch for the trunk net input,
set `batch_size` as a tuple, where the fist number is the batch size for the branch net input
and the second number is the batch size for the trunk net input.
display_every (Integer): Print the loss and metrics every this steps.
disregard_previous_best: If ``True``, disregard the previous saved best
trainer.
callbacks: List of ``pinnx.callbacks.Callback`` instances. List of callbacks
to apply during training.
model_restore_path (String): Path where parameters were previously saved.
model_save_path (String): Prefix of filenames created for the checkpoint.
"""
if measture_train_step_time:
t0 = time.time()
if self.metrics is None:
raise ValueError("Compile the trainer before training.")
# callbacks
callbacks = CallbackList(callbacks=[callbacks] if isinstance(callbacks, Callback) else callbacks)
callbacks.set_model(self)
# disregard previous best
if disregard_previous_best:
self.train_state.disregard_best()
# restore
if model_restore_path is not None:
self.restore(model_restore_path, verbose=1)
print("Training trainer...\n")
self.stop_training = False
# testing
self.train_state.set_data_train(*self.problem.train_next_batch(batch_size))
self.train_state.set_data_test(*self.problem.test())
self._test()
# training
callbacks.on_train_begin()
self._train(iterations, display_every, batch_size, callbacks)
callbacks.on_train_end()
# summary
print("")
training_display.summary(self.train_state)
if model_save_path is not None:
self.save(model_save_path, verbose=1)
if measture_train_step_time:
t1 = time.time()
return self, t1 - t0
return self
def _compile_training_step(self, batch_size=None):
# get data
self.train_state.set_data_train(*self.problem.train_next_batch(batch_size))
# train one batch
self.fn_train_step.compile(self.train_state.X_train,
self.train_state.y_train,
**self.train_state.Aux_train)
def _train(self, iterations, display_every, batch_size, callbacks):
for i in range(iterations):
callbacks.on_epoch_begin()
callbacks.on_batch_begin()
# get data
self.train_state.set_data_train(*self.problem.train_next_batch(batch_size))
# train one batch
self.fn_train_step(self.train_state.X_train, self.train_state.y_train, **self.train_state.Aux_train)
self.train_state.epoch += 1
self.train_state.step += 1
if self.train_state.step % display_every == 0 or i + 1 == iterations:
self._test()
callbacks.on_batch_end()
callbacks.on_epoch_end()
if self.stop_training:
break
def _test(self):
# evaluate the training data
(
self.train_state.y_pred_train,
self.train_state.loss_train,
) = self.fn_outputs_losses_train(
self.train_state.X_train,
self.train_state.y_train,
**self.train_state.Aux_train,
)
# evaluate the test data
(
self.train_state.y_pred_test,
self.train_state.loss_test
) = self.fn_outputs_losses_test(
self.train_state.X_test,
self.train_state.y_test,
**self.train_state.Aux_test,
)
# metrics
if isinstance(self.train_state.y_test, (list, tuple)):
self.train_state.metrics_test = [
m(self.train_state.y_test[i],
self.train_state.y_pred_test[i])
for m in self.metrics
for i in range(len(self.train_state.y_test))
]
else:
self.train_state.metrics_test = [
m(self.train_state.y_test,
self.train_state.y_pred_test)
for m in self.metrics
]
# history
self.train_state.update_best()
self.loss_history.append(
self.train_state.step,
self.train_state.loss_train,
self.train_state.loss_test,
self.train_state.metrics_test,
)
# check NaN
if (
jnp.isnan(jnp.asarray(jax.tree.leaves(self.train_state.loss_train))).any()
or jnp.isnan(jnp.asarray(jax.tree.leaves(self.train_state.loss_test))).any()
):
self.stop_training = True
# display
training_display(self.train_state)
[docs]
def predict(
self,
xs,
operator: Optional[Callable] = None,
callbacks: Union[Callback, Sequence[Callback]] = None,
):
"""Generates predictions for the input samples. If `operator` is ``None``,
returns the network output, otherwise returns the output of the `operator`.
Args:
xs: The network inputs. A Numpy array or a tuple of Numpy arrays.
operator: A function takes arguments (`neural_net`, `inputs`) and outputs a tensor. `inputs` and
`outputs` are the network input and output tensors, respectively. `operator` is typically
chosen as the PDE (used to define `pinnx.problem.PDE`) to predict the PDE residual.
callbacks: List of ``pinnx.callbacks.Callback`` instances. List of callbacks
to apply during prediction.
"""
xs = jax.tree.map(
lambda x: u.math.asarray(x, dtype=brainstate.environ.dftype()),
xs,
is_leaf=u.math.is_quantity
)
callbacks = CallbackList(callbacks=[callbacks] if isinstance(callbacks, Callback) else callbacks)
callbacks.set_model(self)
callbacks.on_predict_begin()
ys = self.fn_outputs(False, xs)
if operator is not None:
ys = operator(xs, ys)
callbacks.on_predict_end()
return ys
[docs]
def save(self, save_path, verbose: int = 0):
"""Saves all variables to a disk file.
Args:
save_path (string): Prefix of filenames to save the trainer file.
verbose (int): Verbosity mode, 0 or 1.
Returns:
string: Path where trainer is saved.
"""
import braintools
# save path
save_path = f"{save_path}-{self.train_state.epoch}.msgpack"
# avoid the duplicate ParamState save
model = brainstate.graph.Dict(params=self.params, optimizer=self.optimizer)
checkpoint = brainstate.graph.states(model).to_nest()
braintools.file.msgpack_save(save_path, checkpoint)
if verbose > 0:
print(
"Epoch {}: saving trainer to {} ...\n".format(
self.train_state.epoch, save_path
)
)
return save_path
[docs]
def restore(self, save_path, verbose: int = 0):
"""Restore all variables from a disk file.
Args:
save_path (string): Path where trainer was previously saved.
verbose (int): Verbosity mode, 0 or 1.
"""
import braintools
if verbose > 0:
print("Restoring trainer from {} ...\n".format(save_path))
data = brainstate.graph.Dict(params=self.params, optimizer=self.optimizer)
checkpoint = brainstate.graph.states(data).to_nest()
braintools.file.msgpack_load(save_path, target=checkpoint)
[docs]
def saveplot(
self,
issave: bool = True,
isplot: bool = True,
loss_fname: str = "loss.dat",
train_fname: str = "train.dat",
test_fname: str = "test.dat",
output_dir: str = None,
):
"""
Saves and plots the loss and metrics.
Args:
issave: If ``True``, save the loss and metrics to files.
isplot: If ``True``, plot the loss and metrics.
loss_fname: Filename to save the loss.
train_fname: Filename to save the training metrics.
test_fname: Filename to save the test metrics.
output_dir: Directory to save the files.
"""
utils.saveplot(
self.loss_history,
self.train_state,
issave=issave,
isplot=isplot,
loss_fname=loss_fname,
train_fname=train_fname,
test_fname=test_fname,
output_dir=output_dir,
)
class TrainState:
__module__ = 'pinnx'
def __init__(self):
self.epoch = 0
self.step = 0
# Current data
self.X_train = None
self.y_train = None
self.Aux_train = dict()
self.X_test = None
self.y_test = None
self.Aux_test = dict()
# Results of current step
# Train results
self.loss_train = None
self.y_pred_train = None
# Test results
self.loss_test = None
self.y_pred_test = None
self.y_std_test = None
self.metrics_test = None
# The best results correspond to the min train loss
self.best_step = 0
self.best_loss_train = np.inf
self.best_loss_test = np.inf
self.best_y = None
self.best_ystd = None
self.best_metrics = None
def set_data_train(self, X_train, y_train, *args):
self.X_train = X_train
self.y_train = y_train
if len(args) > 0:
assert len(args) == 1, "Auxiliary training data must be a single argument."
assert isinstance(args[0], dict), "Auxiliary training data must be a dictionary."
self.Aux_train = args[0]
def set_data_test(self, X_test, y_test, *args):
self.X_test = X_test
self.y_test = y_test
if len(args) > 0:
assert len(args) == 1, "Auxiliary test data must be a single argument."
assert isinstance(args[0], dict), "Auxiliary test data must be a dictionary."
self.Aux_test = args[0]
def update_best(self):
current_loss_train = jnp.sum(jnp.asarray(jax.tree.leaves(self.loss_train)))
if self.best_loss_train > current_loss_train:
self.best_step = self.step
self.best_loss_train = current_loss_train
self.best_loss_test = jnp.sum(jnp.asarray(jax.tree.leaves(self.loss_test)))
self.best_y = self.y_pred_test
self.best_ystd = self.y_std_test
self.best_metrics = self.metrics_test
def disregard_best(self):
self.best_loss_train = np.inf
class LossHistory:
__module__ = 'pinnx'
def __init__(self):
self.steps = []
self.loss_train = []
self.loss_test = []
self.metrics_test = []
def append(self, step, loss_train, loss_test, metrics_test):
self.steps.append(step)
self.loss_train.append(loss_train)
if loss_test is None:
loss_test = self.loss_test[-1]
if metrics_test is None:
metrics_test = self.metrics_test[-1]
self.loss_test.append(loss_test)
self.metrics_test.append(metrics_test)