Source code for pde.tools.numba

"""
Helper functions for just-in-time compilation with numba

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

import logging
import os
import warnings
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar

import numba as nb  # lgtm [py/import-and-import-from]
import numpy as np
from numba.extending import register_jitable
from numba.typed import Dict as NumbaDict

from .. import config
from ..tools.misc import decorator_arguments

try:
    # is_jitted has been added in numba 0.53 on 2021-03-11
    from numba.extending import is_jitted

except ImportError:
    # for earlier version of numba, we need to define the function

    def is_jitted(function: Callable) -> bool:
        """determine whether a function has already been jitted"""
        try:
            from numba.core.dispatcher import Dispatcher
        except ImportError:
            # assume older numba module structure
            from numba.dispatcher import Dispatcher
        return isinstance(function, Dispatcher)


# numba version as a list of integers
NUMBA_VERSION = [int(v) for v in nb.__version__.split(".")[:2]]

# global variable counting the number of compilations
JIT_COUNT = 0


TFunc = TypeVar("TFunc", bound="Callable")


[docs]def numba_environment() -> Dict[str, Any]: """return information about the numba setup used Returns: (dict) information about the numba setup """ # determine whether Nvidia Cuda is available try: from numba import cuda cuda_available = cuda.is_available() except ImportError: cuda_available = False # determine whether AMD ROC is available try: from numba import roc roc_available = roc.is_available() except ImportError: roc_available = False # determine threading layer try: threading_layer = nb.threading_layer() except ValueError: # threading layer was not initialized, so compile a mock function @nb.jit("i8()", parallel=True) def f(): s = 0 for i in nb.prange(4): s += i return s f() try: threading_layer = nb.threading_layer() except ValueError: # cannot initialize threading threading_layer = None except AttributeError: # old numba version threading_layer = None return { "version": nb.__version__, "parallel": config["numba.parallel"], "fastmath": config["numba.fastmath"], "debug": config["numba.debug"], "using_svml": nb.config.USING_SVML, "threading_layer": threading_layer, "omp_num_threads": os.environ.get("OMP_NUM_THREADS"), "mkl_num_threads": os.environ.get("MKL_NUM_THREADS"), "num_threads": nb.config.NUMBA_NUM_THREADS, "num_threads_default": nb.config.NUMBA_DEFAULT_NUM_THREADS, "cuda_available": cuda_available, "roc_available": roc_available, }
def _get_jit_args(parallel: bool = False, **kwargs) -> Dict[str, Any]: """return arguments for the :func:`nb.jit` with default values Args: parallel (bool): Allow parallel compilation of the function **kwargs: Additional arguments to `nb.jit` Returns: dict: Keyword arguments that can directly be used in :func:`nb.jit` """ kwargs.setdefault("fastmath", config["numba.fastmath"]) kwargs.setdefault("debug", config["numba.debug"]) # make sure parallel numba is only enabled in restricted cases kwargs["parallel"] = parallel and config["numba.parallel"] return kwargs if nb.config.DISABLE_JIT: # use work-around for https://github.com/numba/numba/issues/4759 def flat_idx(arr, i): """helper function allowing indexing of scalars as if they arrays""" if np.isscalar(arr): return arr else: return arr.flat[i] else: # compiled version that specializes correctly
[docs] @nb.generated_jit(nopython=True) def flat_idx(arr, i): """helper function allowing indexing of scalars as if they arrays""" if isinstance(arr, (nb.types.Integer, nb.types.Float)): return lambda arr, i: arr else: return lambda arr, i: arr.flat[i]
[docs]@decorator_arguments def jit(function: TFunc, signature=None, parallel: bool = False, **kwargs) -> TFunc: """apply nb.jit with predefined arguments Args: function: The function which is jitted signature: Signature of the function to compile parallel (bool): Allow parallel compilation of the function **kwargs: Additional arguments to `nb.jit` Returns: Function that will be compiled using numba """ if is_jitted(function): return function # prepare the compilation arguments kwargs.setdefault("nopython", True) jit_kwargs = _get_jit_args(parallel=parallel, **kwargs) # log some details logger = logging.getLogger(__name__) name = function.__name__ if kwargs["nopython"]: # standard case logger.info("Compile `%s` with parallel=%s", name, jit_kwargs["parallel"]) else: # this might imply numba falls back to object mode logger.warning("Compile `%s` with nopython=False", name) # increase the compilation counter by one global JIT_COUNT JIT_COUNT += 1 return nb.jit(signature, **jit_kwargs)(function) # type: ignore
[docs]@decorator_arguments def jit_allocate_out( func: Callable, parallel: bool = False, out_shape: Optional[Tuple[int, ...]] = None, num_args: int = 1, **kwargs, ) -> Callable: """Decorator that compiles a function with allocating an output array. This decorator compiles a function that takes the arguments `arr` and `out`. The point of this decorator is to make the `out` array optional by supplying an empty array of the same shape as `arr` if necessary. This is implemented efficiently by using :func:`numba.generated_jit`. Args: func: The function to be compiled parallel (bool): Determines whether the function is jitted with parallel=True. out_shape (tuple): Determines the shape of the `out` array. If omitted, the same shape as the input array is used. num_args (int, optional): Determines the number of input arguments of the function. **kwargs: Additional arguments used in :func:`numba.jit` Returns: The decorated function """ # TODO: Remove `num_args` and use inspection on `func` instead if nb.config.DISABLE_JIT: # jitting is disabled => return generic python function if num_args == 1: def f_arg1_with_allocated_out(arr, out=None, args=None): """helper function allocating output array""" if out is None: if out_shape is None: out = np.empty_like(arr) else: out = np.empty(out_shape, dtype=arr.dtype) return func(arr, out, args=args) return f_arg1_with_allocated_out elif num_args == 2: def f_arg2_with_allocated_out(a, b, out=None, args=None): """helper function allocating output array""" if out is None: assert a.shape == b.shape if out_shape is None: out = np.empty_like(a) else: out = np.empty(out_shape, dtype=a.dtype) return func(a, b, out, args=args) return f_arg2_with_allocated_out else: raise NotImplementedError("Only 1 or 2 arguments are supported") else: # jitting is enabled => return specific compiled functions jit_kwargs_outer = _get_jit_args(nopython=True, parallel=False, **kwargs) # we need to cast `parallel` to bool since np.bool is not supported by jit jit_kwargs_inner = _get_jit_args(parallel=bool(parallel), **kwargs) logging.getLogger(__name__).info( "Compile `%s` with %s", func.__name__, jit_kwargs_inner ) if num_args == 1: @nb.generated_jit(**jit_kwargs_outer) @wraps(func) def wrapper(arr, out=None, args=None): """wrapper deciding whether the underlying function is called with or without `out`. This uses :func:`nb.generated_jit` to compile different versions of the same function """ if isinstance(arr, nb.types.Number): raise RuntimeError( "Functions defined with an `out` keyword must not be called " "with scalar quantities." ) if not isinstance(arr, nb.types.Array): raise RuntimeError( "Compiled functions need to be called with numpy arrays, not " f"type {arr.__class__.__name__}" ) f_jit = register_jitable(**jit_kwargs_inner)(func) if isinstance(out, (nb.types.NoneType, nb.types.Omitted)): # function is called without `out` if out_shape is None: # we have to obtain the shape of `out` from `arr` def f_with_allocated_out(arr, out=None, args=None): """helper function allocating output array""" return f_jit(arr, out=np.empty_like(arr), args=args) else: # the shape of `out` is given by `out_shape` def f_with_allocated_out(arr, out=None, args=None): """helper function allocating output array""" out_arr = np.empty(out_shape, dtype=arr.dtype) return f_jit(arr, out=out_arr, args=args) return f_with_allocated_out else: # function is called with `out` argument if out_shape is None: return f_jit else: def f_with_tested_out(arr, out=None, args=None): """helper function allocating output array""" assert out.shape == out_shape return f_jit(arr, out, args=args) return f_with_tested_out elif num_args == 2: @nb.generated_jit(**jit_kwargs_outer) @wraps(func) def wrapper(a, b, out=None, args=None): """wrapper deciding whether the underlying function is called with or without `out`. This uses nb.generated_jit to compile different versions of the same function.""" if isinstance(a, nb.types.Number): # simple scalar call -> do not need to allocate anything raise RuntimeError( "Functions defined with an `out` keyword should not be called " "with scalar quantities" ) elif isinstance(out, (nb.types.NoneType, nb.types.Omitted)): # function is called without `out` f_jit = register_jitable(**jit_kwargs_inner)(func) if out_shape is None: # we have to obtain the shape of `out` from `a` def f_with_allocated_out(a, b, out=None, args=None): """helper function allocating output array""" return f_jit(a, b, out=np.empty_like(a), args=args) else: # the shape of `out` is given by `out_shape` def f_with_allocated_out(a, b, out=None, args=None): """helper function allocating output array""" out_arr = np.empty(out_shape, dtype=a.dtype) return f_jit(a, b, out=out_arr, args=args) return f_with_allocated_out else: # function is called with `out` argument return func else: raise NotImplementedError("Only 1 or 2 arguments are supported") return wrapper # type: ignore
if nb.config.DISABLE_JIT: # dummy function that creates a ctypes pointer def address_as_void_pointer(addr): """returns a void pointer from a given memory address Example: This can for instance be used together with `numba.carray`: >>> addr = arr.ctypes.data >>> numba.carray(address_as_void_pointer(addr), arr.shape, arr.dtype Args: addr (int): The memory address Returns: :class:`ctypes.c_void_p`: Pointer to the memory address """ import ctypes return ctypes.cast(addr, ctypes.c_void_p) else: # actually useful function that creates a numba pointer @nb.extending.intrinsic def address_as_void_pointer(typingctx, src): """returns a void pointer from a given memory address Example: This can for instance be used together with `numba.carray`: >>> addr = arr.ctypes.data >>> numba.carray(address_as_void_pointer(addr), arr.shape, arr.dtype Args: addr (int): The memory address Returns: :class:`numba.core.types.voidptr`: Pointer to the memory address """ from numba.core import cgutils, types sig = types.voidptr(src) def codegen(cgctx, builder, sig, args): return builder.inttoptr(args[0], cgutils.voidptr_t) return sig, codegen
[docs]def make_array_constructor(arr: np.ndarray) -> Callable[[], np.ndarray]: """returns an array within a jitted function using basic information Args: arr (:class:`~numpy.ndarray`): The array that should be accessible within jit Warning: A reference to the array needs to be retained outside the numba code to prevent garbage collection from removing the array """ data_addr = arr.__array_interface__["data"][0] strides = arr.__array_interface__["strides"] shape = arr.__array_interface__["shape"] dtype = arr.dtype @register_jitable def array_constructor() -> np.ndarray: """helper that reconstructs the array from the pointer and structural info""" data: np.ndarray = nb.carray(address_as_void_pointer(data_addr), shape, dtype) if strides is not None: data = np.lib.index_tricks.as_strided(data, shape, strides) return data return array_constructor # type: ignore
[docs]@nb.generated_jit(nopython=True) def convert_scalar(arr): """helper function that turns 0d-arrays into scalars This helps to avoid the bug discussed in https://github.com/numba/numba/issues/6000 """ if isinstance(arr, nb.types.Array) and arr.ndim == 0: return lambda arr: arr[()] else: return lambda arr: arr
[docs]def numba_dict(data: Dict[str, Any] = None) -> Optional[NumbaDict]: """converts a python dictionary to a numba typed dictionary""" if data is None: return None nb_dict = NumbaDict() for k, v in data.items(): nb_dict[k] = v return nb_dict
[docs]def get_common_numba_dtype(*args): r"""returns a numba numerical type in which all arrays can be represented Args: *args: All items to be tested Returns: numba.complex128 if any entry is complex, otherwise numba.double """ from numba.core.types import npytypes, scalars for arg in args: if isinstance(arg, scalars.Complex): return nb.complex128 elif isinstance(arg, npytypes.Array): if isinstance(arg.dtype, scalars.Complex): return nb.complex128 else: raise NotImplementedError(f"Cannot handle type {arg.__class__}") return nb.double
if NUMBA_VERSION < [0, 45]: warnings.warn( "Your numba version is outdated. Please install at least version 0.45" )