Source code for pde.tools.mpi

"""
Auxillary functions and variables for dealing with MPI multiprocessing


.. autosummary::
   :nosignatures:

   mpi_send
   mpi_recv
   mpi_allreduce

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

import os
import sys
from numbers import Number
from typing import TYPE_CHECKING, Union

import numba as nb
import numpy as np
from numba.core import types
from numba.extending import register_jitable

from .numba import jit

if TYPE_CHECKING:
    from numba_mpi import Operator  # @UnusedImport

# Initialize assuming that we run serial code if `numba_mpi` is not available
initialized: bool = False
"""bool: Flag determining whether mpi was initialized (and is available)"""

size: int = 1
"""int: Total process count"""

rank: int = 0
"""int: ID of the current process"""

# read state of the current MPI node
try:
    import numba_mpi

except ImportError:
    # package `numba_mpi` could not be loaded
    if int(os.environ.get("PMI_SIZE", "1")) > 1:
        # environment variable indicates that we are in a parallel program
        sys.exit(
            "WARNING: Detected multiprocessing run, but could not load `numba_mpi`"
        )

else:
    # we have access to MPI
    initialized = numba_mpi.initialized()
    size = numba_mpi.size()
    rank = numba_mpi.rank()

parallel_run: bool = size > 1
"""bool: Flag indicating whether the current run is using multiprocessing"""

is_main: bool = rank == 0
"""bool: Flag indicating whether the current process is the main process (with ID 0)"""


[docs]@jit def mpi_send(data, dest: int, tag: int) -> None: """send data to another MPI node Args: data: The data being send dest (int): The ID of the receiving node tag (int): A numeric tag identifying the message """ status = numba_mpi.send(data, dest, tag) assert status == 0
[docs]@jit() def mpi_recv(data, source, tag) -> None: """receive data from another MPI node Args: data: A buffer into which the received data is written dest (int): The ID of the sending node tag (int): A numeric tag identifying the message """ status = numba_mpi.recv(data, source, tag) assert status == 0
[docs]@nb.generated_jit(nopython=True) def mpi_allreduce(data, operator: Union[int, "Operator", None] = None): """combines data from all MPI nodes Note that complex datatypes and user-defined functions are not properly supported. Args: data: Data being send from this node to all others operator: The operator used to combine all data. Possible options are summarized in the IntEnum :class:`numba_mpi.Operator`. Returns: The accumulated data """ # the following definition of `allreduce` is a workaround that allows using the # numba_mpi.allreduce function without jitting. This workaround can be dropped once # numba_mpi.allreduce can be called without jitting in the future. if nb.config.DISABLE_JIT: def allreduce(sendobj, recvobj, operator=None): if operator is None: impl = numba_mpi.allreduce(sendobj, recvobj) return impl(sendobj, recvobj) else: impl = numba_mpi.allreduce(sendobj, recvobj, operator) return impl(sendobj, recvobj, operator) else: @register_jitable def allreduce(sendobj, recvobj, operator=None): if operator is None: return numba_mpi.allreduce(sendobj, recvobj) else: return numba_mpi.allreduce(sendobj, recvobj, operator) if isinstance(data, (types.Number, Number)): def impl(data, operator=None): """reduce a single number across all cores""" sendobj = np.array([data]) recvobj = np.empty((1,), sendobj.dtype) status = allreduce(sendobj, recvobj, operator) assert status == 0 return recvobj[0] elif isinstance(data, (types.Array, np.ndarray)): def impl(data, operator=None): """reduce an array across all cores""" recvobj = np.empty(data.shape, data.dtype) status = allreduce(data, recvobj, operator) assert status == 0 return recvobj else: raise TypeError(f"Unsupported type {data.__class__.__name__}") if nb.config.DISABLE_JIT: return impl(data, operator) else: return impl