"""Functions and classes for plotting simulation data.
.. autosummary::
:nosignatures:
ScalarFieldPlot
plot_magnitudes
plot_kymograph
plot_kymographs
plot_interactive
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import logging
import math
import time
import warnings
from typing import Any, Callable, Literal, Union
import numpy as np
from ..fields import FieldCollection
from ..fields.base import FieldBase
from ..fields.datafield_base import DataFieldBase
from ..storage.base import StorageBase
from ..tools.docstrings import fill_in_docstring
from ..tools.misc import module_available
from ..tools.output import display_progress
from ..tools.plotting import (
PlotReference,
napari_add_layers,
napari_viewer,
plot_on_axes,
plot_on_figure,
)
_logger = logging.getLogger(__name__)
ScaleData = Union[str, float, tuple[float, float]]
def _add_horizontal_colorbar(im, ax, num_loc: int = 5) -> None:
"""Adds a horizontal colorbar for image `im` to the axis `ax`
Args:
im: The result of calling :func:`matplotlib.pyplot.imshow`
ax: The matplotlib axes to which the colorbar is added
num_loc (int): Number of ticks
"""
from matplotlib.ticker import MaxNLocator
from mpl_toolkits.axes_grid1 import make_axes_locatable
fig = ax.figure
divider = make_axes_locatable(ax)
cax = divider.new_vertical(size="5%", pad=0.1, pack_start=True)
fig.add_axes(cax)
cb = fig.colorbar(im, cax=cax, orientation="horizontal")
cb.locator = MaxNLocator(num_loc)
cb.update_ticks()
[docs]
class ScalarFieldPlot:
"""Class managing compound plots of scalar fields."""
@fill_in_docstring
def __init__(
self,
fields: FieldBase,
quantities=None,
scale: ScaleData = "automatic",
fig=None,
title: str | None = None,
tight: bool = False,
show: bool = True,
):
"""
Args:
fields (:class:`~pde.fields.base.FieldBase`):
Collection of fields
quantities:
{ARG_PLOT_QUANTITIES}
scale (str, float, tuple of float):
{ARG_PLOT_SCALE}
fig (:class:`matplotlib.figure.Figure):
Figure to be used for plotting. If `None`, a new figure is
created.
title (str):
Title of the plot.
tight (bool):
Whether to call :func:`matplotlib.pyplot.tight_layout`. This
affects the layout of all plot elements.
show (bool):
Flag determining whether to show a plot. If `False`, the plot is
kept in the background, which can be useful if it only needs to
be written to a file.
"""
self.grid = fields.grid
self.quantities = self._prepare_quantities(fields, quantities, scale=scale)
self.show = show
# figure out whether plots are shown in jupyter notebook
try:
from ipywidgets import Output
except ImportError:
ipython_plot = False
else:
ipython_plot = self.show
if ipython_plot:
# plotting is done in an ipython environment using widgets
from IPython.display import display
self._ipython_out = Output()
with self._ipython_out:
self._initialize(fields, scale, fig, title, tight)
display(self._ipython_out)
else:
# plotting is done using a simple matplotlib backend
self._ipython_out = None
self._initialize(fields, scale, fig, title, tight)
if self.show:
self._show()
[docs]
@classmethod
@fill_in_docstring
def from_storage(
cls,
storage: StorageBase,
quantities=None,
scale: ScaleData = "automatic",
tight: bool = False,
show: bool = True,
) -> ScalarFieldPlot:
"""Create ScalarFieldPlot from storage.
Args:
storage (:class:`~pde.storage.base.StorageBase`):
Instance of the storage class that contains the data
quantities:
{ARG_PLOT_QUANTITIES}
scale (str, float, tuple of float):
{ARG_PLOT_SCALE}
tight (bool):
Whether to call :func:`matplotlib.pyplot.tight_layout`. This
affects the layout of all plot elements.
show (bool):
Flag determining whether to show a plot. If `False`, the plot is
kept in the background.
Returns:
:class:`~pde.visualization.plotting.ScalarFieldPlot`
"""
fields = storage[0]
if not isinstance(fields, FieldBase):
raise RuntimeError("Storage must contain fields")
# prepare the data that needs to be plotted
quantities = cls._prepare_quantities(fields, quantities=quantities, scale=scale)
# resolve automatic scaling
for quantity_row in quantities:
for quantity in quantity_row:
if quantity.get("scale", "automatic") == "automatic":
vmin, vmax = +math.inf, -math.inf
for data in storage:
field = extract_field(data, quantity.get("source"))
img = field.get_image_data()
vmin = np.nanmin([np.nanmin(img["data"]), vmin])
vmax = np.nanmax([np.nanmax(img["data"]), vmax])
quantity["scale"] = (vmin, vmax)
# actually setup
return cls(fields, quantities, tight=tight, show=show)
@staticmethod
@fill_in_docstring
def _prepare_quantities(
fields: FieldBase, quantities, scale: ScaleData = "automatic"
) -> list[list[dict[str, Any]]]:
"""Internal method to prepare quantities.
Args:
fields (:class:`~pde.fields.base.FieldBase`):
The field containing the data to show
quantities (dict):
{ARG_PLOT_QUANTITIES}
scale (str, float, tuple of float):
{ARG_PLOT_SCALE}
Returns:
list of list of dict: a 2d arrangements of panels that define what
quantities are shown.
"""
if quantities is None:
# show each component by default
if isinstance(fields, FieldCollection):
quantities = []
for i, field in enumerate(fields):
title = field.label if field.label else f"Field {i + 1}"
quantity = {"title": title, "source": i, "scale": scale}
quantities.append(quantity)
else:
quantities = [{"title": "Concentration", "source": None}]
# make sure panels is a 2d array
if isinstance(quantities, dict):
quantities = [[quantities]]
elif isinstance(quantities[0], dict):
quantities = [quantities]
# set the scaling for quantities where it is not yet set
for row in quantities:
for col in row:
col.setdefault("scale", scale)
return quantities # type: ignore
def __del__(self):
try:
if hasattr(self, "fig") and self.fig:
import matplotlib.pyplot as plt
plt.close(self.fig)
except Exception:
pass # can occur when shutting down python interpreter
@fill_in_docstring
def _initialize(
self,
fields: FieldBase,
scale: ScaleData = "automatic",
fig=None,
title: str | None = None,
tight: bool = False,
):
"""Initialize the plot creating the figure and the axes.
Args:
fields (:class:`~pde.fields.base.FieldBase`):
Collection of fields
scale (str, float, tuple of float):
{ARG_PLOT_SCALE}
fig (:class:`matplotlib.figure.Figure):
Figure to be used for plotting. If `None`, a new figure is
created.
title (str):
Title of the plot.
tight (bool):
Whether to call :func:`matplotlib.pyplot.tight_layout`. This
affects the layout of all plot elements.
"""
import matplotlib.cm as cm
import matplotlib.pyplot as plt
num_rows = len(self.quantities)
num_cols = max(len(p) for p in self.quantities)
example_image = fields.get_image_data()
# set up the figure
if fig is None:
figsize = (4 * num_cols, 4 * num_rows)
self.fig, self.axes = plt.subplots(
num_rows, num_cols, sharey=True, squeeze=False, figsize=figsize
)
else:
self.fig = fig
self.axes = fig.axes
if title is None:
title = ""
self.sup_title = plt.suptitle(title)
# set up all images
empty = np.zeros_like(example_image["data"])
self.images: list[list[Any]] = []
for i, panel_row in enumerate(self.quantities):
img_row = []
for j, panel in enumerate(panel_row):
# determine scale of the panel
panel_scale = panel.get("scale", scale)
if panel_scale == "automatic":
vmin, vmax = None, None
elif panel_scale == "unity":
vmin, vmax = 0, 1
elif panel_scale == "symmetric":
vmin, vmax = -1, 1
else:
try:
vmin, vmax = panel_scale
except TypeError:
vmin, vmax = 0, panel_scale
# determine colormap
cmap = panel.get("cmap")
if cmap is None:
if vmin is None or vmax is None:
cmap = cm.viridis # type: ignore
elif np.isclose(-vmin, vmax):
cmap = cm.coolwarm # type: ignore
elif np.isclose(vmin, 0):
cmap = cm.gray # type: ignore
else:
cmap = cm.viridis # type: ignore
# initialize image
ax = self.axes[i, j]
img = ax.imshow(
empty,
vmin=vmin,
vmax=vmax,
cmap=cmap,
extent=example_image["extent"],
origin="lower",
interpolation="quadric",
)
_add_horizontal_colorbar(img, ax)
ax.set_title(panel.get("title", ""))
ax.set_xlabel(example_image["label_x"])
ax.set_ylabel(example_image["label_y"])
# store the default value alongside
img._default_vmin_vmax = vmin, vmax
img_row.append(img)
self.images.append(img_row)
if tight:
# adjust layout and leave some room for title
self.fig.tight_layout(rect=(0, 0.03, 1, 0.95))
def _update_data(self, fields: FieldBase, title: str | None = None) -> None:
"""Update the fields in the current plot.
Args:
fields (:class:`~pde.fields.base.FieldBase`):
The field or field collection of which the defined quantities
are shown.
title (str, optional):
The title of this view. If `None`, the current title is not
changed.
"""
if not isinstance(fields, FieldBase):
raise TypeError("`fields` must be of type FieldBase")
if title:
self.sup_title.set_text(title)
# iterate over all panels and update their content
for i, panel_row in enumerate(self.quantities):
for j, panel in enumerate(panel_row):
# obtain image data
field = extract_field(fields, panel.get("source"))
img_data = field.get_image_data()
# set the data in the correct panel
img = self.images[i][j]
img.set_data(img_data["data"])
vmin, vmax = img._default_vmin_vmax
if vmin is None:
vmin = img_data["data"].min()
if vmax is None:
vmax = img_data["data"].max()
img.set_clim(vmin, vmax)
def _show(self):
"""Show the updated plot."""
if self._ipython_out:
# seems to be in an ipython instance => update widget
from IPython.display import clear_output, display
with self._ipython_out:
clear_output(wait=True)
display(self.fig)
# add a small pause to allow the GUI to run it's event loop
time.sleep(0.01)
else:
# seems to be in a normal matplotlib window => update it
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# add a small pause to allow the GUI to run it's event loop
import matplotlib.pyplot as plt
plt.pause(0.01)
[docs]
def update(self, fields: FieldBase, title: str | None = None) -> None:
"""Update the plot with the given fields.
Args:
fields:
The field or field collection of which the defined quantities
are shown.
title (str, optional):
The title of this view. If `None`, the current title is not
changed.
"""
self._update_data(fields, title)
if self.show:
self._show()
[docs]
def savefig(self, path: str, **kwargs):
"""Save plot to file.
Args:
path (str):
The path to the file where the image is written. The file
extension determines the image format
**kwargs:
Additional arguments are forwarded to
:meth:`matplotlib.figure.Figure.savefig`.
"""
self.fig.savefig(path, **kwargs)
[docs]
def make_movie(
self, storage: StorageBase, filename: str, progress: bool = True
) -> None:
"""Make a movie from the data stored in storage.
Args:
storage (:class:`~pde.storage.base.StorageBase`):
The storage instance that contains all the data for the movie
filename (str):
The filename to which the movie is written. The extension
determines the format used.
progress (bool):
Flag determining whether the progress of making the movie is
shown.
"""
from ..visualization.movies import Movie
# create the iterator over the data
if progress:
data_iter = display_progress(storage.items(), total=len(storage))
else:
data_iter = storage.items()
with Movie(filename=filename) as movie:
# iterate over all time steps
for t, data in data_iter:
self.update(data, title=f"Time {t:g}")
movie.add_figure(self.fig)
[docs]
@plot_on_axes()
@fill_in_docstring
def plot_magnitudes(
storage: StorageBase, quantities=None, ax=None, **kwargs
) -> PlotReference:
r"""Plot spatially averaged quantities as a function of time.
For scalar fields, the default is to plot the average value while the averaged norm
is plotted for vector fields.
Args:
storage:
Instance of :class:`~pde.storage.base.StorageBase` that contains
the simulation data that will be plotted
quantities:
{ARG_PLOT_QUANTITIES}
{PLOT_ARGS}
\**kwargs:
All remaining parameters are forwarded to the `ax.plot` method
Returns:
:class:`~pde.tools.plotting.PlotReference`: The reference to the plot
"""
if quantities is None:
fields = storage[0]
if not isinstance(fields, FieldBase):
raise TypeError("`fields` must be of type FieldBase")
if fields.label:
label_base = fields.label
else:
label_base = kwargs.get("label", "Field")
if isinstance(fields, FieldCollection):
quantities = []
for i, field in enumerate(fields):
if field.label:
label = field.label
else:
label = label_base + f" {i + 1}"
quantities.append({"label": label, "source": i})
else:
quantities = [{"label": label_base, "source": None}]
_logger.debug("Quantities: %s", quantities)
# prepare data field
data = [
{"label": quantity.get("label", ""), "values": []} for quantity in quantities
]
# calculate the data
for fields in storage:
for i, quantity in enumerate(quantities):
source = quantity["source"]
if callable(source):
value = source(fields)
elif source is None:
value = fields.magnitude # type: ignore
else:
value = fields[source].magnitude # type: ignore
value *= quantity.get("scale", 1)
data[i]["values"].append(value)
# plot the data
lines = []
times = np.asarray(storage.times)
check_complex = True
for d in data:
kwargs["label"] = d["label"]
# warn if there is an imaginary part
if check_complex and np.any(np.iscomplex(d["values"])):
_logger.warning("Only the real part of the complex data is shown")
check_complex = False # only warn once
# actually plot the data
(l,) = ax.plot(times, np.real(d["values"]), **kwargs)
lines.append(l)
ax.set_xlabel("Time")
if len(data) == 1:
ax.set_ylabel(kwargs["label"])
if len(data) > 1:
ax.legend(loc="best")
return PlotReference(ax, lines)
def _plot_kymograph(
img_data: dict[str, Any],
ax,
colorbar: bool = True,
transpose: bool = False,
**kwargs,
) -> PlotReference:
r"""Plots a simple kymograph from given data.
Args:
img_data (dict):
Contains the kymograph data
ax (:class:`~matplotlib.axes.Axes`):
The axes to which the plot is added
colorbar (bool):
Whether to show a colorbar or not
transpose (bool):
Determines whether the transpose of the data should is plotted
\**kwargs:
Additional keyword arguments are passed to
:func:`matplotlib.pyplot.imshow`.
Returns:
:class:`~pde.tools.plotting.PlotReference`: The reference to the plot
"""
# transpose data if requested
if transpose:
# avoid changing the img_data that was passed into the function
extent = np.r_[img_data["extent_y"], img_data["extent_x"]]
img_data = {
"data": img_data["data"].T,
"label_x": img_data["label_y"],
"label_y": img_data["label_x"],
}
else:
extent = np.r_[img_data["extent_x"], img_data["extent_y"]]
# warn if there is an imaginary part
if np.any(np.iscomplex(img_data["data"])):
_logger.warning("Only the real part of the complex data is shown")
# create the actual plot
axes_image = ax.imshow(
img_data["data"].real, extent=extent, origin="lower", **kwargs
)
# adjust some settings
ax.set_xlabel(img_data["label_x"])
ax.set_ylabel(img_data["label_y"])
ax.set_aspect("auto")
if colorbar:
import matplotlib.pyplot as plt
plt.colorbar(axes_image, ax=ax)
# Note that we here do not use the method `add_scaled_colorbar` from the
# `pde.tools.plotting` module since this does not play well with axes with an
# explicit aspect ratio.
return PlotReference(ax, axes_image)
[docs]
@plot_on_axes()
def plot_kymograph(
storage: StorageBase,
field_index: int | None = None,
scalar: str = "auto",
extract: str = "auto",
colorbar: bool = True,
transpose: bool = False,
ax=None,
**kwargs,
) -> PlotReference:
r"""Plots a single kymograph from stored data.
The kymograph shows line data stacked along time. Consequently, the
resulting image shows space along the horizontal axis and time along the
vertical axis.
Args:
storage (:class:`~droplets.simulation.storage.StorageBase`):
The storage instance that contains all the data
field_index (int):
An index to choose a single field out of many in a collection
stored in `storage`. This option should not be used if only a single
field is stored in a collection.
scalar (str):
The method for extracting scalars as described in
:meth:`DataFieldBase.to_scalar`.
extract (str):
The method used for extracting the line data. See the docstring
of the grid method `get_line_data` to find supported values.
colorbar (bool):
Whether to show a colorbar or not
transpose (bool):
Determines whether the transpose of the data should is plotted
{PLOT_ARGS}
\**kwargs:
Additional keyword arguments are passed to
:func:`matplotlib.pyplot.imshow`.
Returns:
:class:`~pde.tools.plotting.PlotReference`: The reference to the plot
"""
if len(storage) == 0:
raise RuntimeError("Storage is empty")
line_data_args: dict[str, Any] = {"scalar": scalar, "extract": extract}
if storage.has_collection:
# there are many fields in the storage
if field_index is None:
_logger.warning(
"Showing first field since argument `field_index` is missing. Use "
"`plot_kymographs` to show kymographs of multiple fields."
)
else:
line_data_args["index"] = field_index
elif field_index is not None and field_index != 0:
# there is only one field in the storage, but a field_index > 0 is given
_logger.warning("`field_index` should only be set for FieldCollections")
# prepare the image data for one kymograph
image = []
for _, field in storage.items():
img_data = field.get_line_data(**line_data_args)
image.append(img_data["data_y"])
img_data["data"] = np.array(image)
img_data["extent_y"] = (storage.times[0], storage.times[-1])
img_data["label_y"] = "Time"
return _plot_kymograph(
img_data, ax, colorbar=colorbar, transpose=transpose, **kwargs
)
[docs]
@plot_on_figure()
def plot_kymographs(
storage: StorageBase,
scalar: str = "auto",
extract: str = "auto",
colorbar: bool = True,
transpose: bool = False,
resize_fig: bool = True,
fig=None,
**kwargs,
) -> list[PlotReference]:
r"""Plots kymographs for all fields stored in `storage`
The kymograph shows line data stacked along time. Consequently, the
resulting image shows space along the horizontal axis and time along the
vertical axis.
Args:
storage (:class:`~droplets.simulation.storage.StorageBase`):
The storage instance that contains all the data
scalar (str):
The method for extracting scalars as described in
:meth:`DataFieldBase.to_scalar`.
extract (str):
The method used for extracting the line data. See the docstring
of the grid method `get_line_data` to find supported values.
colorbar (bool):
Whether to show a colorbar or not
transpose (bool):
Determines whether the transpose of the data should is plotted
resize_fig (bool):
Whether to resize the figure to adjust to the number of panels
{PLOT_ARGS}
\**kwargs:
Additional keyword arguments are passed to the calls to
:func:`matplotlib.pyplot.imshow`.
Returns:
list of :class:`~pde.tools.plotting.PlotReference`: The references to
all plots
"""
if len(storage) == 0:
raise RuntimeError("Storage is empty")
line_data_args = {"scalar": scalar, "extract": extract}
if storage.has_collection:
num_fields = len(storage[0]) # type: ignore
else:
num_fields = 1
# prepare the image data for one kymograph
images = []
for _, field in storage.items():
if storage.has_collection:
image_lines = []
for f_id in range(num_fields):
img_data = field.get_line_data( # type: ignore
index=f_id,
**line_data_args,
)
image_lines.append(img_data["data_y"])
else:
img_data = field.get_line_data(**line_data_args)
image_lines = [img_data["data_y"]]
images.append(image_lines)
img_arr = np.array(images)
img_data["extent_y"] = (storage.times[0], storage.times[-1])
img_data["label_y"] = "Time"
# disable interactive plotting temporarily
# create a plot with all the panels
if resize_fig:
fig.set_size_inches((4 * num_fields, 3), forward=True)
axs = fig.subplots(1, num_fields, squeeze=False)
# iterate over all axes and plot the kymograph
refs = []
for i, ax in enumerate(axs[0]):
img_data["data"] = img_arr[:, i]
ref = _plot_kymograph(
img_data, ax, colorbar=colorbar, transpose=transpose, **kwargs
)
if storage.has_collection:
ax.set_title(storage._field[i].label) # type: ignore
refs.append(ref)
return refs
[docs]
def plot_interactive(
storage: StorageBase,
time_scaling: Literal["exact", "scaled"] = "exact",
viewer_args: dict[str, Any] | None = None,
**kwargs,
):
r"""plots stored data interactively using the `napari <https://napari.org>`_ viewer
Args:
storage (:class:`~droplets.simulation.storage.StorageBase`):
The storage instance that contains all the data
time_scaling (str):
Defines how the time axis is scaled. Possible options are "exact" (the
actual time points are used), or "scaled" (the axis is scaled so that it has
similar dimension to the spatial axes). Note that the spatial axes will
never be scaled.
viewer_args (dict):
Arguments passed to :class:`napari.viewer.Viewer` to affect the viewer.
**kwargs:
Extra arguments passed to the plotting function
"""
if not module_available("napari"):
raise ImportError("Require the `napari` module for interactive plotting")
if len(storage) < 1:
raise ValueError("Storage is empty")
grid = storage.grid
if grid is None:
raise RuntimeError("Storage did not contain information about the grid")
# collect data from all time points
timecourse: dict[str, list[np.ndarray]] = {}
for field in storage:
layer_data = field._get_napari_data(**kwargs)
if timecourse:
for key, field_data in layer_data.items():
timecourse[key].append(field_data["data"])
else:
for key, field_data in layer_data.items():
timecourse[key] = [field_data["data"]]
# replace the data in the layer_data
for key, field_data in layer_data.items():
field_data["data"] = np.array(timecourse[key])
if "scale" in field_data:
if time_scaling == "exact":
dt = np.diff(storage.times)[0] if len(storage) > 1 else 1
elif time_scaling == "scaled":
length_scale = grid.volume ** (1 / grid.dim)
dt = length_scale / len(storage)
else:
raise ValueError(f"Unknown time scaling `{time_scaling}`")
field_data["scale"] = np.r_[dt, field_data["scale"]]
if viewer_args is None:
viewer_args = {}
viewer_args.setdefault("axis_labels", ["Time"] + grid.axes)
# actually display the data using napari
with napari_viewer(grid, **viewer_args) as viewer:
napari_add_layers(viewer, layer_data)