"""Base classes for storing data.
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
from __future__ import annotations
import logging
from abc import ABCMeta, abstractmethod
from collections.abc import Iterator, Sequence
from inspect import signature
from typing import TYPE_CHECKING, Any, Callable, Literal
import numpy as np
from numpy.typing import DTypeLike
from ..fields import FieldCollection, ScalarField, Tensor2Field, VectorField
from ..fields.base import FieldBase
from ..fields.datafield_base import DataFieldBase
from ..grids.base import GridBase
from ..tools.docstrings import fill_in_docstring
from ..tools.output import display_progress
from ..trackers.base import InfoDict, TrackerBase
from ..trackers.interrupts import InterruptData
from .memory import MemoryStorage
_base_logger = logging.getLogger(__name__.rsplit(".", 1)[0])
""":class:`logging.Logger`: Base logger for storage."""
WriteModeType = Literal[
class StorageBase(metaclass=ABCMeta):
"""Base class for storing time series of discretized fields.
These classes store time series of :class:`~pde.fields.base.FieldBase`, i.e., they
store the values of the fields at particular time points. Iterating of the storage
will return the fields in order and individual time points can also be accessed.
times: Sequence[float] # stored time points
data: Any # actual data for all the stored times
write_mode: WriteModeType # mode determining how the storage behaves
_logger: logging.Logger
def __init__(
info: InfoDict | None = None,
write_mode: WriteModeType = "truncate_once",
info (dict):
Supplies extra information that is stored in the storage
write_mode (str):
Determines how new data is added to already existing one. Possible
values are: 'append' (data is always appended), 'truncate' (data is
cleared every time this storage is used for writing), or 'truncate_once'
(data is cleared for the first writing, but subsequent data using the
same instances are appended). Alternatively, specifying 'readonly' will
disable writing completely.
self.info = {} if info is None else info
self.write_mode = write_mode
self._data_shape: tuple[int, ...] | None = None
self._dtype: DTypeLike | None = None
self._grid: GridBase | None = None
self._field: FieldBase | None = None
def __init_subclass__(cls, **kwargs) -> None:
"""Initialize class-level attributes of subclasses."""
# create logger for this specific field class
cls._logger = _base_logger.getChild(cls.__qualname__)
def data_shape(self) -> tuple[int, ...]:
"""The current data shape.
RuntimeError: if data_shape was not set
if self._data_shape is None:
raise RuntimeError("data_shape was not set")
else: # use the else clause to help typing
return self._data_shape
def dtype(self) -> DTypeLike:
"""The current data type.
RuntimeError: if data_type was not set
if self._dtype is None:
raise RuntimeError("dtype was not set")
else: # use the else clause to help typing
return self._dtype
def _append_data(self, data: np.ndarray, time: float) -> None:
"""Append a new data set.
data (:class:`~numpy.ndarray`): The actual data
time (float, optional): The time point associated with the data
def append(self, field: FieldBase, time: float | None = None) -> None:
"""Add field to the storage.
field (:class:`~pde.fields.base.FieldBase`):
The field that is added to the storage
time (float, optional):
The time point
if time is None:
time = 0 if len(self) == 0 else self.times[-1] + 1
if self._grid is None:
self._grid = field.grid
elif self._grid != field.grid:
raise ValueError(f"Grids incompatible ({self._grid} != {field.grid})")
return self._append_data(field.data, time)
def clear(self, clear_data_shape: bool = False) -> None:
"""Truncate the storage by removing all stored data.
clear_data_shape (bool):
Flag determining whether the data shape is also deleted.
if clear_data_shape:
self._data_shape = None
self._dtype = None
def __len__(self):
"""Return the number of stored items, i.e., time steps."""
return len(self.times)
def shape(self) -> tuple[int, ...] | None:
"""The shape of the stored data."""
if self._data_shape:
return (len(self),) + self._data_shape
return None
def has_collection(self) -> bool:
"""bool: whether the storage is storing a collection"""
if self._field is not None:
return isinstance(self._field, FieldCollection)
elif len(self) > 0:
return isinstance(self._get_field(0), FieldCollection)
raise RuntimeError("Storage is empty")
def grid(self) -> GridBase | None:
"""GridBase: the grid associated with this storage
This returns `None` if grid was not stored in `self.info`.
if self._grid is None:
if "field_attributes" in self.info:
attrs_serialized = self.info["field_attributes"]
attrs = FieldBase.unserialize_attributes(attrs_serialized)
# load the grid if possible
if "grid" in attrs:
# load grid from only stored field
self._grid = attrs["grid"]
elif "fields" in attrs:
# load grid from first field of a collection
self._grid = attrs["fields"][0]["grid"]
"`grid` attribute was not stored. Available attributes: %s",
", ".join(sorted(attrs.keys())),
self._logger.warning("Field attributes are unavailable in info")
return self._grid
def _init_field(self) -> None:
"""Initialize internal field variable."""
if self.grid is None:
raise RuntimeError(
"Could not load grid from data. Please set the `_grid` attribute "
"to the grid that has been used for the stored data"
if "field_attributes" in self.info:
# field type was stored in data
attrs_serialized = self.info["field_attributes"]
attrs = FieldBase.unserialize_attributes(attrs_serialized)
self._field = FieldBase.from_state(attrs)
# try to determine field type automatically
# obtain data shape by removing the first axis (associated with the time
# series and the last axes (associated with the spatial dimensions). What is
# left should be the (local) data stored at each grid point for each time
# step. Note that self.data might be a list of arrays
local_shape = self.data[0].shape[: -self.grid.num_axes]
dim = self.grid.dim
if len(local_shape) == 0: # rank 0
self._field = ScalarField(self.grid, dtype=self.data[0].dtype)
elif local_shape == (dim,): # rank 1
self._field = VectorField(self.grid, dtype=self.data[0].dtype)
elif local_shape == (dim, dim): # rank 2
self._field = Tensor2Field(self.grid, dtype=self.data[0].dtype)
raise RuntimeError(
"`field` attribute was not stored in file and the data shape "
f"{local_shape} could not be interpreted automatically"
"`field` attribute was not stored. Assume data is of type %s.",
def _get_field(self, t_index: int) -> FieldBase:
"""Return the field corresponding to the given time index.
Load the data given an index, i.e., the data at time `self.times[t_index]`.
t_index (int):
The index of the data to load
The field class containing the grid and data
if t_index < 0:
t_index += len(self)
if not 0 <= t_index < len(self):
raise IndexError("Time index out of range")
if self._field is None:
# create the field with the data of the given index
assert self._field is not None
field = self._field.copy()
field.data = self.data[t_index]
return field
def __getitem__(self, key: int | slice) -> FieldBase | list[FieldBase]:
"""Return field at given index or a list of fields for a slice."""
if isinstance(key, int):
return self._get_field(key)
elif isinstance(key, slice):
return [self._get_field(i) for i in range(*key.indices(len(self)))]
raise TypeError("Unknown key type")
def __iter__(self) -> Iterator[FieldBase]:
"""Iterate over all stored fields."""
for i in range(len(self)):
yield self[i] # type: ignore
def items(self) -> Iterator[tuple[float, FieldBase]]:
"""Iterate over all times and stored fields, returning pairs."""
for i in range(len(self)):
yield self.times[i], self[i] # type: ignore
def tracker(
interrupts: InterruptData = 1,
transformation: Callable[[FieldBase, float], FieldBase] | None = None,
) -> StorageTracker:
"""Create object that can be used as a tracker to fill this storage.
transformation (callable, optional):
A function that transforms the current state into a new field or field
collection, which is then stored. This allows to store derived
quantities of the field during calculations. The argument needs to be a
callable function taking 1 or 2 arguments. The first argument always is
the current field, while the optional second argument is the associated
:class:`StorageTracker`: The tracker that fills the current storage
The `transformation` argument allows storing additional fields:
.. code-block:: python
def add_to_state(state):
transformed_field = state.smooth(1)
return field.append(transformed_field)
storage = pde.MemoryStorage()
tracker = storage.tracker(1, transformation=add_to_state)
eq.solve(..., tracker=tracker)
In this example, :obj:`storage` will contain a trajectory of the fields of
the simulation as well as the smoothed fields. Other transformations are
possible by defining appropriate :func:`add_to_state`
return StorageTracker(
def start_writing(self, field: FieldBase, info: InfoDict | None = None) -> None:
"""Initialize the storage for writing data.
field (:class:`~pde.fields.FieldBase`):
An example of the data that will be written to extract the grid
and the data_shape
info (dict):
Supplies extra information that is stored in the storage
if self.write_mode == "readonly":
raise RuntimeError("Cannot write data in readonly mode")
if self._data_shape is None:
self._data_shape = field.data.shape
elif self.data_shape != field.data.shape:
raise ValueError("Data shape incompatible with stored data")
if self._dtype is None:
self._dtype = field.dtype
self._grid = field.grid
self._field = field.copy()
self.info["field_attributes"] = field.attributes_serialized
def end_writing(self) -> None:
"""Finalize the storage after writing."""
def view_field(self, field_id: int | str) -> StorageView:
"""Returns a view into this storage focusing on a particular field.
Modifying data returned by the view will modify the underlying storage
field_id (int or str):
The index into the field collection. This determines which field of the
collection is returned. Instead of a numerical index, the field label
can also be supplied. If there are multiple fields with the same label,
only the first field is returned.
:class:`StorageView`: A view into the storage only returning a single field
return StorageView(storage=self, field=field_id)
def apply(
func: Callable,
out: StorageBase | None = None,
progress: bool = False,
) -> StorageBase:
"""Applies function to each field in a storage.
func (callable):
The function to apply to each stored field. The function must either
take as a single argument the field or as two arguments the field and
the associated time point. In both cases, it should return a field.
out (:class:`~pde.storage.base.StorageBase`):
Storage to which the output is written. If omitted, a new
:class:`~pde.storage.memory.MemoryStorage` is used and returned
progress (bool):
Flag indicating whether the progress is shown during the calculation
:class:`~pde.storage.base.StorageBase`: The new storage that contains the
data after the function `func` has been applied
# get the number of arguments that the user function expects
num_args = len(signature(func).parameters)
writing = False # flag indicating whether output storage was opened
for t, field in display_progress(
self.items(), total=len(self), enabled=progress
# apply the user function
if num_args == 0:
transformed = func()
elif num_args == 1:
transformed = func(field)
transformed = func(field, t)
if not isinstance(transformed, FieldBase):
raise TypeError("The user function must return a field")
if out is None:
from .memory import MemoryStorage
out = MemoryStorage(field_obj=transformed)
if not writing:
writing = True
out.append(transformed, t)
if writing:
out.end_writing() # type: ignore
# make sure that a storage is returned, even when no fields are present
if out is None:
from .memory import MemoryStorage
out = MemoryStorage()
return out
def copy(
self, out: StorageBase | None = None, *, progress: bool = False
) -> StorageBase:
"""Copies all fields in a storage to a new one.
out (:class:`~pde.storage.base.StorageBase`):
Storage to which the output is written. If omitted, a new
:class:`~pde.storage.memory.MemoryStorage` is used and returned
progress (bool):
Flag indicating whether the progress is shown during the calculation
:class:`~pde.storage.base.StorageBase`: The new storage that contains the
copied data
# apply the identity function to do the copy
return self.apply(lambda x: x, out=out, progress=progress)
class StorageTracker(TrackerBase):
"""Tracker that stores data in special storage classes.
storage (:class:`~pde.storage.base.StorageBase`):
The underlying storage class through which the data can be accessed
def __init__(
interrupts: InterruptData = 1,
transformation: Callable[[FieldBase, float], FieldBase] | None = None,
storage (:class:`~pde.storage.base.StorageBase`):
Storage instance to which the data is written
transformation (callable, optional):
A function that transforms the current state into a new field or field
collection, which is then stored. This allows to store derived
quantities of the field during calculations. The argument needs to be a
callable function taking 1 or 2 arguments. The first argument always is
the current field, while the optional second argument is the associated
super().__init__(interrupts=interrupts, interval=interval)
self.storage = storage
if transformation is not None and not callable(transformation):
raise TypeError("`transformation` must be callable")
self.transformation = transformation
def _transform(self, field: FieldBase, t: float) -> FieldBase:
"""Transforms the field according to the defined transformation."""
if self.transformation is None:
return field
elif self.transformation.__code__.co_argcount == 1:
return self.transformation(field) # type: ignore
return self.transformation(field, t)
def initialize(self, field: FieldBase, info: InfoDict | None = None) -> float:
field (:class:`~pde.fields.FieldBase`):
An example of the data that will be analyzed by the tracker
info (dict):
Extra information from the simulation
float: The first time the tracker needs to handle data
t_first = super().initialize(field, info)
self.storage.start_writing(self._transform(field, t_first), info)
return t_first
def handle(self, field: FieldBase, t: float) -> None:
"""Handle data supplied to this tracker.
field (:class:`~pde.fields.FieldBase`):
The current state of the simulation
t (float): The associated time
self.storage.append(self._transform(field, t), time=t)
def finalize(self, info: InfoDict | None = None) -> None:
"""Finalize the tracker, supplying additional information.
info (dict):
Extra information from the simulation
class StorageView:
"""Represents a view into a storage that extracts a particular field."""
has_collection: bool = False
def __init__(self, storage: StorageBase, *, field: int | str):
storage (:class:`~pde.storage.base.StorageBase`):
The storage providing the basic data
field (int or str):
The index into the field collection determining which field of the
collection is returned. Instead of a numerical index, the field label
can also be supplied. If there are multiple fields with the same label,
only the first field is returned.
self.storage = storage
if not self.storage.has_collection:
raise RuntimeError("Can only create view into Storage of field collection")
if isinstance(field, str):
self.field_index = self.storage._field.labels.index(field) # type: ignore
self.field_index = field
def times(self) -> Sequence[float]:
return self.storage.times
def grid(self) -> GridBase | None:
return self.storage.grid
def __len__(self):
return len(self.storage)
def __getitem__(self, key: int) -> DataFieldBase:
"""Return field at given index or a list of fields for a slice."""
return self.storage[key][self.field_index] # type: ignore
def __iter__(self) -> Iterator[DataFieldBase]:
"""Iterate over all stored fields."""
for fields in self.storage:
yield fields[self.field_index] # type: ignore
def items(self) -> Iterator[tuple[float, DataFieldBase]]:
"""Iterate over all times and stored fields, returning pairs."""
for k, v in self.storage.items():
yield k, v[self.field_index] # type: ignore