"""
Special module for defining an interactive tracker that uses napari to display fields
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
import logging
import multiprocessing as mp
import platform
import queue
import signal
import time
from typing import Any, Dict, Optional
from ..fields.base import FieldBase
from ..tools.docstrings import fill_in_docstring
from ..tools.plotting import napari_add_layers
from .base import InfoDict, TrackerBase
from .interrupts import IntervalData
[docs]def napari_process(
data_channel: mp.Queue,
initial_data: Dict[str, Dict[str, Any]],
t_initial: Optional[float] = None,
viewer_args: Optional[Dict[str, Any]] = None,
):
""":mod:`multiprocessing.Process` running `napari <https://napari.org>`__
Args:
data_channel (:class:`multiprocessing.Queue`):
queue instance to receive data to view
initial_data (dict):
Initial data to be shown by napari. The layers are named according to
the keys in the dictionary. The associated value needs to be a tuple,
where the first item is a string indicating the type of the layer and
the second carries the associated data
t_initial (float):
Initial time
viewer_args (dict):
Additional arguments passed to the napari viewer
"""
logger = logging.getLogger(__name__ + ".napari_process")
try:
import napari
from napari.qt import thread_worker
except ModuleNotFoundError:
logger.error(
"The `napari` python module could not be found. This module needs to be "
"installed to use the interactive tracker."
)
return
logger.info("Start napari process")
# ignore keyboard interrupts in this process
signal.signal(signal.SIGINT, signal.SIG_IGN)
if viewer_args is None:
viewer_args = {}
# create and initialize the viewer
viewer = napari.Viewer(**viewer_args)
napari_add_layers(viewer, initial_data)
# add time if given
if t_initial is not None:
from qtpy.QtWidgets import QLabel
label = QLabel()
label.setText(f"Time: {t_initial}")
viewer.window.add_dock_widget(label)
else:
label = None
def check_signal(msg: Optional[str]):
"""helper function that processes messages by the listener thread"""
if msg is None:
return # do nothing
elif msg == "close":
viewer.close()
else:
raise RuntimeError(f"Unknown message from listener: {msg}")
@thread_worker(connect={"yielded": check_signal})
def update_listener():
"""helper thread that listens to the data_channel"""
logger.info("Start napari thread to receive data")
# infinite loop waiting for events in the queue
while True:
# get all items from the queue and display the last update
update_data = None # nothing to update yet
while True:
time.sleep(0.02) # read queue with 50 fps
try:
action, data = data_channel.get(block=False)
except queue.Empty:
break
if action == "close":
logger.info("Forced closing of napari...")
yield "close" # signal to napari process to shut down
break
elif action == "update":
update_data = data
# continue running until the queue is empty
else:
logger.warning(f"Unexpected action: {action}")
# update napari view when there is data
if update_data is not None:
logger.debug(f"Update napari layer...")
layer_data, t = update_data
if label is not None:
label.setText(f"Time: {t}")
for name, layer_data in layer_data.items():
viewer.layers[name].data = layer_data["data"]
yield
# start worker thread that listens to the data_channel
update_listener()
# start napari
napari.run()
logger.info("Shutting down napari process")
[docs]class NapariViewer:
"""allows viewing and updating data in a separate napari process"""
def __init__(self, state: FieldBase, t_initial: Optional[float] = None):
"""
Args:
state (:class:`pde.fields.base.FieldBase`): The initial state to be shown
t_initial (float): The initial time. If `None`, no time will be shown.
"""
self._logger = logging.getLogger(__name__)
# pick a suitable multiprocessing
if platform.system() == "Darwin":
context: mp.context.BaseContext = mp.get_context("spawn")
else:
context = mp.get_context()
# create process that runs napari
self.data_channel = context.Queue()
initial_data = state._get_napari_data()
viewer_args = {
"axis_labels": state.grid.axes,
"ndisplay": 3 if state.grid.dim >= 3 else 2,
}
args = (self.data_channel, initial_data, t_initial, viewer_args)
self.proc = context.Process(target=napari_process, args=args) # type: ignore
# start the process in the background
try:
self.proc.start()
except RuntimeError:
print()
print("=" * 80)
print(
"It looks as if the main program did not use the multiprocessing "
"safe-guard, which is necessary on some platforms. Please protect the "
"main code of your program in the following way:"
)
print("")
print(" if __name__ == '__main__':")
print(" code ...")
print("")
print("The interactive Napari viewer could not be launched.")
print("=" * 80)
print()
self._logger.exception("Could not launch napari process")
[docs] def update(self, state: FieldBase, t: float):
"""update the state in the napari viewer
Args:
state (:class:`pde.fields.base.FieldBase`): The new state
t (float): Current time
"""
if self.proc.is_alive():
try:
data = (state._get_napari_data(), t)
self.data_channel.put(("update", data), block=False)
except queue.Full:
pass # could not write data
else:
try:
self.data_channel.get(block=False)
except queue.Empty:
pass
[docs] def close(self, force: bool = True):
"""closes the napari process
Args:
force (bool):
Whether to force closing of the napari program. If this is `False`, this
method blocks until the user closes napari manually.
"""
if self.proc.is_alive() and force:
# signal to napari process that it should be closed
try:
self.data_channel.put(("close", None))
except RuntimeError:
pass
self.data_channel.close()
self.data_channel.join_thread()
if self.proc.is_alive():
self.proc.join()
[docs]class InteractivePlotTracker(TrackerBase):
"""Tracker showing the state interactively in napari
Note:
The interactive tracker uses the python :mod:`multiprocessing` module to run
`napari <http://napari.org/>`__ externally. The multiprocessing module
has limitations on some platforms, which requires some care when writing your
own programs. In particular, the main method needs to be safe-guarded so that
the main module can be imported again after spawning a new process. An
established pattern that works is to introduce a function `main` in your code,
which you call using the following pattern
.. code-block:: python
def main():
# here goes your main code
if __name__ == "__main__":
main()
The last two lines ensure that the `main` function is only called when the
module is run initially and not again when it is re-imported.
"""
name = "interactive"
@fill_in_docstring
def __init__(
self,
interval: IntervalData = "0:01",
close: bool = True,
show_time: bool = False,
):
"""
Args:
interval:
{ARG_TRACKER_INTERVAL}
close (bool):
Flag indicating whether the napari window is closed automatically at the
end of the simulation. If `False`, the tracker blocks when `finalize` is
called until the user closes napari manually.
show_time (bool):
Whether to indicate the time
"""
# initialize the tracker
super().__init__(interval=interval)
self.close = close
self.show_time = show_time
[docs] def initialize(self, state: FieldBase, info: Optional[InfoDict] = None) -> float:
"""initialize the tracker with information about the simulation
Args:
state (:class:`~pde.fields.FieldBase`):
An example of the data that will be analyzed by the tracker
info (dict):
Extra information from the simulation
Returns:
float: The first time the tracker needs to handle data
"""
if self.show_time:
t_initial = 0 if info is None else info.get("t_start", 0)
else:
t_initial = None
self._viewer = NapariViewer(state, t_initial=t_initial)
return super().initialize(state, info=info)
[docs] def handle(self, state: FieldBase, t: float) -> None:
"""handle data supplied to this tracker
Args:
state (:class:`~pde.fields.FieldBase`):
The current state of the simulation
t (float):
The associated time
"""
self._viewer.update(state, t)
[docs] def finalize(self, info: Optional[InfoDict] = None) -> None:
"""finalize the tracker, supplying additional information
Args:
info (dict):
Extra information from the simulation
"""
self._viewer.close(force=self.close)