Source code for pde.tools.modelrunner

"""Establishes hooks for the interplay between :mod:`pde` and :mod:`modelrunner`

This package defines a function for hook registration, which is  usually called
automatically during import if :mod:`modelrunner` is available. In this case, grids and
fields of :mod:`pde` can be directly written to storages from :mod:`modelrunner.storage`

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from collections.abc import Sequence

from .misc import module_available


[docs] def register_modelrunner_hooks() -> None: """Register `modelrunner` hooks.""" if not module_available("modelrunner"): return from modelrunner.storage import StorageBase, storage_actions from modelrunner.storage.utils import decode_class from ..fields.base import FieldBase from ..grids.base import GridBase # these actions are inherited by all subclasses by default def load_grid(storage: StorageBase, loc: Sequence[str]) -> GridBase: """Function loading a grid from a modelrunner storage. Args: storage (:class:`~modelrunner.storage.group.StorageGroup`): Storage to load data from loc (Location): Location in the storage Returns: :class:`~pde.grids.base.GridBase`: the loaded grid """ # get grid class that was stored stored_cls = decode_class(storage._read_attrs(loc).get("__class__")) state = storage.read_attrs(loc) return stored_cls.from_state(state) # type: ignore storage_actions.register("read_item", GridBase, load_grid) def save_grid(storage: StorageBase, loc: Sequence[str], grid: GridBase) -> None: """Function saving a grid to a modelrunner storage. Args: storage (:class:`~modelrunner.storage.group.StorageGroup`): Storage to save data to loc (Location): Location in the storage grid (:class:`~pde.grids.base.GridBase`): the grid to store """ storage.write_object(loc, None, attrs=grid.state, cls=grid.__class__) storage_actions.register("write_item", GridBase, save_grid) # these actions are inherited by all subclasses by default def load_field(storage: StorageBase, loc: Sequence[str]) -> FieldBase: """Function loading a field from a modelrunner storage. Args: storage (:class:`~modelrunner.storage.group.StorageGroup`): Storage to load data from loc (Location): Location in the storage Returns: :class:`~pde.fields.base.FieldBase`: the loaded field """ # get field class that was stored stored_cls = decode_class(storage._read_attrs(loc).get("__class__")) attributes = stored_cls.unserialize_attributes(storage.read_attrs(loc)) # type: ignore return stored_cls.from_state(attributes, data=storage.read_array(loc)) # type: ignore storage_actions.register("read_item", FieldBase, load_field) def save_field(storage: StorageBase, loc: Sequence[str], field: FieldBase) -> None: """Function saving a field to a modelrunner storage. Args: storage (:class:`~modelrunner.storage.group.StorageGroup`): Storage to save data to loc (Location): Location in the storage field (:class:`~pde.fields.base.FieldBase`): the field to store """ storage.write_array( loc, field.data, attrs=field.attributes_serialized, cls=field.__class__ ) storage_actions.register("write_item", FieldBase, save_field)