Source code for pinnx.callbacks

# Rewrite of the original file in DeepXDE: https://github.com/lululxvi/deepxde
# ==============================================================================

import sys
import time
from typing import Dict, Callable

import brainstate
import brainunit as u
import jax.tree
import numpy as np

from . import utils

__all__ = [
    'Callback',
    'CallbackList',
    'ModelCheckpoint',
    'EarlyStopping',
    'Timer',
    'DropoutUncertainty',
    'VariableValue',
    'OperatorPredictor',
    'MovieDumper',
    'PDEPointResampler',
]


[docs] class Callback: """ Callback base class. Attributes: model: instance of ``Trainer``. Reference of the trainer being trained. """ def __init__(self): self.model = None def set_model(self, model): if model is not self.model: self.model = model self.init()
[docs] def init(self): """Init after setting a trainer."""
[docs] def on_epoch_begin(self): """Called at the beginning of every epoch."""
[docs] def on_epoch_end(self): """Called at the end of every epoch."""
[docs] def on_batch_begin(self): """Called at the beginning of every batch."""
[docs] def on_batch_end(self): """Called at the end of every batch."""
[docs] def on_train_begin(self): """Called at the beginning of trainer training."""
[docs] def on_train_end(self): """Called at the end of trainer training."""
[docs] def on_predict_begin(self): """Called at the beginning of prediction."""
[docs] def on_predict_end(self): """Called at the end of prediction."""
[docs] class CallbackList(Callback): """Container abstracting a list of callbacks. Args: callbacks: List of ``Callback`` instances. """ def __init__(self, callbacks=None): super().__init__() callbacks = callbacks or [] self.callbacks = list(callbacks) def set_model(self, model): self.model = model for callback in self.callbacks: callback.set_model(model)
[docs] def on_epoch_begin(self): for callback in self.callbacks: callback.on_epoch_begin()
[docs] def on_epoch_end(self): for callback in self.callbacks: callback.on_epoch_end()
[docs] def on_batch_begin(self): for callback in self.callbacks: callback.on_batch_begin()
[docs] def on_batch_end(self): for callback in self.callbacks: callback.on_batch_end()
[docs] def on_train_begin(self): for callback in self.callbacks: callback.on_train_begin()
[docs] def on_train_end(self): for callback in self.callbacks: callback.on_train_end()
[docs] def on_predict_begin(self): for callback in self.callbacks: callback.on_predict_begin()
[docs] def on_predict_end(self): for callback in self.callbacks: callback.on_predict_end()
def append(self, callback): if not isinstance(callback, Callback): raise Exception(str(callback) + " is an invalid Callback object") self.callbacks.append(callback)
[docs] class ModelCheckpoint(Callback): """Save the trainer after every epoch. Args: filepath (string): Prefix of filenames to save the trainer file. verbose: Verbosity mode, 0 or 1. save_better_only: If True, only save a better trainer according to the quantity monitored. Trainer is only checked at validation step according to ``display_every`` in ``Trainer.train``. period: Interval (number of epochs) between checkpoints. monitor: The loss function that is monitored. Either 'train loss' or 'test loss'. """ def __init__( self, filepath, verbose=0, save_better_only=False, period=1, monitor="train loss", ): super().__init__() self.filepath = filepath self.verbose = verbose self.save_better_only = save_better_only self.period = period self.monitor = monitor self.monitor_op = np.less self.epochs_since_last_save = 0 self.best = np.inf
[docs] def on_epoch_end(self): self.epochs_since_last_save += 1 if self.epochs_since_last_save < self.period: return self.epochs_since_last_save = 0 if self.save_better_only: current = self.get_monitor_value() if self.monitor_op(current, self.best): save_path = self.model.save(self.filepath, verbose=0) if self.verbose > 0: print( "Epoch {}: {} improved from {:.2e} to {:.2e}, saving trainer to {} ...\n".format( self.model.train_state.epoch, self.monitor, self.best, current, save_path, ) ) self.best = current else: self.model.save(self.filepath, verbose=self.verbose)
def get_monitor_value(self): if self.monitor == "train loss": result = sum(self.model.train_state.loss_train) elif self.monitor == "test loss": result = sum(self.model.train_state.loss_test) else: raise ValueError("The specified monitor function is incorrect.") return result
[docs] class EarlyStopping(Callback): """Stop training when a monitored quantity (training or testing loss) has stopped improving. Only checked at validation step according to ``display_every`` in ``Trainer.train``. Args: min_delta: Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. patience: Number of epochs with no improvement after which training will be stopped. baseline: Baseline value for the monitored quantity to reach. Training will stop if the trainer doesn't show improvement over the baseline. monitor: The loss function that is monitored. Either 'loss_train' or 'loss_test' start_from_epoch: Number of epochs to wait before starting to monitor improvement. This allows for a warm-up period in which no improvement is expected and thus training will not be stopped. """ def __init__( self, min_delta=0, patience=0, baseline=None, monitor="loss_train", start_from_epoch=0, ): super().__init__() self.baseline = baseline self.monitor = monitor self.patience = patience self.min_delta = min_delta self.wait = 0 self.stopped_epoch = 0 self.start_from_epoch = start_from_epoch self.monitor_op = np.less self.min_delta *= -1
[docs] def on_train_begin(self): # Allow instances to be re-used self.wait = 0 self.stopped_epoch = 0 if self.baseline is not None: self.best = self.baseline else: self.best = np.inf if self.monitor_op == np.less else -np.inf
[docs] def on_epoch_end(self): if self.model.train_state.epoch < self.start_from_epoch: return current = self.get_monitor_value() if self.monitor_op(current - self.min_delta, self.best): self.best = current self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = self.model.train_state.epoch self.model.stop_training = True
[docs] def on_train_end(self): if self.stopped_epoch > 0: print("Epoch {}: early stopping".format(self.stopped_epoch))
def get_monitor_value(self): if self.monitor == "loss_train": result = np.sum(jax.tree.leaves(self.model.train_state.loss_train)) elif self.monitor == "loss_test": result = np.sum(jax.tree.leaves(self.model.train_state.loss_test)) else: raise ValueError("The specified monitor function is incorrect.") return result
[docs] class Timer(Callback): """Stop training when training time reaches the threshold. This Timer starts after the first call of `on_train_begin`. Args: available_time (float): Total time (in minutes) available for the training. """ def __init__(self, available_time): super().__init__() self.threshold = available_time * 60 # convert to seconds self.t_start = None
[docs] def on_train_begin(self): if self.t_start is None: self.t_start = time.time()
[docs] def on_epoch_end(self): if time.time() - self.t_start > self.threshold: self.model.stop_training = True print( "\nStop training as time used up. time used: {:.1f} mins, epoch trained: {}".format( (time.time() - self.t_start) / 60, self.model.train_state.epoch ) )
[docs] class DropoutUncertainty(Callback): """Uncertainty estimation via MC dropout. References: `Y. Gal, & Z. Ghahramani. Dropout as a Bayesian approximation: Representing trainer uncertainty in deep learning. International Conference on Machine Learning, 2016 <https://arxiv.org/abs/1506.02142>`_. Warning: This cannot be used together with other techniques that have different behaviors during training and testing, such as batch normalization. """ def __init__(self, period=1000): super().__init__() self.period = period self.epochs_since_last = 0
[docs] def on_epoch_end(self): self.epochs_since_last += 1 if self.epochs_since_last >= self.period: self.epochs_since_last = 0 y_preds = [] for _ in range(1000): y_pred_test_one = self.model.fn_outputs(True, self.model.train_state.X_test) y_preds.append(y_pred_test_one) y_preds = jax.tree.map(lambda *x: u.math.stack(x, axis=0), *y_preds, is_leaf=u.math.is_quantity) self.model.train_state.y_std_test = jax.tree.map(lambda x: u.math.std(x, axis=0), y_preds, is_leaf=u.math.is_quantity)
[docs] def on_train_end(self): self.on_epoch_end()
[docs] class VariableValue(Callback): """Get the variable values. Args: var_list: A `TensorFlow Variable <https://www.tensorflow.org/api_docs/python/tf/Variable>`_ or a list of TensorFlow Variable. period (int): Interval (number of epochs) between checking values. filename (string): Output the values to the file `filename`. The file is kept open to allow instances to be re-used. If ``None``, output to the screen. precision (int): The precision of variables to display. """ def __init__(self, var_list, period=1, filename=None, precision=2): super().__init__() self.var_list = var_list if isinstance(var_list, (tuple, list)) else [var_list] for v in self.var_list: if not isinstance(v, brainstate.State): raise ValueError("The variable must be a brainstate.State object.") self.period = period self.precision = precision self.file = sys.stdout if filename is None else open(filename, "w", buffering=1) self.value = None self.epochs_since_last = 0
[docs] def on_train_begin(self): self.value = [var.value for var in self.var_list] print( self.model.train_state.epoch, utils.list_to_str(self.value, precision=self.precision), file=self.file, ) self.file.flush()
[docs] def on_epoch_end(self): self.epochs_since_last += 1 if self.epochs_since_last >= self.period: self.epochs_since_last = 0 self.on_train_begin()
[docs] def on_train_end(self): if not self.epochs_since_last == 0: self.on_train_begin()
[docs] def get_value(self): """Return the variable values.""" return self.value
[docs] class OperatorPredictor(Callback): """ Generates operator values for the input samples. Args: x: The input data. op: The operator with inputs (x, y). period (int): Interval (number of epochs) between checking values. filename (string): Output the values to the file `filename`. The file is kept open to allow instances to be re-used. If ``None``, output to the screen. precision (int): The precision of variables to display. """ def __init__( self, x: Dict, op: Callable, period: int = 1, filename: str = None, precision: int = 2 ): super().__init__() self.x = x self.op = op self.period = period self.precision = precision self.file = sys.stdout if filename is None else open(filename, "w", buffering=1) self.value = None self.epochs_since_last = 0
[docs] def init(self): pass
[docs] def on_train_begin(self): self.on_predict_end() print( self.model.train_state.epoch, # utils.list_to_str(self.value.flatten().tolist(), precision=self.precision), utils.tree_repr(self.value, precision=self.precision), file=self.file, ) self.file.flush()
[docs] def on_train_end(self): if not self.epochs_since_last == 0: self.on_train_begin()
[docs] def on_epoch_end(self): self.epochs_since_last += 1 if self.epochs_since_last >= self.period: self.epochs_since_last = 0 self.on_train_begin()
[docs] def on_predict_end(self): self.value = self._eval()
# self.value = jax.tree.map(np.asarray, self._eval()) @brainstate.transform.jit(static_argnums=0) def _eval(self): with brainstate.environ.context(fit=False): outputs = self.model.problem.approximator(self.x) return self.op(self.x, outputs) def get_value(self): return self.value
[docs] class MovieDumper(Callback): """Dump a movie to show the training progress of the function along a line. Args: spectrum: If True, dump the spectrum of the Fourier transform. """ def __init__( self, filename, x1, x2, num_points=100, period=1, component=0, save_spectrum=False, y_reference=None, ): super().__init__() self.filename = filename x1 = np.array(x1) x2 = np.array(x2) self.x = ( x1 + (x2 - x1) / (num_points - 1) * np.arange(num_points)[:, None] ).astype(dtype=brainstate.environ.dftype()) self.period = period self.component = component self.save_spectrum = save_spectrum self.y_reference = y_reference self.y = [] self.spectrum = [] self.epochs_since_last_save = 0
[docs] def on_train_begin(self): self.y.append(self.model.fn_outputs(False, self.x)[:, self.component]) if self.save_spectrum: A = np.fft.rfft(self.y[-1]) self.spectrum.append(np.abs(A))
[docs] def on_epoch_end(self): self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 self.on_train_begin()
[docs] def on_train_end(self): fname_x = self.filename + "_x.txt" fname_y = self.filename + "_y.txt" fname_movie = self.filename + "_y.gif" print( "\nSaving the movie of function to {}, {}, {}...".format( fname_x, fname_y, fname_movie ) ) np.savetxt(fname_x, self.x) np.savetxt(fname_y, np.array(self.y)) if self.y_reference is None: utils.save_animation(fname_movie, np.ravel(self.x), self.y) else: y_reference = np.ravel(self.y_reference(self.x)) utils.save_animation( fname_movie, np.ravel(self.x), self.y, y_reference=y_reference ) if self.save_spectrum: fname_spec = self.filename + "_spectrum.txt" fname_movie = self.filename + "_spectrum.gif" print( "Saving the movie of spectrum to {}, {}...".format( fname_spec, fname_movie ) ) np.savetxt(fname_spec, np.array(self.spectrum)) xdata = np.arange(len(self.spectrum[0])) if self.y_reference is None: utils.save_animation(fname_movie, xdata, self.spectrum, logy=True) else: A = np.fft.rfft(y_reference) utils.save_animation( fname_movie, xdata, self.spectrum, logy=True, y_reference=np.abs(A) )
[docs] class PDEPointResampler(Callback): """Resample the training points for PDE and/or BC losses every given period. Args: period: How often to resample the training points (default is 100 iterations). pde_points: If True, resample the training points for PDE losses (default is True). bc_points: If True, resample the training points for BC losses (default is False; only supported by PyTorch and PaddlePaddle backend currently). """ def __init__(self, period=100, pde_points=True, bc_points=False): super().__init__() self.period = period self.pde_points = pde_points self.bc_points = bc_points self.num_bcs_initial = None self.epochs_since_last_resample = 0
[docs] def on_train_begin(self): self.num_bcs_initial = self.model.problem.num_bcs
[docs] def on_epoch_end(self): self.epochs_since_last_resample += 1 if self.epochs_since_last_resample < self.period: return self.epochs_since_last_resample = 0 self.model.problem.resample_train_points(self.pde_points, self.bc_points) if not np.array_equal(self.num_bcs_initial, self.model.problem.num_bcs): print("Initial value of self.num_bcs:", self.num_bcs_initial) print("self.trainer.problem.num_bcs:", self.model.problem.num_bcs) raise ValueError( "`num_bcs` changed! Please update the loss function by `trainer.compile`." )