"""
Helper functions for just-in-time compilation with numba
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
import logging
import os
import warnings
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar
import numba as nb # lgtm [py/import-and-import-from]
import numpy as np
from numba.extending import register_jitable
from numba.typed import Dict as NumbaDict
from .. import config
from ..tools.misc import decorator_arguments
try:
# is_jitted has been added in numba 0.53 on 2021-03-11
from numba.extending import is_jitted
except ImportError:
# for earlier version of numba, we need to define the function
def is_jitted(function: Callable) -> bool:
"""determine whether a function has already been jitted"""
try:
from numba.core.dispatcher import Dispatcher
except ImportError:
# assume older numba module structure
from numba.dispatcher import Dispatcher
return isinstance(function, Dispatcher)
# numba version as a list of integers
NUMBA_VERSION = [int(v) for v in nb.__version__.split(".")[:2]]
# global variable counting the number of compilations
JIT_COUNT = 0
TFunc = TypeVar("TFunc", bound="Callable")
[docs]def numba_environment() -> Dict[str, Any]:
"""return information about the numba setup used
Returns:
(dict) information about the numba setup
"""
# determine whether Nvidia Cuda is available
try:
from numba import cuda
cuda_available = cuda.is_available()
except ImportError:
cuda_available = False
# determine whether AMD ROC is available
try:
from numba import roc
roc_available = roc.is_available()
except ImportError:
roc_available = False
# determine threading layer
try:
threading_layer = nb.threading_layer()
except ValueError:
# threading layer was not initialized, so compile a mock function
@nb.jit("i8()", parallel=True)
def f():
s = 0
for i in nb.prange(4):
s += i
return s
f()
try:
threading_layer = nb.threading_layer()
except ValueError: # cannot initialize threading
threading_layer = None
except AttributeError: # old numba version
threading_layer = None
return {
"version": nb.__version__,
"parallel": config["numba.parallel"],
"fastmath": config["numba.fastmath"],
"debug": config["numba.debug"],
"using_svml": nb.config.USING_SVML,
"threading_layer": threading_layer,
"omp_num_threads": os.environ.get("OMP_NUM_THREADS"),
"mkl_num_threads": os.environ.get("MKL_NUM_THREADS"),
"num_threads": nb.config.NUMBA_NUM_THREADS,
"num_threads_default": nb.config.NUMBA_DEFAULT_NUM_THREADS,
"cuda_available": cuda_available,
"roc_available": roc_available,
}
def _get_jit_args(parallel: bool = False, **kwargs) -> Dict[str, Any]:
"""return arguments for the :func:`nb.jit` with default values
Args:
parallel (bool): Allow parallel compilation of the function
**kwargs: Additional arguments to `nb.jit`
Returns:
dict: Keyword arguments that can directly be used in :func:`nb.jit`
"""
kwargs.setdefault("fastmath", config["numba.fastmath"])
kwargs.setdefault("debug", config["numba.debug"])
# make sure parallel numba is only enabled in restricted cases
kwargs["parallel"] = parallel and config["numba.parallel"]
return kwargs
if nb.config.DISABLE_JIT:
# use work-around for https://github.com/numba/numba/issues/4759
def flat_idx(arr, i):
"""helper function allowing indexing of scalars as if they arrays"""
if np.isscalar(arr):
return arr
else:
return arr.flat[i]
else:
# compiled version that specializes correctly
[docs] @nb.generated_jit(nopython=True)
def flat_idx(arr, i):
"""helper function allowing indexing of scalars as if they arrays"""
if isinstance(arr, (nb.types.Integer, nb.types.Float)):
return lambda arr, i: arr
else:
return lambda arr, i: arr.flat[i]
[docs]@decorator_arguments
def jit(function: TFunc, signature=None, parallel: bool = False, **kwargs) -> TFunc:
"""apply nb.jit with predefined arguments
Args:
function: The function which is jitted
signature: Signature of the function to compile
parallel (bool): Allow parallel compilation of the function
**kwargs: Additional arguments to `nb.jit`
Returns:
Function that will be compiled using numba
"""
if is_jitted(function):
return function
# prepare the compilation arguments
kwargs.setdefault("nopython", True)
jit_kwargs = _get_jit_args(parallel=parallel, **kwargs)
# log some details
logger = logging.getLogger(__name__)
name = function.__name__
if kwargs["nopython"]: # standard case
logger.info("Compile `%s` with parallel=%s", name, jit_kwargs["parallel"])
else: # this might imply numba falls back to object mode
logger.warning("Compile `%s` with nopython=False", name)
# increase the compilation counter by one
global JIT_COUNT
JIT_COUNT += 1
return nb.jit(signature, **jit_kwargs)(function) # type: ignore
[docs]@decorator_arguments
def jit_allocate_out(
func: Callable,
parallel: bool = False,
out_shape: Optional[Tuple[int, ...]] = None,
num_args: int = 1,
**kwargs,
) -> Callable:
"""Decorator that compiles a function with allocating an output array.
This decorator compiles a function that takes the arguments `arr` and `out`. The
point of this decorator is to make the `out` array optional by supplying an empty
array of the same shape as `arr` if necessary. This is implemented efficiently by
using :func:`numba.generated_jit`.
Args:
func:
The function to be compiled
parallel (bool):
Determines whether the function is jitted with parallel=True.
out_shape (tuple):
Determines the shape of the `out` array. If omitted, the same shape as the
input array is used.
num_args (int, optional):
Determines the number of input arguments of the function.
**kwargs:
Additional arguments used in :func:`numba.jit`
Returns:
The decorated function
"""
# TODO: Remove `num_args` and use inspection on `func` instead
if nb.config.DISABLE_JIT:
# jitting is disabled => return generic python function
if num_args == 1:
def f_arg1_with_allocated_out(arr, out=None, args=None):
"""helper function allocating output array"""
if out is None:
if out_shape is None:
out = np.empty_like(arr)
else:
out = np.empty(out_shape, dtype=arr.dtype)
return func(arr, out, args=args)
return f_arg1_with_allocated_out
elif num_args == 2:
def f_arg2_with_allocated_out(a, b, out=None, args=None):
"""helper function allocating output array"""
if out is None:
assert a.shape == b.shape
if out_shape is None:
out = np.empty_like(a)
else:
out = np.empty(out_shape, dtype=a.dtype)
return func(a, b, out, args=args)
return f_arg2_with_allocated_out
else:
raise NotImplementedError("Only 1 or 2 arguments are supported")
else:
# jitting is enabled => return specific compiled functions
jit_kwargs_outer = _get_jit_args(nopython=True, parallel=False, **kwargs)
# we need to cast `parallel` to bool since np.bool is not supported by jit
jit_kwargs_inner = _get_jit_args(parallel=bool(parallel), **kwargs)
logging.getLogger(__name__).info(
"Compile `%s` with %s", func.__name__, jit_kwargs_inner
)
if num_args == 1:
@nb.generated_jit(**jit_kwargs_outer)
@wraps(func)
def wrapper(arr, out=None, args=None):
"""wrapper deciding whether the underlying function is called
with or without `out`. This uses :func:`nb.generated_jit` to
compile different versions of the same function
"""
if isinstance(arr, nb.types.Number):
raise RuntimeError(
"Functions defined with an `out` keyword must not be called "
"with scalar quantities."
)
if not isinstance(arr, nb.types.Array):
raise RuntimeError(
"Compiled functions need to be called with numpy arrays, not "
f"type {arr.__class__.__name__}"
)
f_jit = register_jitable(**jit_kwargs_inner)(func)
if isinstance(out, (nb.types.NoneType, nb.types.Omitted)):
# function is called without `out`
if out_shape is None:
# we have to obtain the shape of `out` from `arr`
def f_with_allocated_out(arr, out=None, args=None):
"""helper function allocating output array"""
return f_jit(arr, out=np.empty_like(arr), args=args)
else:
# the shape of `out` is given by `out_shape`
def f_with_allocated_out(arr, out=None, args=None):
"""helper function allocating output array"""
out_arr = np.empty(out_shape, dtype=arr.dtype)
return f_jit(arr, out=out_arr, args=args)
return f_with_allocated_out
else:
# function is called with `out` argument
if out_shape is None:
return f_jit
else:
def f_with_tested_out(arr, out=None, args=None):
"""helper function allocating output array"""
assert out.shape == out_shape
return f_jit(arr, out, args=args)
return f_with_tested_out
elif num_args == 2:
@nb.generated_jit(**jit_kwargs_outer)
@wraps(func)
def wrapper(a, b, out=None, args=None):
"""wrapper deciding whether the underlying function is called
with or without `out`. This uses nb.generated_jit to compile
different versions of the same function."""
if isinstance(a, nb.types.Number):
# simple scalar call -> do not need to allocate anything
raise RuntimeError(
"Functions defined with an `out` keyword should not be called "
"with scalar quantities"
)
elif isinstance(out, (nb.types.NoneType, nb.types.Omitted)):
# function is called without `out`
f_jit = register_jitable(**jit_kwargs_inner)(func)
if out_shape is None:
# we have to obtain the shape of `out` from `a`
def f_with_allocated_out(a, b, out=None, args=None):
"""helper function allocating output array"""
return f_jit(a, b, out=np.empty_like(a), args=args)
else:
# the shape of `out` is given by `out_shape`
def f_with_allocated_out(a, b, out=None, args=None):
"""helper function allocating output array"""
out_arr = np.empty(out_shape, dtype=a.dtype)
return f_jit(a, b, out=out_arr, args=args)
return f_with_allocated_out
else:
# function is called with `out` argument
return func
else:
raise NotImplementedError("Only 1 or 2 arguments are supported")
return wrapper # type: ignore
if nb.config.DISABLE_JIT:
# dummy function that creates a ctypes pointer
def address_as_void_pointer(addr):
"""returns a void pointer from a given memory address
Example:
This can for instance be used together with `numba.carray`:
>>> addr = arr.ctypes.data
>>> numba.carray(address_as_void_pointer(addr), arr.shape, arr.dtype
Args:
addr (int): The memory address
Returns:
:class:`ctypes.c_void_p`: Pointer to the memory address
"""
import ctypes
return ctypes.cast(addr, ctypes.c_void_p)
else:
# actually useful function that creates a numba pointer
@nb.extending.intrinsic
def address_as_void_pointer(typingctx, src):
"""returns a void pointer from a given memory address
Example:
This can for instance be used together with `numba.carray`:
>>> addr = arr.ctypes.data
>>> numba.carray(address_as_void_pointer(addr), arr.shape, arr.dtype
Args:
addr (int): The memory address
Returns:
:class:`numba.core.types.voidptr`: Pointer to the memory address
"""
from numba.core import cgutils, types
sig = types.voidptr(src)
def codegen(cgctx, builder, sig, args):
return builder.inttoptr(args[0], cgutils.voidptr_t)
return sig, codegen
[docs]def make_array_constructor(arr: np.ndarray) -> Callable[[], np.ndarray]:
"""returns an array within a jitted function using basic information
Args:
arr (:class:`~numpy.ndarray`): The array that should be accessible within jit
Warning:
A reference to the array needs to be retained outside the numba code to prevent
garbage collection from removing the array
"""
data_addr = arr.__array_interface__["data"][0]
strides = arr.__array_interface__["strides"]
shape = arr.__array_interface__["shape"]
dtype = arr.dtype
@register_jitable
def array_constructor() -> np.ndarray:
"""helper that reconstructs the array from the pointer and structural info"""
data: np.ndarray = nb.carray(address_as_void_pointer(data_addr), shape, dtype)
if strides is not None:
data = np.lib.index_tricks.as_strided(data, shape, strides)
return data
return array_constructor # type: ignore
[docs]@nb.generated_jit(nopython=True)
def convert_scalar(arr):
"""helper function that turns 0d-arrays into scalars
This helps to avoid the bug discussed in
https://github.com/numba/numba/issues/6000
"""
if isinstance(arr, nb.types.Array) and arr.ndim == 0:
return lambda arr: arr[()]
else:
return lambda arr: arr
[docs]def numba_dict(data: Dict[str, Any] = None) -> Optional[NumbaDict]:
"""converts a python dictionary to a numba typed dictionary"""
if data is None:
return None
nb_dict = NumbaDict()
for k, v in data.items():
nb_dict[k] = v
return nb_dict
[docs]def get_common_numba_dtype(*args):
r"""returns a numba numerical type in which all arrays can be represented
Args:
*args: All items to be tested
Returns: numba.complex128 if any entry is complex, otherwise numba.double
"""
from numba.core.types import npytypes, scalars
for arg in args:
if isinstance(arg, scalars.Complex):
return nb.complex128
elif isinstance(arg, npytypes.Array):
if isinstance(arg.dtype, scalars.Complex):
return nb.complex128
else:
raise NotImplementedError(f"Cannot handle type {arg.__class__}")
return nb.double
if NUMBA_VERSION < [0, 45]:
warnings.warn(
"Your numba version is outdated. Please install at least version 0.45"
)