"""Module defining classes for tracking results from simulations.
The trackers defined in this module are:
.. autosummary::
:nosignatures:
CallbackTracker
ProgressTracker
PrintTracker
PlotTracker
LivePlotTracker
DataTracker
SteadyStateTracker
RuntimeTracker
ConsistencyTracker
MaterialConservationTracker
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import inspect
import math
import os.path
import sys
import time
from datetime import timedelta
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable
import numpy as np
from ..fields import FieldCollection
from ..fields.base import FieldBase
from ..tools.docstrings import fill_in_docstring
from ..tools.misc import module_available
from ..tools.output import get_progress_bar_class
from ..tools.parse_duration import parse_duration
from ..tools.typing import Real
from ..visualization.movies import Movie
from .base import FinishedSimulation, InfoDict, TrackerBase
from .interrupts import InterruptData, RealtimeInterrupts
if TYPE_CHECKING:
import pandas # noqa: ICN001
[docs]
class CallbackTracker(TrackerBase):
"""Tracker calling a function periodically.
Example:
The callback tracker can be used to check for conditions during the simulation:
.. code-block:: python
def check_simulation(state, time):
if state.integral < 0:
raise StopIteration
tracker = CallbackTracker(check_simulation, interval="0:10")
Adding :code:`tracker` to the simulation will perform a check every 10 real time
seconds. If the integral of the entire state falls below zero, the simulation
will be aborted.
"""
@fill_in_docstring
def __init__(
self,
func: Callable,
interrupts: InterruptData = 1,
*,
interval=None,
):
"""
Args:
func:
The function to call periodically. The function signature should be
`(state)` or `(state, time)`, where `state` contains the current state
as an instance of :class:`~pde.fields.base.FieldBase` and `time` is a
float value indicating the current time. Note that only a view of the
state is supplied, implying that a copy needs to be made if the data
should be stored. The function can thus adjust the state by modifying it
in-place and it can even interrupt the simulation by raising the special
exception :class:`StopIteration`.
interrupts:
{ARG_TRACKER_INTERRUPT}
"""
super().__init__(interrupts=interrupts, interval=interval)
self._callback = func
self._num_args = len(inspect.signature(func).parameters)
if not 0 < self._num_args < 3:
raise ValueError(
"`func` must be a function accepting one or two arguments, not "
f"{self._num_args}"
)
[docs]
def handle(self, field: FieldBase, t: float) -> None:
"""Handle data supplied to this tracker.
Args:
field (:class:`~pde.fields.FieldBase`):
The current state of the simulation
t (float):
The associated time
"""
if self._num_args == 1:
self._callback(field)
else:
self._callback(field, t)
[docs]
class ProgressTracker(TrackerBase):
"""Tracker showing the progress of the simulation."""
name = "progress"
@fill_in_docstring
def __init__(
self,
interrupts: InterruptData | None = None,
*,
fancy: bool = True,
ndigits: int = 5,
leave: bool = True,
interval=None,
):
"""
Args:
interrupts:
{ARG_TRACKER_INTERRUPT}
The default value `None` updates the progress bar approximately every
(real) second.
fancy (bool):
Flag determining whether a fancy progress bar should be used in jupyter
notebooks (if :mod:`ipywidgets` is installed)
ndigits (int):
The number of digits after the decimal point that are shown maximally.
leave (bool):
Whether to leave the progress bar after the simulation has finished
(default: True)
"""
if interrupts is None:
interrupts = RealtimeInterrupts(duration=1) # print every second by default
super().__init__(interrupts=interrupts, interval=interval)
self.fancy = fancy
self.ndigits = ndigits
self.leave = leave
[docs]
def initialize(self, field: FieldBase, info: InfoDict | None = None) -> float:
"""Initialize the tracker with information about the simulation.
Args:
field (:class:`~pde.fields.FieldBase`):
An example of the data that will be analyzed by the tracker
info (dict):
Extra information from the simulation
Returns:
float: The first time the tracker needs to handle data
"""
result = super().initialize(field, info)
# get solver information
controller_info = {} if info is None else info.get("controller", {})
# initialize the progress bar
pb_cls = get_progress_bar_class(self.fancy)
self.progress_bar = pb_cls(
total=controller_info.get("t_end"),
initial=controller_info.get("t_start", 0),
leave=self.leave,
)
self.progress_bar.set_description("Initializing")
return result
[docs]
def handle(self, field: FieldBase, t: float) -> None:
"""Handle data supplied to this tracker.
Args:
field (:class:`~pde.fields.FieldBase`):
The current state of the simulation
t (float):
The associated time
"""
# show an update
if self.progress_bar.total:
t_new = min(t, self.progress_bar.total)
else:
t_new = t
self.progress_bar.n = round(t_new, self.ndigits)
self.progress_bar.set_description("")
[docs]
def finalize(self, info: InfoDict | None = None) -> None:
"""Finalize the tracker, supplying additional information.
Args:
info (dict):
Extra information from the simulation
"""
super().finalize(info)
self.progress_bar.set_description("")
# limit progress bar to 100%
controller_info = {} if info is None else info.get("controller", {})
t_final = controller_info.get("t_final", -math.inf)
t_end = controller_info.get("t_end", -math.inf)
if t_final >= t_end and self.progress_bar.total:
self.progress_bar.n = self.progress_bar.total
self.progress_bar.refresh()
if (
controller_info.get("successful", False)
and self.leave
and hasattr(self.progress_bar, "sp")
):
# show progress bar in green if simulation was successful. We
# need to overwrite the default behavior (and disable the
# progress bar) since reaching steady state means the simulation
# was successful even though it did not reach t_final
try:
self.progress_bar.sp(bar_style="success")
except TypeError:
self.progress_bar.close()
else:
self.progress_bar.disable = True
else:
self.progress_bar.close()
def __del__(self):
if hasattr(self, "progress_bar") and not self.progress_bar.disable:
self.progress_bar.close()
[docs]
class PrintTracker(TrackerBase):
"""Tracker printing data to a stream (default: stdout)"""
name = "print"
@fill_in_docstring
def __init__(
self,
interrupts: InterruptData = 1,
stream: IO[str] = sys.stdout,
*,
interval=None,
):
"""
Args:
interrupts:
{ARG_TRACKER_INTERRUPT}
stream:
The stream used for printing
"""
super().__init__(interrupts=interrupts, interval=interval)
self.stream = stream
[docs]
def handle(self, field: FieldBase, t: float) -> None:
"""Handle data supplied to this tracker.
Args:
field (:class:`~pde.fields.FieldBase`):
The current state of the simulation
t (float):
The associated time
"""
data = f"c={field.data.mean():.3g}±{field.data.std():.3g}"
self.stream.write(f"t={t:g}, {data}\n")
self.stream.flush()
[docs]
class PlotTracker(TrackerBase):
"""Tracker plotting data on screen, to files, or writes a movie.
This tracker can be used to create movies from simulations or to simply update a
single image file on the fly (i.e. to monitor simulations running on a cluster). The
default values of this tracker are chosen with regular output to a file in mind.
Example:
To create a movie while running the simulation, you can use
.. code-block:: python
movie_tracker = PlotTracker(interval=10, movie="my_movie.mp4")
eq.solve(..., tracker=movie_tracker)
This will create the file `my_movie.mp4` during the simulation. Note that you
can display the frames interactively by setting :code:`show=True`.
"""
@fill_in_docstring
def __init__(
self,
interrupts: InterruptData = 1,
*,
title: str | Callable = "Time: {time:g}",
output_file: str | None = None,
movie: str | Path | Movie | None = None,
show: bool | None = None,
tight_layout: bool = False,
max_fps: float = math.inf,
plot_args: dict[str, Any] | None = None,
interval=None,
):
"""
Args:
interrupts:
{ARG_TRACKER_INTERRUPT}
title (str or callable):
Title text of the figure. If this is a string, it is shown with a
potential placeholder named `time` being replaced by the current
simulation time. Conversely, if `title` is a function, it is called with
the current state and the time as arguments. This function is expected
to return a string.
output_file (str, optional):
Specifies a single image file, which is updated periodically, so that
the progress can be monitored (e.g. on a compute cluster)
movie (str or :class:`~pde.visualization.movies.Movie`):
Create a movie. If a filename is given, all frames are written to this
file in the format deduced from the extension after the simulation ran.
If a :class:`~pde.visualization.movies.Movie` is supplied, frames are
appended to the instance.
show (bool, optional):
Determines whether the plot is shown while the simulation is running. If
set to `None`, the images are only shown if neither `output_file` nor
`movie` is set, otherwise they are kept hidden. Note that showing the
plot can slow down a simulation severely.
tight_layout (bool):
Determines whether :func:`~matplotlib.pyplot.tight_layout` is used.
max_fps (float):
Determines the maximal rate (frames per second) at which the plots are
updated in real time during the simulation. Some plots are skipped if
the tracker receives data at a higher rate. A larger value (e.g.,
`math.inf`) can be used to ensure every frame is drawn, which might
penalizes the overall performance.
plot_args (dict):
Extra arguments supplied to the plot call. For example, this can be used
to specify axes ranges when a single panel is shown. For instance, the
value :code:`{'ax_style': {'ylim': (0, 1)}}` enforces the y-axis to lie
between 0 and 1.
Note:
If an instance of :class:`~pde.visualization.movies.Movie` is given as the
`movie` argument, it can happen that the movie is not written to the file
when the simulation ends. This is because, the movie could still be extended
by appending frames. To write the movie to a file call its
:meth:`~pde.visualization.movies.Movie.save` method. Beside adding frames
before and after the simulation, an explicit movie object can also be used
to adjust the output. For instance, the following example code creates a
movie with a framerate of 15, a resolution of 200 dpi, and a bitrate of 6000
kilobits per second:
.. code-block:: python
movie = Movie("movie.mp4", framerate=15, dpi=200, bitrate=6000)
eq.solve(..., tracker=PlotTracker(1, movie=movie))
movie.save()
"""
from ..visualization.movies import Movie
# initialize the tracker
super().__init__(interrupts=interrupts, interval=interval)
self.title = title
self.output_file = output_file
self.tight_layout = tight_layout
self.max_fps = max_fps
self.plot_args = {} if plot_args is None else plot_args.copy()
# make sure the plot is only create and not shown since the context
# handles showing the plot itself
self.plot_args["action"] = "none"
# initialize the movie class
if movie is None:
self.movie: Movie | None = None
self._save_movie = False
elif isinstance(movie, Movie):
self.movie = movie
self._save_movie = False
elif isinstance(movie, (str, Path)):
self.movie = Movie(filename=str(movie))
self._save_movie = True
else:
raise TypeError(f"Unknown type of `movie`: {movie.__class__.__name__}")
# determine whether to show the images interactively
self._write_images = self._save_movie or self.output_file
if show is None:
self.show = not self._write_images
else:
self.show = show
[docs]
def initialize(self, state: FieldBase, info: InfoDict | None = None) -> float:
"""Initialize the tracker with information about the simulation.
Args:
state (:class:`~pde.fields.FieldBase`):
An example of the data that will be analyzed by the tracker
info (dict):
Extra information from the simulation
Returns:
float: The first time the tracker needs to handle data
"""
# initialize the plotting context
import matplotlib.pyplot as plt
from ..tools.plotting import get_plotting_context
self._context = get_plotting_context(title="Initializing...", show=self.show)
# do the actual plotting
with self._context:
self._plot_reference = state.plot(**self.plot_args)
if self.tight_layout:
plt.gcf().tight_layout()
if self._context.supports_update:
# the context supports reusing figures
if hasattr(state.plot, "update_method"):
# the plotting method supports updating the plot
if state.plot.update_method is None:
if state.plot.mpl_class == "axes": # type: ignore
self._update_method = "update_ax"
elif state.plot.mpl_class == "figure": # type: ignore
self._update_method = "update_fig"
else:
mpl_class = state.plot.mpl_class # type: ignore
raise RuntimeError(
f"Unknown mpl_class on plot method: {mpl_class}"
)
else:
self._update_method = "update_data"
else:
raise RuntimeError(
"PlotTracker does not work since the state of type "
f"{state.__class__.__name__} does not use the plot protocol of "
"`pde.tools.plotting`."
)
else:
self._update_method = "replot"
self._logger.info('Update method: "%s"', self._update_method)
self._last_update = time.monotonic()
return super().initialize(state, info=info)
[docs]
def handle(self, state: FieldBase, t: float) -> None:
"""Handle data supplied to this tracker.
Args:
state (:class:`~pde.fields.FieldBase`):
The current state of the simulation
t (float):
The associated time
"""
import matplotlib.pyplot as plt
if not self._write_images:
# check whether we can skip this image
time_passed = time.monotonic() - self._last_update
if time_passed < 1 / self.max_fps:
return # we just recently updated the image
if callable(self.title):
self._context.title = str(self.title(state, t))
else:
self._context.title = self.title.format(time=t)
# update the plot in the correct plotting context
with self._context:
if self._update_method == "update_data":
# the state supports updating the plot data
update_func = getattr(state, state.plot.update_method) # type: ignore
update_func(self._plot_reference)
elif self._update_method == "update_ax":
fig = self._context.fig
fig.clf() # type: ignore
ax = fig.add_subplot(1, 1, 1) # type: ignore
state.plot(ax=ax, **self.plot_args)
elif self._update_method == "update_fig":
fig = self._context.fig
fig.clf() # type: ignore
state.plot(fig=fig, **self.plot_args)
if self.tight_layout:
plt.gcf().tight_layout()
elif self._update_method == "replot":
state.plot(**self.plot_args)
if self.tight_layout:
plt.gcf().tight_layout()
else:
raise RuntimeError(f"Unknown update method `{self._update_method}`")
if self.output_file and self._context.fig is not None:
self._context.fig.savefig(self.output_file)
if self.movie:
self.movie.add_figure(self._context.fig)
self._last_update = time.monotonic()
[docs]
def finalize(self, info: InfoDict | None = None) -> None:
"""Finalize the tracker, supplying additional information.
Args:
info (dict):
Extra information from the simulation
"""
super().finalize(info)
if self._save_movie:
# write out movie file
self.movie.save() # type: ignore
# end recording the movie (e.g. delete temporary files)
self.movie._end() # type: ignore
if not self.show:
self._context.close()
[docs]
class LivePlotTracker(PlotTracker):
"""PlotTracker with defaults for live plotting.
The only difference to :class:`PlotTracker` are the changed default values, where
output is by default shown on screen and the `interval` is set something more
suitable for interactive plotting. In particular, this tracker can be enabled by
simply listing 'plot' as a tracker.
"""
name = "plot"
@fill_in_docstring
def __init__(
self,
interrupts: InterruptData = "0:03",
*,
show: bool = True,
max_fps: float = 2,
interval=None,
**kwargs,
):
"""
Args:
interrupts:
{ARG_TRACKER_INTERRUPT}
title (str):
Text to show in the title. The current time point will be
appended to this text, so include a space for optimal results.
output_file (str, optional):
Specifies a single image file, which is updated periodically, so
that the progress can be monitored (e.g. on a compute cluster)
output_folder (str, optional):
Specifies a folder to which all images are written. The files
will have names with increasing numbers.
movie_file (str, optional):
Specifies a filename to which a movie of all the frames is
written after the simulation.
show (bool, optional):
Determines whether the plot is shown while the simulation is
running. If `False`, the files are created in the background.
This option can slow down a simulation severely.
max_fps (float):
Determines the maximal rate (frames per second) at which the plots are
updated. Some plots are skipped if the tracker receives data at a higher
rate. A larger value (e.g., `math.inf`) can be used to ensure every
frame is drawn, which might penalizes the overall performance.
plot_args (dict):
Extra arguments supplied to the plot call. For example, this can
be used to specify axes ranges when a single panel is shown. For
instance, the value `{'ax_style': {'ylim': (0, 1)}}` enforces
the y-axis to lie between 0 and 1.
"""
super().__init__(
interrupts=interrupts,
interval=interval,
show=show,
max_fps=max_fps,
**kwargs,
)
[docs]
class DataTracker(CallbackTracker):
"""Tracker storing custom data obtained by calling a function.
Example:
The data tracker can be used to gather statistics during the run
.. code-block:: python
def get_statistics(state, time):
return {"mean": state.data.mean(), "variance": state.data.var()}
data_tracker = DataTracker(get_statistics, interval=10)
Adding :code:`data_tracker` to the simulation will gather the statistics every
10 time units. After the simulation, the final result will be accessable via the
:attr:`data` attribute or conveniently as a pandas from the :attr:`dataframe`
attribute.
Attributes:
times (list):
The time points at which the data is stored
data (list):
The actually stored data, which is a list of the objects returned by
the callback function.
"""
@fill_in_docstring
def __init__(
self,
func: Callable,
interrupts: InterruptData = 1,
*,
filename: str | None = None,
interval=None,
):
"""
Args:
func:
The function to call periodically. The function signature
should be `(state)` or `(state, time)`, where `state` contains
the current state as an instance of
:class:`~pde.fields.FieldBase` and `time` is a
float value indicating the current time. Note that only a view
of the state is supplied, implying that a copy needs to be made
if the data should be stored.
Typical return values of the function are either a single
number, a numpy array, a list of number, or a dictionary to
return multiple numbers with assigned labels.
interrupts:
{ARG_TRACKER_INTERRUPT}
filename (str):
A path to a file to which the data is written at the end of the
tracking. The data format will be determined by the extension
of the filename. '.pickle' indicates a python pickle file
storing a tuple `(self.times, self.data)`, whereas any other
data format requires :mod:`pandas`.
"""
super().__init__(func=func, interrupts=interrupts, interval=interval)
self.filename = filename
self.times: list[float] = []
self.data: list[Any] = []
[docs]
def handle(self, field: FieldBase, t: float) -> None:
"""Handle data supplied to this tracker.
Args:
field (:class:`~pde.fields.FieldBase`):
The current state of the simulation
t (float):
The associated time
"""
self.times.append(t)
if self._num_args == 1:
self.data.append(self._callback(field))
else:
self.data.append(self._callback(field, t))
[docs]
def finalize(self, info: InfoDict | None = None) -> None:
"""Finalize the tracker, supplying additional information.
Args:
info (dict):
Extra information from the simulation
"""
super().finalize(info)
if self.filename:
self.to_file(self.filename)
@property
def dataframe(self) -> pandas.DataFrame:
""":class:`pandas.DataFrame`: the data in a dataframe.
If `func` returns a dictionary, the keys are used as column names.
Otherwise, the returned data is enumerated starting with '0'. In any
case the time point at which the data was recorded is stored in the
column 'time'.
"""
import pandas as pd
df = pd.DataFrame(self.data)
# insert the times and use them as an index
df.insert(0, "time", self.times)
return df
[docs]
def to_file(self, filename: str, **kwargs):
r"""Store data in a file.
The extension of the filename determines what format is being used. For
instance, '.pickle' indicates a python pickle file storing a tuple
`(self.times, self.data)`, whereas any other data format requires
:mod:`pandas`. Supported formats include 'csv', 'json'.
Args:
filename (str):
Path where the data is stored
\**kwargs:
Additional parameters may be supported for some formats
"""
from pathlib import Path
extension = Path(filename).suffix.lower()
if extension == ".pickle":
import pickle
with Path(filename).open("wb") as fp:
pickle.dump((self.times, self.data), fp, **kwargs)
elif extension == ".csv":
self.dataframe.to_csv(filename, **kwargs)
elif extension == ".json":
self.dataframe.to_json(filename, **kwargs)
elif extension in {".xls", ".xlsx"}:
self.dataframe.to_excel(filename, **kwargs)
else:
raise ValueError(f"Unsupported file extension `{extension}`")
[docs]
class SteadyStateTracker(TrackerBase):
"""Tracker aborting the simulation once steady state is reached.
Steady state is obtained when the state does not change anymore, i.e., when the
evolution rate is close to zero. If the argument `evolution_rate` is specified, it
is used to calculate the evolution rate directly. If it is omitted, the evolution
rate is estaimted by comparing the current state `cur` to the state `prev` at the
previous time step. In both cases, convergence is assumed when the absolute value of
the evolution rate falls below :code:`atol + rtol * cur` for all points. Here,
`atol` and `rtol` denote absolute and relative tolerances, respectively.
"""
name = "steady_state"
progress_bar_format = (
"Convergence: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}]"
)
"""Determines the format of the progress bar shown when `progress = True`"""
@fill_in_docstring
def __init__(
self,
interrupts: InterruptData | None = None,
atol: float = 1e-8,
rtol: float = 1e-5,
*,
progress: bool = False,
evolution_rate: Callable[[np.ndarray, float], np.ndarray] | None = None,
interval=None,
):
"""
Args:
interrupts:
{ARG_TRACKER_INTERRUPT}
The default value `None` checks for the steady state approximately every
(real) second.
atol (float):
Absolute tolerance that must be reached to abort the simulation
rtol (float):
Relative tolerance that must be reached to abort the simulation
progress (bool):
Flag indicating whether the progress towards convergence is shown
graphically during the simulation
evolution_rate (callable):
Function to evaluate the current evolution rate. If omitted, the
evolution rate is estimate from the change in the state variable, which
can be less accurate. A suitable form of the function is returned by
`eq.make_pde_rhs(state)` when `eq` is the PDE class.
"""
if interrupts is None:
interrupts = RealtimeInterrupts(duration=1)
super().__init__(interrupts=interrupts, interval=interval)
self.atol = atol
self.rtol = rtol
self.evolution_rate = evolution_rate
self.progress = progress and module_available("tqdm")
self._progress_bar: Any = None
self._pbar_offset: float = 0 # required for calculating progress
self._last_data: np.ndarray | None = None
self._last_time: float | None = None
self._best_rate_max: np.ndarray | None = None
[docs]
def handle(self, field: FieldBase, t: float) -> None:
"""Handle data supplied to this tracker.
Args:
field (:class:`~pde.fields.FieldBase`):
The current state of the simulation
t (float):
The associated time
"""
finite = np.isfinite(field.data) # ignore infinite and nan data
# determine the maximal rate of change
if self.evolution_rate is not None:
# use the evolution_rate function to calculate the rate
evolution_rate = self.evolution_rate(field.data, t)[finite]
elif self._last_data is not None:
# get evolution rate from the difference of current state to previous one
diff = self._last_data[finite] - field.data[finite]
evolution_rate = diff / (t - self._last_time) # type: ignore
# save current data for next comparison
self._last_data[:] = field.data
self._last_time = t
else:
# create storage for the data
self._last_data = field.data.copy()
self._last_time = t
return # do not output anything since we don't know `evolution_rate` yet
# calculate the maximal deviation of the evolution rate from zero, subtracting
# the relative tolerance with respect to the field values
rate_abs = np.abs(evolution_rate) - self.rtol * np.abs(field.data[finite])
rate_abs_max = np.max(rate_abs)
# check wether the simulation has converged
if rate_abs_max <= self.atol:
if self.progress and self._progress_bar is not None:
# advance progress bar to 100%
self._progress_bar.n = self._pbar_offset - np.log10(self.atol)
try:
self._progress_bar.disp(bar_style="success", check_delay=False)
except (TypeError, AttributeError):
self._progress_bar.close()
raise FinishedSimulation("Reached stationary state")
if self.progress:
# show progress of the convergence
if self._best_rate_max is None:
# initialize the progress bar
pb_cls = get_progress_bar_class()
self._pbar_offset = np.log10(rate_abs_max)
self._progress_bar = pb_cls(
total=self._pbar_offset - np.log10(self.atol),
bar_format=self.progress_bar_format,
)
self._best_rate_max = rate_abs_max
elif rate_abs_max < self._best_rate_max:
# update progress bar if simulation got closer to convergence
self._progress_bar.n = self._pbar_offset - np.log10(rate_abs_max)
self._progress_bar.refresh()
self._best_rate_max = rate_abs_max
[docs]
class RuntimeTracker(TrackerBase):
"""Tracker interrupting the simulation once a duration has passed."""
@fill_in_docstring
def __init__(
self, max_runtime: Real | str, interrupts: InterruptData = 1, *, interval=None
):
"""
Args:
max_runtime (float or str):
The maximal runtime of the simulation. If the runtime is exceeded, the
simulation is interrupted. Values can be either given as a number
(interpreted as seconds) or as a string, which is then parsed using the
function :func:`~pde.tools.parse_duration.parse_duration`.
interrupts:
{ARG_TRACKER_INTERRUPT}
"""
super().__init__(interrupts=interrupts, interval=interval)
try:
self.max_runtime = float(max_runtime)
except ValueError:
td = parse_duration(str(max_runtime))
self.max_runtime = td.total_seconds()
[docs]
def initialize(self, field: FieldBase, info: InfoDict | None = None) -> float:
"""
Args:
field (:class:`~pde.fields.FieldBase`):
An example of the data that will be analyzed by the tracker
info (dict):
Extra information from the simulation
Returns:
float: The first time the tracker needs to handle data
"""
self.max_time = time.monotonic() + self.max_runtime
return super().initialize(field, info)
[docs]
def handle(self, field: FieldBase, t: float) -> None:
"""Handle data supplied to this tracker.
Args:
field (:class:`~pde.fields.FieldBase`):
The current state of the simulation
t (float):
The associated time
"""
if time.monotonic() > self.max_time:
dt = timedelta(seconds=self.max_runtime)
raise FinishedSimulation(f"Reached maximal runtime of {str(dt)}")
[docs]
class ConsistencyTracker(TrackerBase):
"""Tracker interrupting the simulation when the state is not finite."""
name = "consistency"
@fill_in_docstring
def __init__(self, interrupts: InterruptData | None = None, *, interval=None):
"""
Args:
interrupts:
{ARG_TRACKER_INTERRUPT}
The default value `None` checks for consistency approximately every
(real) second.
"""
if interrupts is None:
interrupts = RealtimeInterrupts(duration=1)
super().__init__(interrupts=interrupts, interval=interval)
[docs]
def handle(self, field: FieldBase, t: float) -> None:
"""Handle data supplied to this tracker.
Args:
field (:class:`~pde.fields.FieldBase`):
The current state of the simulation
t (float):
The associated time
"""
if not np.all(np.isfinite(field.data)):
raise StopIteration("Field was not finite")
[docs]
class MaterialConservationTracker(TrackerBase):
"""Tracking interrupting the simulation when material conservation is broken."""
name = "material_conservation"
@fill_in_docstring
def __init__(
self,
interrupts: InterruptData = 1,
atol: float = 1e-4,
rtol: float = 1e-4,
*,
interval=None,
):
"""
Args:
interrupts:
{ARG_TRACKER_INTERRUPT}
atol (float):
Absolute tolerance for amount deviations
rtol (float):
Relative tolerance for amount deviations
"""
super().__init__(interrupts=interrupts, interval=interval)
self.atol = atol
self.rtol = rtol
[docs]
def initialize(self, field: FieldBase, info: InfoDict | None = None) -> float:
"""
Args:
field (:class:`~pde.fields.base.FieldBase`):
An example of the data that will be analyzed by the tracker
info (dict):
Extra information from the simulation
Returns:
float: The first time the tracker needs to handle data
"""
if isinstance(field, FieldCollection):
self._reference = np.array([f.magnitude for f in field])
else:
self._reference = field.magnitude # type: ignore
return super().initialize(field, info)
[docs]
def handle(self, field: FieldBase, t: float) -> None:
"""Handle data supplied to this tracker.
Args:
field (:class:`~pde.fields.FieldBase`):
The current state of the simulation
t (float):
The associated time
"""
if isinstance(field, FieldCollection):
mags = np.array([f.magnitude for f in field])
else:
mags = field.magnitude # type: ignore
c = np.isclose(mags, self._reference, rtol=self.rtol, atol=self.atol)
if not np.all(c):
if isinstance(field, FieldCollection):
msg = f"Material of field {np.flatnonzero(~c)} is not conserved"
else:
msg = "Material is not conserved"
raise StopIteration(msg)
__all__ = [
"CallbackTracker",
"ProgressTracker",
"PrintTracker",
"PlotTracker",
"LivePlotTracker",
"DataTracker",
"SteadyStateTracker",
"RuntimeTracker",
"ConsistencyTracker",
"MaterialConservationTracker",
]