Source code for pde.trackers.interactive

"""Special module for defining an interactive tracker that uses napari to display
fields.

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

import contextlib
import logging
import multiprocessing as mp
import platform
import queue
import signal
import time
from typing import Any

from ..fields.base import FieldBase
from ..tools.docstrings import fill_in_docstring
from ..tools.plotting import napari_add_layers
from .base import InfoDict, TrackerBase
from .interrupts import InterruptData


[docs] def napari_process( data_channel: mp.Queue, initial_data: dict[str, dict[str, Any]], t_initial: float | None = None, viewer_args: dict[str, Any] | None = None, ): """:mod:`multiprocessing.Process` running `napari120 <https://napari.org>`__ Args: data_channel (:class:`multiprocessing.Queue`): queue instance to receive data to view initial_data (dict): Initial data to be shown by napari. The layers are named according to the keys in the dictionary. The associated value needs to be a tuple, where the first item is a string indicating the type of the layer and the second carries the associated data t_initial (float): Initial time viewer_args (dict): Additional arguments passed to the napari viewer """ logger = logging.getLogger(__name__ + ".napari_process") try: import napari from napari.qt import thread_worker except ModuleNotFoundError: logger.error( "The `napari` python module could not be found. This module needs to be " "installed to use the interactive tracker." ) return logger.info("Start napari process") # ignore keyboard interrupts in this process signal.signal(signal.SIGINT, signal.SIG_IGN) if viewer_args is None: viewer_args = {} # create and initialize the viewer viewer = napari.Viewer(**viewer_args) napari_add_layers(viewer, initial_data) # add time if given if t_initial is not None: from qtpy.QtWidgets import QLabel label = QLabel() label.setText(f"Time: {t_initial}") viewer.window.add_dock_widget(label) else: label = None def check_signal(msg: str | None): """Helper function that processes messages by the listener thread.""" if msg is None: return # do nothing elif msg == "close": viewer.close() else: raise RuntimeError(f"Unknown message from listener: {msg}") @thread_worker(connect={"yielded": check_signal}) def update_listener(): """Helper thread that listens to the data_channel.""" logger.info("Start napari thread to receive data") # infinite loop waiting for events in the queue while True: # get all items from the queue and display the last update update_data = None # nothing to update yet while True: time.sleep(0.02) # read queue with 50 fps try: action, data = data_channel.get(block=False) except queue.Empty: break if action == "close": logger.info("Forced closing of napari...") yield "close" # signal to napari process to shut down break elif action == "update": update_data = data # continue running until the queue is empty else: logger.warning("Unexpected action: %s", action) # update napari view when there is data if update_data is not None: logger.debug("Update napari layer...") layer_data, t = update_data if label is not None: label.setText(f"Time: {t}") for name, data in layer_data.items(): viewer.layers[name].data = data["data"] yield # start worker thread that listens to the data_channel update_listener() # start napari napari.run() logger.info("Shutting down napari process")
[docs] class NapariViewer: """Allows viewing and updating data in a separate napari process.""" def __init__(self, state: FieldBase, t_initial: float | None = None): """ Args: state (:class:`pde.fields.base.FieldBase`): The initial state to be shown t_initial (float): The initial time. If `None`, no time will be shown. """ self._logger = logging.getLogger(__name__) # pick a suitable multiprocessing if platform.system() == "Darwin": context: mp.context.BaseContext = mp.get_context("spawn") else: context = mp.get_context() # create process that runs napari self.data_channel = context.Queue() initial_data = state._get_napari_data() viewer_args = { "axis_labels": state.grid.axes, "ndisplay": 3 if state.grid.dim >= 3 else 2, } args = (self.data_channel, initial_data, t_initial, viewer_args) self.proc = context.Process(target=napari_process, args=args) # type: ignore # start the process in the background try: self.proc.start() except RuntimeError: print() print("=" * 80) print( "It looks as if the main program did not use the multiprocessing " "safe-guard, which is necessary on some platforms. Please protect the " "main code of your program in the following way:" ) print("") print(" if __name__ == '__main__':") print(" code ...") print("") print("The interactive Napari viewer could not be launched.") print("=" * 80) print() self._logger.exception("Could not launch napari process")
[docs] def update(self, state: FieldBase, t: float): """Update the state in the napari viewer. Args: state (:class:`pde.fields.base.FieldBase`): The new state t (float): Current time """ if self.proc.is_alive(): try: data = (state._get_napari_data(), t) self.data_channel.put(("update", data), block=False) except queue.Full: pass # could not write data else: with contextlib.suppress(queue.Empty): self.data_channel.get(block=False)
[docs] def close(self, force: bool = True): """Closes the napari process. Args: force (bool): Whether to force closing of the napari program. If this is `False`, this method blocks until the user closes napari manually. """ if self.proc.is_alive() and force: # signal to napari process that it should be closed with contextlib.suppress(RuntimeError): self.data_channel.put(("close", None)) self.data_channel.close() self.data_channel.join_thread() if self.proc.is_alive(): self.proc.join()
[docs] class InteractivePlotTracker(TrackerBase): """Tracker showing the state interactively in napari. Note: The interactive tracker uses the python :mod:`multiprocessing` module to run `napari <http://napari.org/>`__ externally. The multiprocessing module has limitations on some platforms, which requires some care when writing your own programs. In particular, the main method needs to be safe-guarded so that the main module can be imported again after spawning a new process. An established pattern that works is to introduce a function `main` in your code, which you call using the following pattern .. code-block:: python def main(): # here goes your main code if __name__ == "__main__": main() The last two lines ensure that the `main` function is only called when the module is run initially and not again when it is re-imported. """ name = "interactive" @fill_in_docstring def __init__( self, interrupts: InterruptData = "0:01", *, close: bool = True, show_time: bool = False, interval=None, ): """ Args: interrupts: {ARG_TRACKER_INTERRUPT} close (bool): Flag indicating whether the napari window is closed automatically at the end of the simulation. If `False`, the tracker blocks when `finalize` is called until the user closes napari manually. show_time (bool): Whether to indicate the time """ super().__init__(interrupts=interrupts, interval=interval) self.close = close self.show_time = show_time
[docs] def initialize(self, state: FieldBase, info: InfoDict | None = None) -> float: """Initialize the tracker with information about the simulation. Args: state (:class:`~pde.fields.FieldBase`): An example of the data that will be analyzed by the tracker info (dict): Extra information from the simulation Returns: float: The first time the tracker needs to handle data """ if self.show_time: t_initial = 0 if info is None else info.get("t_start", 0) else: t_initial = None self._viewer = NapariViewer(state, t_initial=t_initial) return super().initialize(state, info=info)
[docs] def handle(self, state: FieldBase, t: float) -> None: """Handle data supplied to this tracker. Args: state (:class:`~pde.fields.FieldBase`): The current state of the simulation t (float): The associated time """ self._viewer.update(state, t)
[docs] def finalize(self, info: InfoDict | None = None) -> None: """Finalize the tracker, supplying additional information. Args: info (dict): Extra information from the simulation """ self._viewer.close(force=self.close)