Source code for pde.tools.numba

"""
Helper functions for just-in-time compilation with numba

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

import logging
import os
import warnings
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar

import numba as nb  # lgtm [py/import-and-import-from]
import numpy as np
from numba.extending import register_jitable
from numba.typed import Dict as NumbaDict

from .. import config
from ..tools.misc import decorator_arguments

try:
    # is_jitted has been added in numba 0.53 on 2021-03-11
    from numba.extending import is_jitted

except ImportError:
    # for earlier version of numba, we need to define the function

    def is_jitted(function: Callable) -> bool:
        """determine whether a function has already been jitted"""
        try:
            from numba.core.dispatcher import Dispatcher
        except ImportError:
            # assume older numba module structure
            from numba.dispatcher import Dispatcher
        return isinstance(function, Dispatcher)


# numba version as a list of integers
NUMBA_VERSION = [int(v) for v in nb.__version__.split(".")[:2]]


[docs]class Counter: """helper class for implementing JIT_COUNT We cannot use a simple integer for this, since integers are immutable, so if one imports JIT_COUNT from this module it would always stay at the fixed value it had when it was first imported. The workaround would be to import the symbol every time the counter is read, but this is error-prone. Instead, we implement a thin wrapper class around an int, which only supports reading and incrementing the value. Since this object is now mutable it can be used easily. A disadvantage is that the object needs to be converted to int before it can be used in most expressions. """ def __init__(self, value: int = 0): self._counter = value def __eq__(self, other): return self._counter == other def __int__(self): return self._counter def __iadd__(self, value): self._counter += value return self
[docs] def increment(self): self._counter += 1
def __repr__(self): return str(self._counter)
# global variable counting the number of compilations JIT_COUNT = Counter() TFunc = TypeVar("TFunc", bound="Callable")
[docs]def numba_environment() -> Dict[str, Any]: """return information about the numba setup used Returns: (dict) information about the numba setup """ # determine whether Nvidia Cuda is available try: from numba import cuda cuda_available = cuda.is_available() except ImportError: cuda_available = False # determine whether AMD ROC is available try: from numba import roc roc_available = roc.is_available() except ImportError: roc_available = False # determine threading layer try: threading_layer = nb.threading_layer() except ValueError: # threading layer was not initialized, so compile a mock function @nb.jit("i8()", parallel=True) def f(): s = 0 for i in nb.prange(4): s += i return s f() try: threading_layer = nb.threading_layer() except ValueError: # cannot initialize threading threading_layer = None except AttributeError: # old numba version threading_layer = None return { "version": nb.__version__, "parallel": config["numba.parallel"], "fastmath": config["numba.fastmath"], "debug": config["numba.debug"], "using_svml": nb.config.USING_SVML, "threading_layer": threading_layer, "omp_num_threads": os.environ.get("OMP_NUM_THREADS"), "mkl_num_threads": os.environ.get("MKL_NUM_THREADS"), "num_threads": nb.config.NUMBA_NUM_THREADS, "num_threads_default": nb.config.NUMBA_DEFAULT_NUM_THREADS, "cuda_available": cuda_available, "roc_available": roc_available, }
def _get_jit_args(parallel: bool = False, **kwargs) -> Dict[str, Any]: """return arguments for the :func:`nb.jit` with default values Args: parallel (bool): Allow parallel compilation of the function **kwargs: Additional arguments to `nb.jit` Returns: dict: Keyword arguments that can directly be used in :func:`nb.jit` """ kwargs.setdefault("fastmath", config["numba.fastmath"]) kwargs.setdefault("debug", config["numba.debug"]) # make sure parallel numba is only enabled in restricted cases kwargs["parallel"] = parallel and config["numba.parallel"] return kwargs if nb.config.DISABLE_JIT: # use work-around for https://github.com/numba/numba/issues/4759 def flat_idx(arr, i): """helper function allowing indexing of scalars as if they arrays""" if np.isscalar(arr): return arr else: return arr.flat[i] else: # compiled version that specializes correctly
[docs] @nb.generated_jit(nopython=True) def flat_idx(arr, i): """helper function allowing indexing of scalars as if they arrays""" if isinstance(arr, (nb.types.Integer, nb.types.Float)): return lambda arr, i: arr else: return lambda arr, i: arr.flat[i]
[docs]@decorator_arguments def jit(function: TFunc, signature=None, parallel: bool = False, **kwargs) -> TFunc: """apply nb.jit with predefined arguments Args: function: The function which is jitted signature: Signature of the function to compile parallel (bool): Allow parallel compilation of the function **kwargs: Additional arguments to `nb.jit` Returns: Function that will be compiled using numba """ if is_jitted(function): return function # prepare the compilation arguments kwargs.setdefault("nopython", True) jit_kwargs = _get_jit_args(parallel=parallel, **kwargs) # log some details logger = logging.getLogger(__name__) name = function.__name__ if kwargs["nopython"]: # standard case logger.info("Compile `%s` with parallel=%s", name, jit_kwargs["parallel"]) else: # this might imply numba falls back to object mode logger.warning("Compile `%s` with nopython=False", name) # increase the compilation counter by one JIT_COUNT.increment() return nb.jit(signature, **jit_kwargs)(function) # type: ignore
[docs]@decorator_arguments def jit_allocate_out( func: Callable, parallel: bool = False, out_shape: Optional[Tuple[int, ...]] = None, num_args: int = 1, **kwargs, ) -> Callable: """Decorator that compiles a function with allocating an output array. This decorator compiles a function that takes the arguments `arr` and `out`. The point of this decorator is to make the `out` array optional by supplying an empty array of the same shape as `arr` if necessary. This is implemented efficiently by using :func:`numba.generated_jit`. Args: func: The function to be compiled parallel (bool): Determines whether the function is jitted with parallel=True. out_shape (tuple): Determines the shape of the `out` array. If omitted, the same shape as the input array is used. num_args (int, optional): Determines the number of input arguments of the function. **kwargs: Additional arguments used in :func:`numba.jit` Returns: The decorated function """ # TODO: Remove `num_args` and use inspection on `func` instead if nb.config.DISABLE_JIT: # jitting is disabled => return generic python function if num_args == 1: def f_arg1_with_allocated_out(arr, out=None, args=None): """helper function allocating output array""" if out is None: if out_shape is None: out = np.empty_like(arr) else: out = np.empty(out_shape, dtype=arr.dtype) return func(arr, out, args=args) return f_arg1_with_allocated_out elif num_args == 2: def f_arg2_with_allocated_out(a, b, out=None, args=None): """helper function allocating output array""" if out is None: assert a.shape == b.shape if out_shape is None: out = np.empty_like(a) else: out = np.empty(out_shape, dtype=a.dtype) return func(a, b, out, args=args) return f_arg2_with_allocated_out else: raise NotImplementedError("Only 1 or 2 arguments are supported") else: # jitting is enabled => return specific compiled functions jit_kwargs_outer = _get_jit_args(nopython=True, parallel=False, **kwargs) # we need to cast `parallel` to bool since np.bool is not supported by jit jit_kwargs_inner = _get_jit_args(parallel=bool(parallel), **kwargs) logging.getLogger(__name__).info( "Compile `%s` with %s", func.__name__, jit_kwargs_inner ) if num_args == 1: @nb.generated_jit(**jit_kwargs_outer) @wraps(func) def wrapper(arr, out=None, args=None): """wrapper deciding whether the underlying function is called with or without `out`. This uses :func:`nb.generated_jit` to compile different versions of the same function """ if isinstance(arr, nb.types.Number): raise RuntimeError( "Functions defined with an `out` keyword must not be called " "with scalar quantities." ) if not isinstance(arr, nb.types.Array): raise RuntimeError( "Compiled functions need to be called with numpy arrays, not " f"type {arr.__class__.__name__}" ) f_jit = register_jitable(**jit_kwargs_inner)(func) if isinstance(out, (nb.types.NoneType, nb.types.Omitted)): # function is called without `out` if out_shape is None: # we have to obtain the shape of `out` from `arr` def f_with_allocated_out(arr, out=None, args=None): """helper function allocating output array""" return f_jit(arr, out=np.empty_like(arr), args=args) else: # the shape of `out` is given by `out_shape` def f_with_allocated_out(arr, out=None, args=None): """helper function allocating output array""" out_arr = np.empty(out_shape, dtype=arr.dtype) return f_jit(arr, out=out_arr, args=args) return f_with_allocated_out else: # function is called with `out` argument if out_shape is None: return f_jit else: def f_with_tested_out(arr, out=None, args=None): """helper function allocating output array""" assert out.shape == out_shape return f_jit(arr, out, args=args) return f_with_tested_out elif num_args == 2: @nb.generated_jit(**jit_kwargs_outer) @wraps(func) def wrapper(a, b, out=None, args=None): """wrapper deciding whether the underlying function is called with or without `out`. This uses nb.generated_jit to compile different versions of the same function.""" if isinstance(a, nb.types.Number): # simple scalar call -> do not need to allocate anything raise RuntimeError( "Functions defined with an `out` keyword should not be called " "with scalar quantities" ) elif isinstance(out, (nb.types.NoneType, nb.types.Omitted)): # function is called without `out` f_jit = register_jitable(**jit_kwargs_inner)(func) if out_shape is None: # we have to obtain the shape of `out` from `a` def f_with_allocated_out(a, b, out=None, args=None): """helper function allocating output array""" return f_jit(a, b, out=np.empty_like(a), args=args) else: # the shape of `out` is given by `out_shape` def f_with_allocated_out(a, b, out=None, args=None): """helper function allocating output array""" out_arr = np.empty(out_shape, dtype=a.dtype) return f_jit(a, b, out=out_arr, args=args) return f_with_allocated_out else: # function is called with `out` argument return func else: raise NotImplementedError("Only 1 or 2 arguments are supported") # increase the compilation counter by one JIT_COUNT.increment() return wrapper # type: ignore
if nb.config.DISABLE_JIT: # dummy function that creates a ctypes pointer def address_as_void_pointer(addr): """returns a void pointer from a given memory address Example: This can for instance be used together with `numba.carray`: >>> addr = arr.ctypes.data >>> numba.carray(address_as_void_pointer(addr), arr.shape, arr.dtype Args: addr (int): The memory address Returns: :class:`ctypes.c_void_p`: Pointer to the memory address """ import ctypes return ctypes.cast(addr, ctypes.c_void_p) else: # actually useful function that creates a numba pointer @nb.extending.intrinsic def address_as_void_pointer(typingctx, src): """returns a void pointer from a given memory address Example: This can for instance be used together with `numba.carray`: >>> addr = arr.ctypes.data >>> numba.carray(address_as_void_pointer(addr), arr.shape, arr.dtype Args: addr (int): The memory address Returns: :class:`numba.core.types.voidptr`: Pointer to the memory address """ from numba.core import cgutils, types sig = types.voidptr(src) def codegen(cgctx, builder, sig, args): return builder.inttoptr(args[0], cgutils.voidptr_t) return sig, codegen
[docs]def make_array_constructor(arr: np.ndarray) -> Callable[[], np.ndarray]: """returns an array within a jitted function using basic information Args: arr (:class:`~numpy.ndarray`): The array that should be accessible within jit Warning: A reference to the array needs to be retained outside the numba code to prevent garbage collection from removing the array """ data_addr = arr.__array_interface__["data"][0] strides = arr.__array_interface__["strides"] shape = arr.__array_interface__["shape"] dtype = arr.dtype @register_jitable def array_constructor() -> np.ndarray: """helper that reconstructs the array from the pointer and structural info""" data: np.ndarray = nb.carray(address_as_void_pointer(data_addr), shape, dtype) if strides is not None: data = np.lib.index_tricks.as_strided(data, shape, strides) # type: ignore return data return array_constructor # type: ignore
[docs]@nb.generated_jit(nopython=True) def convert_scalar(arr): """helper function that turns 0d-arrays into scalars This helps to avoid the bug discussed in https://github.com/numba/numba/issues/6000 """ if isinstance(arr, nb.types.Array) and arr.ndim == 0: return lambda arr: arr[()] else: return lambda arr: arr
[docs]def numba_dict(data: Dict[str, Any] = None) -> Optional[NumbaDict]: """converts a python dictionary to a numba typed dictionary""" if data is None: return None nb_dict = NumbaDict() for k, v in data.items(): nb_dict[k] = v return nb_dict
[docs]def get_common_numba_dtype(*args): r"""returns a numba numerical type in which all arrays can be represented Args: *args: All items to be tested Returns: numba.complex128 if any entry is complex, otherwise numba.double """ from numba.core.types import npytypes, scalars for arg in args: if isinstance(arg, scalars.Complex): return nb.complex128 elif isinstance(arg, npytypes.Array): if isinstance(arg.dtype, scalars.Complex): return nb.complex128 else: raise NotImplementedError(f"Cannot handle type {arg.__class__}") return nb.double
if NUMBA_VERSION < [0, 45]: warnings.warn( "Your numba version is outdated. Please install at least version 0.45" )