Source code for pinnx.utils.external

# Rewrite of the original file in DeepXDE: https://github.com/lululxvi/deepxde
# ==============================================================================


"""External utilities."""

import csv
import importlib.util
import os
from multiprocessing import Pool

import braintools
import brainunit as u
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

sklearn_installed = importlib.util.find_spec("sklearn")

if sklearn_installed:
    from sklearn import preprocessing


[docs] def apply(func, args=None, kwds=None): """Launch a new process to call the function. This can be used to clear Tensorflow GPU memory after trainer execution: https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution """ with Pool(1) as p: if args is None and kwds is None: r = p.apply(func) elif kwds is None: r = p.apply(func, args=args) elif args is None: r = p.apply(func, kwds=kwds) else: r = p.apply(func, args=args, kwds=kwds) return r
[docs] def standardize(X_train, X_test): """Standardize features by removing the mean and scaling to unit variance. The mean and std are computed from the training data `X_train` using `sklearn.preprocessing.StandardScaler <https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html>`_, and then applied to the testing data `X_test`. Args: X_train: A NumPy array of shape (n_samples, n_features). The data used to compute the mean and standard deviation used for later scaling along the features axis. X_test: A NumPy array. Returns: scaler: Instance of ``sklearn.preprocessing.StandardScaler``. X_train: Transformed training data. X_test: Transformed testing data. """ train_exp_dim = False if u.math.ndim(X_train) == 1: train_exp_dim = True X_train = X_train.reshape(-1, 1) test_exp_dim = False if u.math.ndim(X_test) == 1: test_exp_dim = True X_test = X_test.reshape(-1, 1) if not sklearn_installed: raise ImportError("scikit-learn is not installed. Please install it to use the standardize function.") scaler = preprocessing.StandardScaler(with_mean=True, with_std=True) X_train = scaler.fit_transform(X_train) X_test = scaler.transform(X_test) if train_exp_dim: X_train = X_train.flatten() if test_exp_dim: X_test = X_test.flatten() return X_train, X_test
[docs] def saveplot( loss_history, train_state, issave=True, isplot=True, loss_fname="loss.dat", train_fname="train.dat", test_fname="test.dat", output_dir=None, ): """Save/plot the loss history and best trained result. This function is used to quickly check your results. To better investigate your result, use ``save_loss_history()`` and ``save_best_state()``. Args: loss_history: ``LossHistory`` instance. The first variable returned from ``Trainer.train()``. train_state: ``TrainState`` instance. The second variable returned from ``Trainer.train()``. issave (bool): Set ``True`` (default) to save the loss, training points, and testing points. isplot (bool): Set ``True`` (default) to plot loss, metric, and the predicted solution. loss_fname (string): Name of the file to save the loss in. train_fname (string): Name of the file to save the training points in. test_fname (string): Name of the file to save the testing points in. output_dir (string): If ``None``, use the current working directory. """ if output_dir is None: output_dir = os.getcwd() if not os.path.exists(output_dir): print(f"Warning: Directory {output_dir} doesn't exist. Creating it.") os.mkdir(output_dir) if issave: loss_fname = os.path.join(output_dir, loss_fname) train_fname = os.path.join(output_dir, train_fname) test_fname = os.path.join(output_dir, test_fname) save_loss_history(loss_history, loss_fname) save_best_state(train_state, train_fname, test_fname) if isplot: plot_loss_history(loss_history) plot_best_state(train_state) plt.show()
[docs] def plot_loss_history(loss_history, fname=None): """Plot the training and testing loss history. Note: You need to call ``plt.show()`` to show the figure. Args: loss_history: ``LossHistory`` instance. The first variable returned from ``Trainer.train()``. fname (string): If `fname` is a string (e.g., 'loss_history.png'), then save the figure to the file of the file name `fname`. """ # np.sum(loss_history.loss_train, axis=1) is error-prone for arrays of varying lengths. # Handle irregular array sizes. loss_train = jnp.array([jnp.sum(jnp.asarray(jax.tree.leaves(loss))) for loss in loss_history.loss_train]) loss_test = jnp.array([jnp.sum(jnp.asarray(jax.tree.leaves(loss))) for loss in loss_history.loss_test]) plt.figure() plt.semilogy(loss_history.steps, loss_train, label="Train loss") plt.semilogy(loss_history.steps, loss_test, label="Test loss") metric_tests = jax.tree.map(lambda *a: u.math.asarray(a), *loss_history.metrics_test) for i in range(len(loss_history.metrics_test[0])): if isinstance(metric_tests[i], dict): for k, v in metric_tests[i].items(): plt.semilogy(loss_history.steps, v, label=f"Test metric {k}") else: plt.semilogy(loss_history.steps, metric_tests[i], label=f"Test metric {i}") plt.xlabel("# Steps") plt.legend() if isinstance(fname, str): plt.savefig(fname)
[docs] def save_loss_history(loss_history, fname): """Save the training and testing loss history to a file.""" print("Saving loss history to {} ...".format(fname)) train_losses = jax.tree.map(lambda *a: u.math.asarray(a), *loss_history.loss_train) braintools.file.msgpack_save(fname, train_losses)
def _pack_data(train_state): def merge_values(values): if values is None: return None return jnp.hstack(values) if isinstance(values, (list, tuple)) else values # y_train = merge_values(train_state.y_train) # y_test = merge_values(train_state.y_test) # best_y = merge_values(train_state.best_y) # best_ystd = merge_values(train_state.best_ystd) y_train = train_state.y_train y_test = train_state.y_test best_y = train_state.best_y best_ystd = train_state.best_ystd return y_train, y_test, best_y, best_ystd
[docs] def plot_best_state(train_state): """Plot the best result of the smallest training loss. This function only works for 1D and 2D problems. For other problems and to better customize the figure, use ``save_best_state()``. Note: You need to call ``plt.show()`` to show the figure. Args: train_state: ``TrainState`` instance. The second variable returned from ``Trainer.train()``. """ if isinstance(train_state.X_train, (list, tuple)): print("Error: The network has multiple inputs, and plotting such result hasn't been implemented.") return y_train, y_test, best_y, best_ystd = _pack_data(train_state) xkeys = tuple(train_state.X_test.keys()) # Regression plot # 1D if len(train_state.X_test) == 1: idx = u.math.argsort(train_state.X_test[xkeys[0]]) X = train_state.X_test[xkeys[0]][idx] plt.figure() for ykey in best_y: if y_train is not None: plt.plot(train_state.X_train[xkeys[0]], y_train[ykey], "ok", label="Train") if y_test is not None: plt.plot(X, y_test[ykey], "-k", label="True") y_val, y_unit = u.split_mantissa_unit(best_y[ykey]) plt.plot( X, y_val, "--r", label=(f"{ykey} Prediction" if y_unit.is_unitless else f"{ykey} Prediction [{y_unit}]") ) if best_ystd is not None: ystd_val = u.get_magnitude(best_ystd[ykey].to(y_unit)) plt.plot(X, y_val + 1.96 * ystd_val, "-b", label="95% CI") plt.plot(X, y_val - 1.96 * ystd_val, "-b") plt.xlabel("x") plt.ylabel("y") plt.legend() # 2D elif len(train_state.X_test) == 2: for ykey in best_y: plt.figure() ax = plt.axes(projection=Axes3D.name) ax.plot3D( u.get_magnitude(train_state.X_test[xkeys[0]]), u.get_magnitude(train_state.X_test[xkeys[1]]), u.get_magnitude(best_y[ykey]), ".", ) unit = u.get_unit(train_state.X_test[xkeys[0]]) if unit.is_unitless: ax.set_xlabel(f'{xkeys[0]}') else: ax.set_xlabel(f'{xkeys[0]} [{unit}]') unit = u.get_unit(train_state.X_test[xkeys[1]]) if unit.is_unitless: ax.set_ylabel(f'{xkeys[1]}') else: ax.set_ylabel(f'{xkeys[1]} [{unit}]') unit = u.get_unit(best_y[ykey]) if unit.is_unitless: ax.set_zlabel(f'{ykey}') else: ax.set_zlabel(f'{ykey} [{unit}]')
# Residual plot # Not necessary to plot # if y_test is not None: # plt.figure() # residual = y_test[:, 0] - best_y[:, 0] # plt.plot(best_y[:, 0], residual, "o", zorder=1) # plt.hlines(0, plt.xlim()[0], plt.xlim()[1], linestyles="dashed", zorder=2) # plt.xlabel("Predicted") # plt.ylabel("Residual = Observed - Predicted") # plt.tight_layout() # Uncertainty plot # Not necessary to plot # if best_ystd is not None: # plt.figure() # for i in range(y_dim): # plt.plot(train_state.X_test[:, 0], best_ystd[:, i], "-b") # plt.plot( # train_state.X_train[:, 0], # np.interp( # train_state.X_train[:, 0], train_state.X_test[:, 0], best_ystd[:, i] # ), # "ok", # ) # plt.xlabel("x") # plt.ylabel("std(y)")
[docs] def save_best_state(train_state, fname_train, fname_test): """Save the best result of the smallest training loss to a file.""" if isinstance(train_state.X_train, (list, tuple)): print("Error: The network has multiple inputs, and saving such result han't been implemented.") return print("Saving training data to {} ...".format(fname_train)) y_train, y_test, best_y, best_ystd = _pack_data(train_state) if y_train is None: data = {'X_train': train_state.X_train} else: data = {'X_train': train_state.X_train, 'y_train': y_train} braintools.file.msgpack_save(fname_train, data) print("Saving test data to {} ...".format(fname_test)) if y_test is None: data = {'X_test': train_state.X_test, 'best_y': best_y} if best_ystd is not None: data['best_ystd'] = best_ystd braintools.file.msgpack_save(fname_test, data) else: data = {'X_test': train_state.X_test, 'best_y': best_y, 'y_test': y_test} if best_ystd is not None: data['best_ystd'] = best_ystd braintools.file.msgpack_save(fname_test, data)
[docs] def dat_to_csv(dat_file_path, csv_file_path, columns): """Converts a dat file to CSV format and saves it. Args: dat_file_path (string): Path of the dat file. csv_file_path (string): Desired path of the CSV file. columns (list): Column names to be added in the CSV file. """ with open(dat_file_path, "r", encoding="utf-8") as dat_file, open( csv_file_path, "w", encoding="utf-8", newline="" ) as csv_file: csv_writer = csv.writer(csv_file) csv_writer.writerow(columns) for line in dat_file: if "#" in line: continue row = [field.strip() for field in line.split(" ")] csv_writer.writerow(row)
[docs] def isclose(a, b): """A modified version of `np.isclose` for DeepXDE. This function changes the value of `atol` due to the dtype of `a` and `b`. If the dtype is float16, `atol` is `1e-4`. If it is float32, `atol` is `1e-6`. Otherwise (for float64), the default is `1e-8`. If you want to manually set `atol` for some reason, use `np.isclose` instead. Args: a, b (array like): DictToArray arrays to compare. """ pack = smart_numpy(a) a_dtype = a.dtype a_unit = u.get_unit(a) if a_dtype == jnp.float32: atol = u.maybe_decimal(u.Quantity(1e-6, unit=a_unit)) elif a_dtype == jnp.float16: atol = u.maybe_decimal(u.Quantity(1e-4, unit=a_unit)) else: atol = u.maybe_decimal(u.Quantity(1e-8, unit=a_unit)) return pack.isclose(a, b, atol=atol)
[docs] def smart_numpy(x): if isinstance(x, jnp.ndarray): return jnp elif isinstance(x, jax.Array): return jax.numpy elif isinstance(x, u.Quantity): return u.math else: raise TypeError(f"Unknown type {type(x)}.")