"""Base classes for trackers.
.. autosummary::
:nosignatures:
TrackerBase
TrackerCollection
TransformedTrackerBase
FinishedSimulation
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import logging
import math
import warnings
from abc import ABCMeta, abstractmethod
from collections.abc import Callable, Sequence
from typing import Any, Union
from ..fields.base import FieldBase
from ..tools.docstrings import fill_in_docstring
from ..tools.misc import module_available
from .interrupts import InterruptData, parse_interrupt
_base_logger = logging.getLogger(__name__.rsplit(".", 1)[0])
""":class:`logging.Logger`: Base logger for trackers."""
InfoDict = dict[str, Any] | None
TrackerDataType = Union["TrackerBase", str]
[docs]
class FinishedSimulation(StopIteration):
"""Exception for signaling that simulation finished successfully."""
[docs]
class TrackerBase(metaclass=ABCMeta):
"""Base class for implementing trackers."""
_logger: logging.Logger
_subclasses: dict[str, type[TrackerBase]] = {} # all inheriting classes
@fill_in_docstring
def __init__(self, interrupts: InterruptData = 1):
"""
Args:
interrupts:
{ARG_TRACKER_INTERRUPT}
"""
self.interrupt = parse_interrupt(interrupts)
def __init_subclass__(cls, **kwargs):
"""Initialize class-level attributes of subclasses."""
super().__init_subclass__(**kwargs)
# create logger for this specific field class
cls._logger = _base_logger.getChild(cls.__qualname__)
# register all subclasses to reconstruct them later
if hasattr(cls, "name"):
assert cls.name != "auto"
cls._subclasses[cls.name] = cls
[docs]
@classmethod
def from_data(cls, data: TrackerDataType, **kwargs) -> TrackerBase:
"""Create tracker class from given data.
Args:
data (str or TrackerBase): Data describing the tracker
**kwargs: Additional keyword arguments passed to the tracker constructor
Returns:
:class:`TrackerBase`: An instance representing the tracker
"""
if isinstance(data, TrackerBase):
return data
if isinstance(data, str):
try:
tracker_cls = cls._subclasses[data]
except KeyError as err:
trackers = sorted(cls._subclasses)
msg = f"Tracker `{data}` is not in {trackers}"
raise ValueError(msg) from err
return tracker_cls(**kwargs)
msg = f"Unsupported tracker format: `{data}`."
raise ValueError(msg)
[docs]
def initialize(self, field: FieldBase, info: InfoDict | None = None) -> float:
"""Initialize the tracker with information about the simulation.
Args:
field (:class:`~pde.fields.FieldBase`):
An example of the data that will be analyzed by the tracker
info (dict):
Extra information from the simulation
Returns:
float: The first time the tracker needs to handle data
"""
if info is not None:
t_start = info.get("controller", {}).get("t_start", 0)
else:
t_start = 0
return self.interrupt.initialize(t_start)
[docs]
@abstractmethod
def handle(self, field: FieldBase, t: float) -> None:
"""Handle data supplied to this tracker.
Args:
field (:class:`~pde.fields.FieldBase`):
The current state of the simulation
t (float):
The associated time
"""
[docs]
def finalize(self, info: InfoDict | None = None) -> None:
"""Finalize the tracker, supplying additional information.
Args:
info (dict):
Extra information from the simulation
"""
TransformationType = Callable[[FieldBase, float], FieldBase] | None
TrackerCollectionDataType = Sequence[TrackerDataType] | TrackerDataType | None
[docs]
class TrackerCollection:
"""List of trackers providing methods to handle them efficiently.
Attributes:
trackers (list):
List of the trackers in the collection
"""
tracker_action_times: list[float]
""" list: Times at which the trackers need to be handled next """
time_next_action: float
""" float: The time of the next interrupt of the simulation """
def __init__(self, trackers: list[TrackerBase] | None = None):
"""
Args:
trackers: List of trackers that are to be handled.
"""
if trackers is None:
self.trackers: list[TrackerBase] = []
elif not hasattr(trackers, "__iter__"):
msg = f"`trackers` must be a list of trackers, not {trackers}"
raise ValueError(msg)
else:
self.trackers = trackers
# do not check trackers before everything was initialized
self.tracker_action_times = []
self.time_next_action = math.inf
def __len__(self) -> int:
"""Returns the number of trackers in the collection."""
return len(self.trackers)
[docs]
@classmethod
def from_data(cls, data: TrackerCollectionDataType, **kwargs) -> TrackerCollection:
"""Create tracker collection from given data.
Args:
data: Data describing the tracker collection
**kwargs: Additional keyword arguments passed to tracker constructors
Returns:
:class:`TrackerCollection`:
An instance representing the tracker collection
"""
if data == "auto":
if module_available("tqdm"):
data = ("progress", "consistency")
else:
data = "consistency"
if data is None:
trackers: list[TrackerBase] = []
elif isinstance(data, TrackerCollection):
trackers = data.trackers
elif isinstance(data, TrackerBase):
trackers = [data]
elif isinstance(data, str):
trackers = [TrackerBase.from_data(data, **kwargs)]
elif isinstance(data, (list, tuple)):
# initialize trackers from a sequence
trackers, interrupt_ids = [], set()
for tracker in data:
if tracker is not None:
tracker_obj = TrackerBase.from_data(tracker)
if id(tracker_obj.interrupt) in interrupt_ids:
# make sure that different trackers never use the same interrupt
# class, which would be problematic during iteration
tracker_obj.interrupt = tracker_obj.interrupt.copy()
interrupt_ids.add(id(tracker_obj.interrupt))
trackers.append(tracker_obj)
else:
msg = f"Cannot initialize trackers from class `{data.__class__}`"
raise TypeError(msg)
return cls(trackers)
[docs]
def initialize(self, field: FieldBase, info: InfoDict | None = None) -> float:
"""Initialize the tracker with information about the simulation.
Args:
field (:class:`~pde.fields.FieldBase`):
An example of the data that will be analyzed by the tracker
info (dict):
Extra information from the simulation
Returns:
float: The first time the tracker needs to handle data
"""
# initialize trackers and get their action times
self.tracker_action_times = [
tracker.initialize(field, info) for tracker in self.trackers
]
if self.trackers:
# determine next time to check trackers
self.time_next_action = min(self.tracker_action_times)
else:
self.time_next_action = math.inf
return self.time_next_action
[docs]
def handle(self, state: FieldBase, t: float, atol: float = 1.0e-8) -> float:
"""Handle all trackers.
Args:
state (:class:`~pde.fields.FieldBase`):
The current state of the simulation
t (float):
The associated time
atol (float):
An absolute tolerance that is used to determine whether a
tracker should be called now or whether the simulation should be
carried on more timesteps. This is basically used to predict the
next time to decided which one is closer.
Returns:
float: The next time the simulation needs to be interrupted to
handle a tracker.
"""
# check each tracker to see whether we need to handle it
stop_iteration_err = None
for i, t_next in enumerate(self.tracker_action_times):
if t > t_next - atol:
try:
self.trackers[i].handle(state, t)
except StopIteration as err:
# This tracker requested to stop the iteration. We save this
# information for later, so we can first handle all trackers.
stop_iteration_err = err
# calculate next event (may skip some if too close)
self.tracker_action_times[i] = self.trackers[i].interrupt.next(t)
if stop_iteration_err is not None:
raise stop_iteration_err
# determine next time for checking handler
if self.trackers:
self.time_next_action = min(self.tracker_action_times)
return self.time_next_action
[docs]
def finalize(self, info: InfoDict | None = None) -> None:
"""Finalize the tracker, supplying additional information.
Args:
info (dict):
Extra information from the simulation
"""
for tracker in self.trackers:
tracker.finalize(info=info)
def get_named_trackers() -> dict[str, type[TrackerBase]]:
"""Returns all named trackers.
Returns:
dict: a mapping of names to the actual tracker classes.
"""
# Deprecated on 2025-12-23
warnings.warn(
"`get_named_trackers` is deprecated. Use `registered_trackers` instead",
DeprecationWarning,
stacklevel=2,
)
return TrackerBase._subclasses.copy()
[docs]
def registered_trackers() -> dict[str, type[TrackerBase]]:
"""Returns all trackers that are currently registered.
Returns:
dict: a dictionary with the names of the trackers and the associated class
"""
return {
name: cls
for name, cls in TrackerBase._subclasses.items()
if not (name.endswith("Base"))
}
__all__ = [
"FinishedSimulation",
"TrackerBase",
"TrackerCollection",
"TransformedTrackerBase",
"registered_trackers",
]