Source code for pde.trackers.interrupts

"""
Module defining classes for time interrupts for trackers

The provided interrupt classes are:

.. autosummary::
   :nosignatures:

   FixedInterrupts
   ConstantInterrupts
   LogarithmicInterrupts
   RealtimeInterrupts

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de> 
"""

from __future__ import annotations

import copy
import math
import time
from abc import ABCMeta, abstractmethod
from collections.abc import Sequence
from typing import Any, Optional, TypeVar, Union

import numpy as np

from ..tools.parse_duration import parse_duration

InfoDict = Optional[dict[str, Any]]
TInterrupt = TypeVar("TInterrupt", bound="InterruptsBase")


[docs] class InterruptsBase(metaclass=ABCMeta): """base class for implementing interrupts""" dt: float """float: current time difference between interrupts"""
[docs] @abstractmethod def copy(self: TInterrupt) -> TInterrupt: """return a copy of this instance"""
[docs] @abstractmethod def initialize(self, t: float) -> float: """initialize the interrupt class Args: t (float): The starting time of the simulation Returns: float: The first time the simulation needs to be interrupted """
[docs] @abstractmethod def next(self, t: float) -> float: """computes the next time point Args: t (float): The current time point of the simulation. The returned next time point lies later than this time, so interrupts might be skipped. Returns: float: The next time point """
[docs] class FixedInterrupts(InterruptsBase): """class representing a list of interrupt times""" def __init__(self, interrupts: np.ndarray | Sequence[float]): self.interrupts = np.atleast_1d(interrupts) if self.interrupts.ndim != 1: raise ValueError("`interrupts` must be a 1d sequence") def __repr__(self): return f"{self.__class__.__name__}(interrupts={self.interrupts})"
[docs] def copy(self): return self.__class__(interrupts=self.interrupts.copy())
[docs] def initialize(self, t: float) -> float: self._index = -1 return self.next(t)
[docs] def next(self, t: float) -> float: try: # Determine time of the last interrupt. This value does not make much sense # for the first interrupt, so we simply use the current time if self._index < 0: t_last = t else: t_last = self.interrupts[self._index] # fetch the next entry that is after the current time `t` self._index += 1 t_next: float = self.interrupts[self._index] # fetch next time point while t_next < t: # ensure time point lies in the future self._index += 1 t_next = self.interrupts[self._index] self.dt = t_next - t_last return t_next except IndexError: # iterator has been exhausted -> never break again return math.inf
[docs] class ConstantInterrupts(InterruptsBase): """class representing equidistantly spaced time interrupts""" def __init__(self, dt: float = 1, t_start: float | None = None): """ Args: dt (float): The duration between subsequent interrupts. This is measured in simulation time units. t_start (float, optional): The time after which the tracker becomes active. If omitted, the tracker starts recording right away. This argument can be used for an initial equilibration period during which no data is recorded. """ self.dt = float(dt) self.t_start = t_start self._t_next: float | None = None # next time it should be called def __repr__(self): return f"{self.__class__.__name__}(dt={self.dt:g}, t_start={self.t_start})"
[docs] def copy(self): return copy.copy(self)
[docs] def initialize(self, t: float) -> float: if self.t_start is None: self._t_next = t else: self._t_next = max(t, self.t_start) return self._t_next
[docs] def next(self, t: float) -> float: # move next interrupt time by the appropriate interrupt self._t_next += self.dt # type: ignore # make sure that the new interrupt time is in the future if self._t_next <= t: # add `dt` until _t_next is in the future (larger than t) n = math.ceil((t - self._t_next) / self.dt) self._t_next += self.dt * n # adjust in special cases where float-point math fails us if self._t_next < t: self._t_next += self.dt return self._t_next
[docs] class LogarithmicInterrupts(ConstantInterrupts): """class representing logarithmically spaced time interrupts""" def __init__( self, dt_initial: float = 1, factor: float = 1, t_start: float | None = None ): """ Args: dt_initial (float): The initial duration between subsequent interrupts. This is measured in simulation time units. factor (float): The factor by which the time between interrupts is increased every time. Values larger than one lead to time interrupts that are increasingly further apart. t_start (float, optional): The time after which the tracker becomes active. If omitted, the tracker starts recording right away. This argument can be used for an initial equilibration period during which no data is recorded. """ super().__init__(dt=dt_initial / factor, t_start=t_start) self.factor = float(factor) def __repr__(self): return ( f"{self.__class__.__name__}(dt={self.dt:g}, " f"factor={self.factor:g}, t_start={self.t_start})" )
[docs] def next(self, t: float) -> float: self.dt *= self.factor return super().next(t)
[docs] class RealtimeInterrupts(ConstantInterrupts): """class representing time interrupts spaced equidistantly in real time This spacing is only achieved approximately and depends on the initial value set by `dt_initial` and the actual variation in computation speed. """ def __init__(self, duration: float | str, dt_initial: float = 0.01): """ Args: duration (float or str): The duration (in real seconds) that the interrupts should be spaced apart. The duration can also be given as a string, which is then parsed using the function :func:`~pde.tools.parse_duration.parse_duration`. dt_initial (float): The initial duration between subsequent interrupts. This is measured in simulation time units. """ super().__init__(dt=dt_initial) try: self.duration = float(duration) except Exception: td = parse_duration(str(duration)) self.duration = td.total_seconds() self._last_time: float | None = None def __repr__(self): return ( f"{self.__class__.__name__}(duration={self.duration:g}, " f"dt_initial={self.dt:g})" )
[docs] def initialize(self, t: float) -> float: self._last_time = time.monotonic() return super().initialize(t)
[docs] def next(self, t: float) -> float: if self._last_time is None: self._last_time = time.monotonic() else: # adapt time step current_time = time.monotonic() time_passed = current_time - self._last_time if time_passed > 0: # predict new time step, but limit it from below, to avoid problems with # simulations where a single step takes a long time dt_predict = max(1e-3, self.dt * self.duration / time_passed) # use geometric average to provide some smoothing self.dt = math.sqrt(self.dt * dt_predict) else: self.dt *= 2 self._last_time = current_time return super().next(t)
InterruptData = Union[InterruptsBase, float, str, Sequence[float], np.ndarray]
[docs] def parse_interrupt(data: InterruptData) -> InterruptsBase: """create interrupt class from various data formats Args: data (str or number or :class:`InterruptsBase`): Data determining the interrupt class. If this is a :class:`InterruptsBase`, it is simply returned, numbers imply :class:`ConstantInterrupts`, a string is parsed as a time for :class:`RealtimeInterrupts`, and lists are interpreted as :class:`FixedInterrupts`. Returns: :class:`InterruptsBase`: An instance that represents the interrupt """ if isinstance(data, InterruptsBase): return data elif isinstance(data, (int, float)): return ConstantInterrupts(data) elif isinstance(data, str): return RealtimeInterrupts(data) elif hasattr(data, "__iter__"): return FixedInterrupts(data) else: raise TypeError(f"Cannot parse interrupt data `{data}`")