Source code for pinnx.nn.convert

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Dict

import brainstate
import brainunit as u

__all__ = [
    'DictToArray',
    'ArrayToDict',
]


def dict_to_array(
    d: Dict[str, brainstate.typing.ArrayLike],
    axis: int = 1
):
    """
    Convert a dictionary to an array.

    Args:
        d (dict): The dictionary.
        axis (int): The axis to concatenate.

    Returns:
        ndarray: The array.
    """
    keys = tuple(d.keys())
    return u.math.stack([d[key] for key in keys], axis=axis)


[docs] class DictToArray(brainstate.nn.Module): """ DictToArray layer, scaling the input data according to the given units, and merging them into an array. Args: axis (int): The axis to concatenate. **units: The units for each input. The unit should be the instance of ``brainunit.Unit``, but it can be None. """ def __init__(self, axis: int = -1, **units): super().__init__() # axis assert isinstance(axis, int), f"DictToArray axis must be an integer. Please check the input values." self.axis = axis # unit scale self.units = units for val in units.values(): assert isinstance(val, u.Unit) or val is None, (f"DictToArray values must be a unit or None. " "Please check the input values.") self.in_size = len(units) self.out_size = len(units) def update(self, x: Dict[str, brainstate.typing.ArrayLike]): assert set(x.keys()) == set(self.units.keys()), (f"DictToArray keys mismatch. " f"{set(x.keys())} != {set(self.units.keys())}.") # scale the input x_dict = dict() for key in self.units.keys(): val = x[key] if isinstance(self.units[key], u.Unit): assert (isinstance(val, u.Quantity) or self.units[key].dim == u.DIMENSIONLESS), ( f"DictToArray values must be a quantity. " "Please check the input values.") x_dict[key] = val.to_decimal(self.units[key]) if isinstance(val, u.Quantity) else val else: x_dict[key] = u.maybe_decimal(val) # convert to array arr = dict_to_array(x_dict, axis=self.axis) return arr
[docs] class ArrayToDict(brainstate.nn.Module): """ Output layer, splitting the output data into a dict and assign the corresponding units. Args: axis (int): The axis to split the output data. **units: The units of the output data. The unit should be the instance of ``brainunit.Unit``, but it can be None. """ def __init__(self, axis: int = -1, **units): super().__init__() assert isinstance(axis, int), f"Output axis must be an integer. " self.axis = axis self.units = units for val in units.values(): assert isinstance(val, u.Unit) or val is None, (f"Input values must be a unit or None. " "Please check the input values.") self.in_size = len(units) self.out_size = len(units) def update(self, arr: brainstate.typing.ArrayLike) -> Dict[str, brainstate.typing.ArrayLike]: assert arr.shape[self.axis] == len(self.units), (f"The number of columns of x must be " f"equal to the number of units. " f"Got {arr.shape[self.axis]} != {len(self.units)}. " "Please check the input values.") shape = list(arr.shape) shape.pop(self.axis) xs = u.math.split(arr, len(self.units), axis=self.axis) keys = tuple(self.units.keys()) units = tuple(self.units.values()) res = dict() for key, unit, x in zip(keys, units, xs): res[key] = u.math.squeeze(x, axis=self.axis) if unit is not None: res[key] *= unit return res