"""Defines the :mod:`numba` backend class.
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import functools
import warnings
from typing import TYPE_CHECKING, Any
import numba as nb
import numpy as np
from numba.extending import is_jitted, register_jitable
from numba.extending import overload as nb_overload
from ...fields import DataFieldBase, VectorField
from ...grids import DimensionError, DomainError, GridBase
from ...grids.boundaries.axes import BoundariesBase, BoundariesList, BoundariesSetter
from ...grids.boundaries.local import BCBase, UserBC
from ...tools.cache import cached_method
from ...tools.config import is_hpc_environment
from ...tools.typing import OperatorInfo
from ..numpy.backend import NumpyBackend
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from numpy.typing import DTypeLike
from ...grids.boundaries.axis import BoundaryAxisBase
from ...pdes import PDEBase
from ...solvers.base import SolverBase
from ...tools.expressions import ExpressionBase
from ...tools.typing import (
BinaryOperatorImplType,
FloatingArray,
GhostCellSetter,
InexactArray,
Number,
NumberOrArray,
NumericArray,
OperatorType,
StepperType,
TField,
TFunc,
)
[docs]
class NumbaBackend(NumpyBackend):
"""Defines :mod:`numba` backend."""
implementation = "numba"
[docs]
def compile_function(self, func: TFunc, **kwargs) -> TFunc:
"""General method that compiles a user function.
Args:
func (callable):
The function that needs to be compiled for this backend
**kwargs:
Additional arguments forwarded to :func:`pde.backends.numba.utils.jit`
"""
from .utils import jit
return jit(backend=self, **kwargs)(func)
[docs]
def use_multithreading(self) -> bool:
"""Determine whether multithreading should be used in numba-compiled code.
This method checks the configuration setting for `numba.multithreading` and
determines whether multithreading should be enabled based on the value of this
setting. The possible values for `numba.multithreading` are:
- 'always': Multithreading is always enabled.
- 'never': Multithreading is never enabled.
- 'only_local': Multithreading is enabled only if the code is not running in a
high-performance computing (HPC) environment.
Returns:
bool: True if multithreading should be enabled, False otherwise.
Raises:
ValueError: If the `numba.multithreading` setting is not one of the expected
values ('always', 'never', 'only_local').
"""
setting = self._config_parameter("multithreading")
if setting == "always":
return True
if setting == "never":
return False
if setting == "only_local":
return not is_hpc_environment()
msg = (
"Parameter `backend.numba.multithreading` must be in {'always', 'never', "
f"'only_local'}}, not `{setting}`"
)
raise ValueError(msg)
[docs]
def get_registered_operators(self, grid_id: GridBase | type[GridBase]) -> set[str]:
"""Returns all operators defined for a backend.
Args:
grid_id (:class:`~pde.grid.base.GridBase` or its type):
Grid for which the operator need to be returned
"""
operators = super().get_registered_operators(grid_id)
# add operators calculating derivate along a coordinate
for ax in getattr(grid_id, "axes", []):
operators |= {
f"d_d{ax}",
f"d_d{ax}_forward",
f"d_d{ax}_backward",
f"d2_d{ax}2",
}
return operators
[docs]
def get_operator_info(
self, grid: GridBase, operator: str | OperatorInfo
) -> OperatorInfo:
"""Return the operator defined for this backend.
Args:
grid (:class:`~pde.grid.base.GridBase`):
Grid for which the operator is needed
operator (str):
Identifier for the operator. Some examples are 'laplace', 'gradient', or
'divergence'. The registered operators for this grid can be obtained
from the :attr:`~pde.grids.base.GridBase.operators` attribute.
Returns:
:class:`~pde.grids.base.OperatorInfo`: information for the operator
"""
if isinstance(operator, OperatorInfo):
return operator
assert isinstance(operator, str)
try:
# try the default method for determining operators
return super().get_operator_info(grid, operator)
except NotImplementedError:
# deal with some special patterns that are often used
if operator.startswith("d_d"):
# create a special operator that takes a first derivative along one axis
from .operators.common import make_derivative
# determine axis to which operator is applied (and the method to use)
axis_name = operator[len("d_d") :]
for direction in ["central", "forward", "backward"]:
if axis_name.endswith("_" + direction):
method = direction
axis_name = axis_name[: -len("_" + direction)]
break
else:
method = "central"
axis_id = grid.axes.index(axis_name)
factory = functools.partial(
make_derivative,
axis=axis_id,
method=method, # type: ignore
)
return OperatorInfo(factory, rank_in=0, rank_out=0, name=operator)
if operator.startswith("d2_d") and operator.endswith("2"):
# create a special operator taking a second derivative along one axis
from .operators.common import make_derivative2
axis_id = grid.axes.index(operator[len("d2_d") : -1])
factory = functools.partial(make_derivative2, axis=axis_id)
return OperatorInfo(factory, rank_in=0, rank_out=0, name=operator)
# throw an informative error since operator was not found
op_list = ", ".join(sorted(self.get_registered_operators(grid)))
msg = (
f"'{operator}' is not defined for {grid.__class__.__name__} on backend "
f"{self.__class__.__name__}. Operators can be added using the "
f"`register_operator` method. Defined operators: {op_list})."
)
raise NotImplementedError(msg)
def _make_local_ghost_cell_setter(self, bc: BCBase) -> GhostCellSetter:
"""Return function that sets the ghost cells for a particular side of an axis.
Args:
bc (:class:`~pde.grids.boundaries.local.BCBase`):
Defines the boundary conditions for a particular side, for which the
setter should be defined.
Returns:
Callable with signature :code:`(data_full: NumericArray, args=None)`, which
sets the ghost cells of the full data, potentially using additional
information in `args` (e.g., the time `t` during solving a PDE)
"""
from ._boundaries import make_virtual_point_evaluator
normal = bc.normal
axis = bc.axis
# get information of the virtual points (ghost cells)
vp_idx = bc.grid.shape[bc.axis] + 1 if bc.upper else 0
np_idx = bc._get_value_cell_index(with_ghost_cells=False)
vp_value = make_virtual_point_evaluator(bc)
if bc.grid.num_axes == 1: # 1d grid
@register_jitable
def ghost_cell_setter(data_full: NumericArray, args=None) -> None:
"""Helper function setting the conditions on all axes."""
data_valid = data_full[..., 1:-1]
val = vp_value(data_valid, (np_idx,), args=args)
if normal:
data_full[..., axis, vp_idx] = val
else:
data_full[..., vp_idx] = val
elif bc.grid.num_axes == 2: # 2d grid
if bc.axis == 0:
num_y = bc.grid.shape[1]
@register_jitable
def ghost_cell_setter(data_full: NumericArray, args=None) -> None:
"""Helper function setting the conditions on all axes."""
data_valid = data_full[..., 1:-1, 1:-1]
for j in range(num_y):
val = vp_value(data_valid, (np_idx, j), args=args)
if normal:
data_full[..., axis, vp_idx, j + 1] = val
else:
data_full[..., vp_idx, j + 1] = val
elif bc.axis == 1:
num_x = bc.grid.shape[0]
@register_jitable
def ghost_cell_setter(data_full: NumericArray, args=None) -> None:
"""Helper function setting the conditions on all axes."""
data_valid = data_full[..., 1:-1, 1:-1]
for i in range(num_x):
val = vp_value(data_valid, (i, np_idx), args=args)
if normal:
data_full[..., axis, i + 1, vp_idx] = val
else:
data_full[..., i + 1, vp_idx] = val
elif bc.grid.num_axes == 3: # 3d grid
if bc.axis == 0:
num_y, num_z = bc.grid.shape[1:]
@register_jitable
def ghost_cell_setter(data_full: NumericArray, args=None) -> None:
"""Helper function setting the conditions on all axes."""
data_valid = data_full[..., 1:-1, 1:-1, 1:-1]
for j in range(num_y):
for k in range(num_z):
val = vp_value(data_valid, (np_idx, j, k), args=args)
if normal:
data_full[..., axis, vp_idx, j + 1, k + 1] = val
else:
data_full[..., vp_idx, j + 1, k + 1] = val
elif bc.axis == 1:
num_x, num_z = bc.grid.shape[0], bc.grid.shape[2]
@register_jitable
def ghost_cell_setter(data_full: NumericArray, args=None) -> None:
"""Helper function setting the conditions on all axes."""
data_valid = data_full[..., 1:-1, 1:-1, 1:-1]
for i in range(num_x):
for k in range(num_z):
val = vp_value(data_valid, (i, np_idx, k), args=args)
if normal:
data_full[..., axis, i + 1, vp_idx, k + 1] = val
else:
data_full[..., i + 1, vp_idx, k + 1] = val
elif bc.axis == 2:
num_x, num_y = bc.grid.shape[:2]
@register_jitable
def ghost_cell_setter(data_full: NumericArray, args=None) -> None:
"""Helper function setting the conditions on all axes."""
data_valid = data_full[..., 1:-1, 1:-1, 1:-1]
for i in range(num_x):
for j in range(num_y):
val = vp_value(data_valid, (i, j, np_idx), args=args)
if normal:
data_full[..., axis, i + 1, j + 1, vp_idx] = val
else:
data_full[..., i + 1, j + 1, vp_idx] = val
else:
msg = "Too many axes"
raise NotImplementedError(msg)
if isinstance(bc, UserBC):
# the (pretty uncommon) UserBC needs a special check, which we add here
@register_jitable
def ghost_cell_setter_wrapped(data_full: NumericArray, args=None) -> None:
"""Helper function setting the conditions on all axes."""
if args is None:
return # no-op when no specific arguments are given
if "virtual_point" in args or "value" in args or "derivative" in args:
# ghost cells will only be set if any of the above keys are supplied
ghost_cell_setter(data_full, args=args)
# else: no-op for the default case where BCs are not set by user
return ghost_cell_setter_wrapped # type: ignore
# the standard case just uses the ghost_cell_setter as defined above
return ghost_cell_setter # type: ignore
def _make_axis_ghost_cell_setter(
self, bc_axis: BoundaryAxisBase
) -> GhostCellSetter:
"""Return function that sets the ghost cells for a particular axis.
Args:
bc_axis (:class:`~pde.grids.boundaries.axis.BoundaryAxisBase`):
Defines the boundary conditions for a particular axis, for which the
setter should be defined.
Returns:
Callable with signature :code:`(data_full: NumericArray, args=None)`, which
sets the ghost cells of the full data, potentially using additional
information in `args` (e.g., the time `t` during solving a PDE)
"""
# get the functions that handle the data
ghost_cell_setter_low = self._make_local_ghost_cell_setter(bc_axis.low)
ghost_cell_setter_high = self._make_local_ghost_cell_setter(bc_axis.high)
@register_jitable
def ghost_cell_setter(data_full: NumericArray, args=None) -> None:
"""Helper function setting the conditions on all axes."""
# set the actual ghost cells
ghost_cell_setter_high(data_full, args=args)
ghost_cell_setter_low(data_full, args=args)
return ghost_cell_setter # type: ignore
[docs]
def make_ghost_cell_setter(self, bcs: BoundariesBase) -> GhostCellSetter:
"""Return function that sets the ghost cells on a full array.
Args:
bcs (:class:`~pde.grids.boundaries.axes.BoundariesBase`):
Defines the boundary conditions for a particular grid, for which the
setter should be defined.
Returns:
Callable with signature :code:`(data_full: NumericArray, args=None)`, which
sets the ghost cells of the full data, potentially using additional
information in `args` (e.g., the time `t` during solving a PDE)
"""
if isinstance(bcs, BoundariesList):
ghost_cell_setters = tuple(
self._make_axis_ghost_cell_setter(bc_axis) for bc_axis in bcs
)
# TODO: use numba.literal_unroll
# # get the setters for all axes
#
# from numba import jit
#
# @self.compile_function
# def set_ghost_cells(data_full: NumericArray, args=None) -> None:
# for f in nb.literal_unroll(ghost_cell_setters):
# f(data_full, args=args)
#
# return set_ghost_cells
def chain(
fs: Sequence[GhostCellSetter], inner: GhostCellSetter | None = None
) -> GhostCellSetter:
"""Helper function composing setters of all axes recursively."""
first, rest = fs[0], fs[1:]
if inner is None:
@register_jitable
def wrap(data_full: NumericArray, args=None) -> None:
first(data_full, args=args)
else:
@register_jitable
def wrap(data_full: NumericArray, args=None) -> None:
inner(data_full, args=args)
first(data_full, args=args)
if rest:
return chain(rest, wrap)
return wrap # type: ignore
return chain(ghost_cell_setters)
if isinstance(bcs, BoundariesSetter):
return self.compile_function(bcs._setter)
msg = f"Cannot handle following boundary conditions: {bcs}"
raise NotImplementedError(msg)
[docs]
@cached_method()
def make_operator(
self,
grid: GridBase,
operator: str | OperatorInfo,
*,
bcs: BoundariesBase,
dtype: DTypeLike | None = None,
**kwargs,
) -> OperatorType:
"""Return a compiled function applying an operator with boundary conditions.
Args:
grid (:class:`~pde.grid.base.GridBase`):
Grid for which the operator is needed
operator (str):
Identifier for the operator. Some examples are 'laplace', 'gradient', or
'divergence'. The registered operators for this grid can be obtained
from the :attr:`~pde.grids.base.GridBase.operators` attribute.
bcs (:class:`~pde.grids.boundaries.axes.BoundariesBase`, optional):
The boundary conditions used before the operator is applied
dtype (numpy dtype):
The data type of the field.
**kwargs:
Specifies extra arguments influencing how the operator is created.
The returned function takes the discretized data on the grid as an input and
returns the data to which the operator `operator` has been applied. The function
only takes the valid grid points and allocates memory for the ghost points
internally to apply the boundary conditions specified as `bc`. Note that the
function supports an optional argument `out`, which if given should provide
space for the valid output array without the ghost cells. The result of the
operator is then written into this output array.
The function also accepts an optional parameter `args`, which is forwarded to
`set_ghost_cells`. This allows setting boundary conditions based on external
parameters, like time. Note that since the returned operator will always be
compiled by Numba, the arguments need to be compatible with Numba. The
following example shows how to pass the current time `t`:
Returns:
callable: the function that applies the operator. This function has the
signature (arr: NumericArray, out: NumericArray = None, args=None).
"""
# determine the operator for the chosen backend
operator_info = self.get_operator_info(grid, operator)
operator_raw = operator_info.factory(grid, backend=self, **kwargs)
# calculate shapes of the full data
shape_in_valid = (grid.dim,) * operator_info.rank_in + grid.shape
shape_in_full = (grid.dim,) * operator_info.rank_in + grid._shape_full
shape_out = (grid.dim,) * operator_info.rank_out + grid.shape
# define numpy version of the operator
def apply_op(
arr: NumericArray, out: NumericArray | None = None, args=None
) -> NumericArray:
"""Set boundary conditions and apply operator."""
# check input array
if arr.shape != shape_in_valid:
msg = f"Incompatible shapes {arr.shape} != {shape_in_valid}"
raise ValueError(msg)
# ensure `out` array is allocated and has the right shape
if out is None:
out = np.empty(shape_out, dtype=arr.dtype)
elif out.shape != shape_out:
msg = f"Incompatible shapes {out.shape} != {shape_out}"
raise ValueError(msg)
# prepare input with boundary conditions
arr_full = np.empty(shape_in_full, dtype=arr.dtype)
arr_full[(..., *grid._idx_valid)] = arr # type: ignore
bcs.set_ghost_cells(arr_full, args=args)
# apply operator
operator_raw(arr_full, out) # type: ignore
# return valid part of the output
return out
# overload `apply_op` with numba-compiled version
set_valid_and_bcs = self.make_full_data_setter(bcs=bcs)
if not is_jitted(operator_raw):
operator_raw = self.compile_function(operator_raw)
@nb_overload(apply_op, inline="always")
def apply_op_ol(
arr: NumericArray, out: NumericArray | None = None, args=None
) -> NumericArray:
"""Make numba implementation of the operator."""
if isinstance(out, (nb.types.NoneType, nb.types.Omitted)):
# need to allocate memory for `out`
def apply_op_impl(
arr: NumericArray, out: NumericArray | None = None, args=None
) -> NumericArray:
"""Allocates `out` and applies operator to the data."""
if arr.shape != shape_in_valid:
raise ValueError("Incompatible shapes of input array") # noqa: EM101, TRY003
out = np.empty(shape_out, dtype=arr.dtype)
# prepare input with boundary conditions
arr_full = np.empty(shape_in_full, dtype=arr.dtype)
set_valid_and_bcs(arr_full, arr, args=args)
# apply operator
operator_raw(arr_full, out) # type: ignore
# return valid part of the output
return out
else:
# reuse provided `out` array
def apply_op_impl(
arr: NumericArray, out: NumericArray | None = None, args=None
) -> NumericArray:
"""Applies operator to the data without allocating out."""
if TYPE_CHECKING:
assert isinstance(out, np.ndarray) # help type checker
if arr.shape != shape_in_valid:
raise ValueError("Incompatible shapes of input array") # noqa: EM101, TRY003
if out.shape != shape_out:
raise ValueError("Incompatible shapes of output array") # noqa: EM101, TRY003
# prepare input with boundary conditions
arr_full = np.empty(shape_in_full, dtype=arr.dtype)
set_valid_and_bcs(arr_full, arr, args=args)
# apply operator
operator_raw(arr_full, out) # type: ignore
# return valid part of the output
return out
return apply_op_impl # type: ignore
@self.compile_function
def apply_op_compiled(
arr: NumericArray, out: NumericArray | None = None, args=None
) -> NumericArray:
"""Set boundary conditions and apply operator."""
return apply_op(arr, out, args)
# return the compiled versions of the operator
return apply_op_compiled # type: ignore
def _make_local_integrator(
self, grid: GridBase
) -> Callable[[NumericArray], NumberOrArray]:
"""Return function that integrates discretized data over a grid.
If this function is used in a multiprocessing run (using MPI), the integrals are
performed on all subgrids and then accumulated. Each process then receives the
same value representing the global integral.
Args:
grid (:class:`~pde.grid.base.GridBase`):
Grid for which the integrator is defined
Returns:
A function that takes a numpy array and returns the integral with the
correct weights given by the cell volumes.
"""
from . import grids
num_axes = grid.num_axes
# cell volume varies with position
get_cell_volume = grids.make_cell_volume_getter(grid=grid, flat_index=True)
def integrate_local(arr: NumericArray) -> NumberOrArray:
"""Integrates data over a grid using numpy."""
# Dummy function so we can overwrite it using numba. This function will only
# be called when the numba backend is used with DISABLE_JIT=True
amounts = arr * grid.cell_volumes
return amounts.sum(axis=tuple(range(-num_axes, 0, 1))) # type: ignore
# We need to overload the integrate function since we want to be able to
# integrate scalar and tensorial fields, which lead to different signatures.
@nb_overload(integrate_local)
def ol_integrate_local(
arr: NumericArray,
) -> Callable[[NumericArray], NumberOrArray]:
"""Integrates data over a grid using numba."""
if arr.ndim == num_axes:
# `arr` is a scalar field
grid_shape = grid.shape
def impl(arr: NumericArray) -> Number:
"""Integrate a scalar field."""
assert arr.shape == grid_shape
total = 0
for i in range(arr.size):
total += get_cell_volume(i) * arr.flat[i]
return total
else:
# `arr` is a tensorial field with rank >= 1
tensor_shape = (grid.dim,) * (arr.ndim - num_axes)
data_shape = tensor_shape + grid.shape
def impl(arr: NumericArray) -> NumericArray: # type: ignore
"""Integrate a tensorial field."""
assert arr.shape == data_shape
total = np.zeros(tensor_shape)
for idx in np.ndindex(*tensor_shape):
arr_comp = arr[idx]
for i in range(arr_comp.size):
total[idx] += get_cell_volume(i) * arr_comp.flat[i]
return total
return impl
return integrate_local
[docs]
def make_integrator( # type: ignore
self, grid: GridBase
) -> Callable[[NumericArray], NumberOrArray]:
"""Return function that integrates discretized data over a grid.
If this function is used in a multiprocessing run (using MPI), the integrals are
performed on all subgrids and then accumulated. Each process then receives the
same value representing the global integral.
Args:
grid (:class:`~pde.grid.base.GridBase`):
Grid for which the integrator is defined
Returns:
A function that takes a numpy array and returns the integral with the
correct weights given by the cell volumes.
"""
integrate_local = self._make_local_integrator(grid)
@self.compile_function
def integrate_global(arr: NumericArray) -> NumberOrArray:
"""Integrate data.
Args:
arr (:class:`~numpy.ndarray`): discretized data on grid
"""
return integrate_local(arr)
return integrate_global
[docs]
def make_inner_prod_operator(
self, field: DataFieldBase, *, conjugate: bool = True
) -> BinaryOperatorImplType:
"""Return operator calculating the dot product between two fields.
This supports both products between two vectors as well as products
between a vector and a tensor.
Args:
field (:class:`~pde.fields.datafield_base.DataFieldBase`):
Field for which the inner product is defined
conjugate (bool):
Whether to use the complex conjugate for the second operand
Returns:
function that takes two instance of :class:`~numpy.ndarray`, which contain
the discretized data of the two operands. An optional third argument can
specify the output array to which the result is written.
"""
from .utils import get_common_numba_dtype
dot = super().make_inner_prod_operator(field, conjugate=conjugate)
dim = field.grid.dim
num_axes = field.grid.num_axes
@register_jitable
def maybe_conj(arr: NumericArray) -> NumericArray:
"""Helper function implementing optional conjugation."""
return arr.conj() if conjugate else arr
def get_rank(arr: nb.types.Type | nb.types.Optional) -> int:
"""Determine rank of field with type `arr`"""
arr_typ = arr.type if isinstance(arr, nb.types.Optional) else arr
if not isinstance(arr_typ, (np.ndarray, nb.types.Array)):
msg = f"Dot argument must be array, not {arr_typ.__class__}"
raise nb.errors.TypingError(msg)
rank = arr_typ.ndim - num_axes
if rank < 1:
msg = (
f"Rank={rank} too small for dot product. Use a normal product "
"instead."
)
raise nb.NumbaTypeError(msg)
return rank
@nb_overload(dot, inline="always")
def dot_ol(
a: NumericArray, b: NumericArray, out: NumericArray | None = None
) -> NumericArray:
"""Numba implementation to calculate dot product between two fields."""
# get (and check) rank of the input arrays
rank_a = get_rank(a)
rank_b = get_rank(b)
if rank_a == 1 and rank_b == 1: # result is scalar field
@register_jitable
def calc(a: NumericArray, b: NumericArray, out: NumericArray) -> None:
out[:] = a[0] * maybe_conj(b[0])
for j in range(1, dim):
out[:] += a[j] * maybe_conj(b[j])
elif rank_a == 2 and rank_b == 1: # result is vector field
@register_jitable
def calc(a: NumericArray, b: NumericArray, out: NumericArray) -> None:
for i in range(dim):
out[i] = a[i, 0] * maybe_conj(b[0])
for j in range(1, dim):
out[i] += a[i, j] * maybe_conj(b[j])
elif rank_a == 1 and rank_b == 2: # result is vector field
@register_jitable
def calc(a: NumericArray, b: NumericArray, out: NumericArray) -> None:
for i in range(dim):
out[i] = a[0] * maybe_conj(b[0, i])
for j in range(1, dim):
out[i] += a[j] * maybe_conj(b[j, i])
elif rank_a == 2 and rank_b == 2: # result is tensor-2 field
@register_jitable
def calc(a: NumericArray, b: NumericArray, out: NumericArray) -> None:
for i in range(dim):
for j in range(dim):
out[i, j] = a[i, 0] * maybe_conj(b[0, j])
for k in range(1, dim):
out[i, j] += a[i, k] * maybe_conj(b[k, j])
else:
msg = "Inner product for these ranks"
raise NotImplementedError(msg)
if isinstance(out, (nb.types.NoneType, nb.types.Omitted)):
# function is called without `out` -> allocate memory
rank_out = rank_a + rank_b - 2
a_shape = (dim,) * rank_a + field.grid.shape
b_shape = (dim,) * rank_b + field.grid.shape
out_shape = (dim,) * rank_out + field.grid.shape
dtype = get_common_numba_dtype(a, b)
def dot_impl(
a: NumericArray,
b: NumericArray,
out: NumericArray | None = None,
) -> NumericArray:
"""Helper function allocating output array."""
assert a.shape == a_shape
assert b.shape == b_shape
out = np.empty(out_shape, dtype=dtype)
calc(a, b, out)
return out
else:
# function is called with `out` argument -> reuse `out` array
def dot_impl(
a: NumericArray,
b: NumericArray,
out: NumericArray | None = None,
) -> NumericArray:
"""Helper function without allocating output array."""
assert a.shape == a_shape
assert b.shape == b_shape
assert out.shape == out_shape # type: ignore
calc(a, b, out)
return out # type: ignore
return dot_impl # type: ignore
@self.compile_function
def dot_compiled(
a: NumericArray, b: NumericArray, out: NumericArray | None = None
) -> NumericArray:
"""Numba implementation to calculate dot product between two fields."""
return dot(a, b, out) # type: ignore
return dot_compiled
[docs]
def make_outer_prod_operator(self, field: DataFieldBase) -> BinaryOperatorImplType:
"""Return operator calculating the outer product between two fields.
This supports typically only supports products between two vector fields.
Args:
field (:class:`~pde.fields.datafield_base.DataFieldBase`):
Field for which the outer product is defined
Returns:
function that takes two instance of :class:`~numpy.ndarray`, which contain
the discretized data of the two operands. An optional third argument can
specify the output array to which the result is written.
"""
from .utils import get_common_numba_dtype
if not isinstance(field, VectorField):
msg = "Can only define outer product between vector fields"
raise TypeError(msg)
def outer(
a: NumericArray, b: NumericArray, out: NumericArray | None = None
) -> NumericArray:
"""Calculate the outer product using numpy."""
return np.einsum("i...,j...->ij...", a, b, out=out)
# overload `outer` with a numba-compiled version
dim = field.grid.dim
num_axes = field.grid.num_axes
def check_rank(arr: nb.types.Type | nb.types.Optional) -> None:
"""Determine rank of field with type `arr`"""
arr_typ = arr.type if isinstance(arr, nb.types.Optional) else arr
if not isinstance(arr_typ, (np.ndarray, nb.types.Array)):
msg = f"Arguments must be array, not {arr_typ.__class__}"
raise nb.errors.TypingError(msg)
assert arr_typ.ndim == 1 + num_axes
# create the inner function calculating the outer product
@register_jitable
def calc(a: NumericArray, b: NumericArray, out: NumericArray) -> NumericArray:
"""Calculate outer product between fields `a` and `b`"""
for i in range(dim):
for j in range(dim):
out[i, j, :] = a[i] * b[j]
return out
@nb_overload(outer, inline="always")
def outer_ol(
a: NumericArray, b: NumericArray, out: NumericArray | None = None
) -> NumericArray:
"""Numba implementation to calculate outer product between two fields."""
# get (and check) rank of the input arrays
check_rank(a)
check_rank(b)
in_shape = (dim, *field.grid.shape)
out_shape = (dim, dim, *field.grid.shape)
if isinstance(out, (nb.types.NoneType, nb.types.Omitted)):
# function is called without `out` -> allocate memory
dtype = get_common_numba_dtype(a, b)
def outer_impl(
a: NumericArray,
b: NumericArray,
out: NumericArray | None = None,
) -> NumericArray:
"""Helper function allocating output array."""
assert a.shape == b.shape == in_shape
out = np.empty(out_shape, dtype=dtype)
calc(a, b, out)
return out
else:
# function is called with `out` argument -> reuse `out` array
def outer_impl(
a: NumericArray,
b: NumericArray,
out: NumericArray | None = None,
) -> NumericArray:
"""Helper function without allocating output array."""
# check input
assert a.shape == b.shape == in_shape
assert out.shape == out_shape # type: ignore
calc(a, b, out)
return out # type: ignore
return outer_impl # type: ignore
@self.compile_function
def outer_compiled(
a: NumericArray, b: NumericArray, out: NumericArray | None = None
) -> NumericArray:
"""Numba implementation to calculate outer product between two fields."""
return outer(a, b, out)
return outer_compiled
[docs]
def make_interpolator(
self,
field: DataFieldBase,
*,
fill: Number | None = None,
with_ghost_cells: bool = False,
) -> Callable[[FloatingArray, NumericArray], NumberOrArray]:
r"""Returns a function that can be used to interpolate values.
Args:
field (:class:`~pde.fields.datafield_base.DataFieldBase`):
Field for which the interpolator is defined
fill (Number, optional):
Determines how values out of bounds are handled. If `None`, a
`ValueError` is raised when out-of-bounds points are requested.
Otherwise, the given value is returned.
with_ghost_cells (bool):
Flag indicating that the interpolator should work on the full data array
that includes values for the ghost points. If this is the case, the
boundaries are not checked and the coordinates are used as is.
Returns:
A function which returns interpolated values when called with arbitrary
positions within the space of the grid.
"""
from . import grids
from .utils import make_array_constructor
grid = field.grid
num_axes = field.grid.num_axes
data_shape = field.data_shape
# convert `fill` to dtype of data
if fill is not None:
if field.rank == 0:
fill = field.data.dtype.type(fill) # type: ignore
else:
fill = np.broadcast_to(fill, field.data_shape).astype(field.data.dtype) # type: ignore
# create the method to interpolate data at a single point
interpolate_single = grids.make_single_interpolator(
grid=grid, fill=fill, with_ghost_cells=with_ghost_cells
)
# provide a method to access the current data of the field
if with_ghost_cells:
get_data_array = make_array_constructor(field._data_full)
else:
get_data_array = make_array_constructor(field.data)
dim_error_msg = f"Dimension of point does not match axes count {num_axes}"
@self.compile_function
def interpolator(
point: FloatingArray, data: NumericArray | None = None
) -> NumericArray:
"""Return the interpolated value at the position `point`
Args:
point (:class:`~numpy.ndarray`):
The list of points. This point coordinates should be given along the
last axis, i.e., the shape should be `(..., num_axes)`.
data (:class:`~numpy.ndarray`, optional):
The discretized field values. If omitted, the data of the current
field is used, which should be the default. However, this option can
be useful to interpolate other fields defined on the same grid
without recreating the interpolator. If a data array is supplied, it
needs to be the full data if `with_ghost_cells == True`, and
otherwise only the valid data.
Returns:
:class:`~numpy.ndarray`: The interpolated values at the points
"""
# check input
point = np.atleast_1d(point)
if point.shape[-1] != num_axes:
raise DimensionError(dim_error_msg)
point_shape = point.shape[:-1]
if data is None:
# reconstruct data field from memory address
data = get_data_array()
# interpolate at every valid point
out = np.empty(data_shape + point_shape, dtype=data.dtype)
for idx in np.ndindex(*point_shape):
out[(..., *idx)] = interpolate_single(data, point[idx])
return out
# store a reference to the data so it is not garbage collected too early
interpolator._data = field.data # type: ignore
return interpolator
[docs]
def make_inserter(
self, grid: GridBase, *, with_ghost_cells: bool = False
) -> Callable[[InexactArray, FloatingArray, NumberOrArray], None]:
"""Return a compiled function to insert values at interpolated positions.
Args:
grid (:class:`~pde.grid.base.GridBase`):
Grid for which the integrator is defined
with_ghost_cells (bool):
Flag indicating that the interpolator should work on the full data array
that includes values for the grid points. If this is the case, the
boundaries are not checked and the coordinates are used as is.
Returns:
callable: A function with signature (data, position, amount), where `data`
is the numpy array containing the field data, position is denotes the
position in grid coordinates, and `amount` is the that is to be added to
the field.
"""
from . import grids
cell_volume = grids.make_cell_volume_getter(grid=grid, flat_index=False)
if grid.num_axes == 1:
# specialize for 1-dimensional interpolation
data_x = grids.make_interpolation_axis_data(
grid=grid, axis=0, with_ghost_cells=with_ghost_cells
)
@self.compile_function
def insert(
data: InexactArray, point: FloatingArray, amount: NumberOrArray
) -> None:
"""Add an amount to a field at an interpolated position.
Args:
data (:class:`~numpy.ndarray`):
The values at the grid points
point (:class:`~numpy.ndarray`):
Coordinates of a single point in the grid coordinate system
amount (Number or :class:`~numpy.ndarray`):
The amount that will be added to the data. This value describes
an integrated quantity (given by the field value times the
discretization volume). This is important for consistency with
different discretizations and in particular grids with
non-uniform discretizations
"""
c_li, c_hi, w_l, w_h = data_x(float(point[0]))
if c_li == -42: # out of bounds
msg = "Point lies outside the grid domain"
raise DomainError(msg)
data[..., c_li] += w_l * amount / cell_volume(c_li) # type: ignore
data[..., c_hi] += w_h * amount / cell_volume(c_hi) # type: ignore
elif grid.num_axes == 2:
# specialize for 2-dimensional interpolation
data_x = grids.make_interpolation_axis_data(
grid=grid, axis=0, with_ghost_cells=with_ghost_cells
)
data_y = grids.make_interpolation_axis_data(
grid=grid, axis=1, with_ghost_cells=with_ghost_cells
)
@self.compile_function
def insert(
data: InexactArray, point: FloatingArray, amount: NumberOrArray
) -> None:
"""Add an amount to a field at an interpolated position.
Args:
data (:class:`~numpy.ndarray`):
The values at the grid points
point (:class:`~numpy.ndarray`):
Coordinates of a single point in the grid coordinate system
amount (Number or :class:`~numpy.ndarray`):
The amount that will be added to the data. This value describes
an integrated quantity (given by the field value times the
discretization volume). This is important for consistency with
different discretizations and in particular grids with
non-uniform discretizations
"""
# determine surrounding points and their weights
c_xli, c_xhi, w_xl, w_xh = data_x(float(point[0]))
c_yli, c_yhi, w_yl, w_yh = data_y(float(point[1]))
if c_xli == -42 or c_yli == -42: # out of bounds
msg = "Point lies outside the grid domain"
raise DomainError(msg)
cell_vol = cell_volume(c_xli, c_yli)
data[..., c_xli, c_yli] += w_xl * w_yl * amount / cell_vol # type: ignore
cell_vol = cell_volume(c_xli, c_yhi)
data[..., c_xli, c_yhi] += w_xl * w_yh * amount / cell_vol # type: ignore
cell_vol = cell_volume(c_xhi, c_yli)
data[..., c_xhi, c_yli] += w_xh * w_yl * amount / cell_vol # type: ignore
cell_vol = cell_volume(c_xhi, c_yhi)
data[..., c_xhi, c_yhi] += w_xh * w_yh * amount / cell_vol # type: ignore
elif grid.num_axes == 3:
# specialize for 3-dimensional interpolation
data_x = grids.make_interpolation_axis_data(
grid=grid, axis=0, with_ghost_cells=with_ghost_cells
)
data_y = grids.make_interpolation_axis_data(
grid=grid, axis=1, with_ghost_cells=with_ghost_cells
)
data_z = grids.make_interpolation_axis_data(
grid=grid, axis=2, with_ghost_cells=with_ghost_cells
)
@self.compile_function
def insert(
data: InexactArray, point: FloatingArray, amount: NumberOrArray
) -> None:
"""Add an amount to a field at an interpolated position.
Args:
data (:class:`~numpy.ndarray`):
The values at the grid points
point (:class:`~numpy.ndarray`):
Coordinates of a single point in the grid coordinate system
amount (Number or :class:`~numpy.ndarray`):
The amount that will be added to the data. This value describes
an integrated quantity (given by the field value times the
discretization volume). This is important for consistency with
different discretizations and in particular grids with
non-uniform discretizations
"""
# determine surrounding points and their weights
c_xli, c_xhi, w_xl, w_xh = data_x(float(point[0]))
c_yli, c_yhi, w_yl, w_yh = data_y(float(point[1]))
c_zli, c_zhi, w_zl, w_zh = data_z(float(point[2]))
if c_xli == -42 or c_yli == -42 or c_zli == -42: # out of bounds
msg = "Point lies outside the grid domain"
raise DomainError(msg)
cell_vol = cell_volume(c_xli, c_yli, c_zli)
data[..., c_xli, c_yli, c_zli] += w_xl * w_yl * w_zl * amount / cell_vol # type: ignore
cell_vol = cell_volume(c_xli, c_yli, c_zhi)
data[..., c_xli, c_yli, c_zhi] += w_xl * w_yl * w_zh * amount / cell_vol # type: ignore
cell_vol = cell_volume(c_xli, c_yhi, c_zli)
data[..., c_xli, c_yhi, c_zli] += w_xl * w_yh * w_zl * amount / cell_vol # type: ignore
cell_vol = cell_volume(c_xli, c_yhi, c_zhi)
data[..., c_xli, c_yhi, c_zhi] += w_xl * w_yh * w_zh * amount / cell_vol # type: ignore
cell_vol = cell_volume(c_xhi, c_yli, c_zli)
data[..., c_xhi, c_yli, c_zli] += w_xh * w_yl * w_zl * amount / cell_vol # type: ignore
cell_vol = cell_volume(c_xhi, c_yli, c_zhi)
data[..., c_xhi, c_yli, c_zhi] += w_xh * w_yl * w_zh * amount / cell_vol # type: ignore
cell_vol = cell_volume(c_xhi, c_yhi, c_zli)
data[..., c_xhi, c_yhi, c_zli] += w_xh * w_yh * w_zl * amount / cell_vol # type: ignore
cell_vol = cell_volume(c_xhi, c_yhi, c_zhi)
data[..., c_xhi, c_yhi, c_zhi] += w_xh * w_yh * w_zh * amount / cell_vol # type: ignore
else:
msg = (
f"Compiled interpolation not implemented for dimension {grid.num_axes}"
)
raise NotImplementedError(msg)
return insert
[docs]
def make_pde_rhs(
self, eq: PDEBase, state: TField
) -> Callable[[NumericArray, float], NumericArray]:
"""Return a function for evaluating the right hand side of the PDE.
Args:
eq (:class:`~pde.pdes.base.PDEBase`):
The object describing the differential equation
state (:class:`~pde.fields.FieldBase`):
An example for the state from which information can be extracted
Returns:
Function returning deterministic part of the right hand side of the PDE
"""
# the following method is deprecated since 2026-03-02
try:
make_rhs = eq.make_pde_rhs_numba # type: ignore
except AttributeError:
pass # method is not implemented, which should be the default
else:
warnings.warn(
"`eq.make_pde_rhs_numba` method is deprecated. Implement "
"`eq.make_evolution_rate` instead.",
DeprecationWarning,
stacklevel=2,
)
return self.compile_function(make_rhs(state)) # type: ignore
try:
make_rhs = eq.make_evolution_rate
except (AttributeError, NotImplementedError) as err:
msg = (
f"Evolution rate is not implemented for the `{self.name}` backend. To "
"add the implementation, provide the method `make_evolution_rate` "
"that returns a function calculating the evolution rate."
)
raise NotImplementedError(msg) from err
return self.compile_function(make_rhs(state, backend=self))
[docs]
def make_expression_function(
self,
expression: ExpressionBase,
*,
single_arg: bool = False,
user_funcs: dict[str, Callable] | None = None,
) -> Callable[..., NumberOrArray]:
"""Return a function evaluating an expression.
Args:
expression (:class:`~pde.tools.expression.ExpressionBase`):
The expression that is converted to a function
single_arg (bool):
Determines whether the returned function accepts all variables in a
single argument as an array or whether all variables need to be
supplied separately.
user_funcs (dict):
Additional functions that can be used in the expression.
Returns:
function: the function
"""
import sympy
from sympy.printing.pycode import PythonCodePrinter
from ...tools.expressions import SPECIAL_FUNCTIONS
# collect all the user functions
user_functions = expression.user_funcs.copy()
if user_funcs is not None:
user_functions.update(user_funcs)
user_functions.update(SPECIAL_FUNCTIONS)
# transform the user functions, so they can be compiled using numba
def compile_func(func):
if isinstance(func, np.ufunc):
# this is a work-around that allows to compile numpy ufuncs
return self.compile_function(lambda *args: func(*args))
return self.compile_function(func)
user_functions = {k: compile_func(v) for k, v in user_functions.items()}
# initialize the printer that deals with numpy arrays correctly
class ListArrayPrinter(PythonCodePrinter):
"""Special sympy printer returning arrays as lists."""
def _print_ImmutableDenseNDimArray(self, arr):
arrays = ", ".join(f"{self._print(expr)}" for expr in arr)
return f"[{arrays}]"
printer = ListArrayPrinter(
{
"fully_qualified_modules": False,
"inline": True,
"allow_unknown_functions": True,
"user_functions": {k: k for k in user_functions},
}
)
# determine the list of variables that the function depends on
variables = (expression.vars,) if single_arg else tuple(expression.vars)
constants = tuple(expression.consts)
# turn the expression into a callable function
self._logger.info("Parse sympy expression `%s`", expression._sympy_expr)
func = sympy.lambdify(
variables + constants,
expression._sympy_expr,
modules=[user_functions, "numpy"],
printer=printer,
)
# Apply the constants if there are any. Note that we use this pattern of a
# partial function instead of replacing the constants in the sympy expression
# directly since sympy does not work well with numpy arrays.
if constants:
const_values = tuple(expression.consts[c] for c in constants)
func = register_jitable(func)
def result(*args):
return func(*args, *const_values)
else:
result = func
return self.compile_function(result)
def _make_expression_array(
self, expression: ExpressionBase, *, single_arg: bool = True
) -> Callable[[NumericArray, NumericArray | None], NumericArray]:
"""Compile the tensor expression such that a numpy array is returned.
Args:
expression (:class:`~pde.tools.expression.ExpressionBase`):
The expression that is converted to a function
single_arg (bool):
Whether the compiled function expects all arguments as a single array
or whether they are supplied individually.
"""
import builtins
import sympy
from sympy.utilities.lambdify import _get_namespace
if not isinstance(expression._sympy_expr, sympy.Array):
msg = "Expression must be an array"
raise TypeError(msg)
variables = ", ".join(v for v in expression.vars)
shape = expression._sympy_expr.shape
lines = [
f" out[{str((*idx, ...))[1:-1]}] = {expr}"
for idx, expr in np.ndenumerate(expression._sympy_expr)
]
# TODO: replace the np.ndindex with np.ndenumerate eventually. This does not
# work with numpy 1.18, so we have the work around using np.ndindex
# TODO: We should also support constants similar to ScalarExpressions. They
# could be written in separate lines and prepended to the actual code. However,
# we would need to make sure to print numpy arrays correctly.
if variables:
# the expression takes variables as input
if single_arg:
# the function takes a single input array
first_dim = 0 if len(expression.vars) == 1 else 1
code = "def _generated_function(arr, out=None):\n"
code += " arr = asarray(arr)\n"
code += f" {variables} = arr\n"
code += " if out is None:\n"
code += f" out = empty({shape} + arr.shape[{first_dim}:])\n"
else:
# the function takes each variables as an argument
code = f"def _generated_function({variables}, out=None):\n"
code += " if out is None:\n"
code += f" out = empty({shape} + shape({expression.vars[0]}))\n"
else:
# the expression is constant
if single_arg:
code = "def _generated_function(arr=None, out=None):\n"
else:
code = "def _generated_function(out=None):\n"
code += " if out is None:\n"
code += f" out = empty({shape})\n"
code += "\n".join(lines) + "\n"
code += " return out"
self._logger.debug("Code for `make_expression_array`: %s", code)
namespace = _get_namespace("numpy")
namespace["builtins"] = builtins
namespace.update(expression.user_funcs)
local_vars: dict[str, Any] = {}
exec(code, namespace, local_vars)
function = local_vars["_generated_function"]
return self.compile_function(function) # type: ignore
[docs]
def make_mpi_synchronizer(
self, operator: int | str = "MAX", mpi_run: bool = False
) -> Callable[[float], float]:
"""Return function that synchronizes values between multiple MPI processes.
Warning:
The default implementation does not synchronize anything. This is simply a
hook, which can be used by backends that support MPI
Args:
operator (str or int):
Flag determining how the value from multiple nodes is combined.
Possible values include "MAX", "MIN", and "SUM".
mpi_run (bool):
Whether MPI is actually used. If `False`, the method returns a no-op.
Returns:
Function that can be used to synchronize values across nodes
"""
return register_jitable( # type: ignore
super().make_mpi_synchronizer(operator=operator, mpi_run=mpi_run)
)
[docs]
def make_gaussian_noise(
self, field: TField, *, rng: np.random.Generator
) -> Callable[[], NumericArray]:
"""Create a function generating Gaussian white noise.
Args:
field (:class:`~pde.fields.base.FieldBase`):
An example for the state from which the grid and other information can
be extracted.
rng (:class:`~numpy.random.Generator`):
Random number generator (default: :func:`~numpy.random.default_rng()`).
Not used in this numba backend.
"""
data_shape: tuple[int, ...] = field.data.shape
@self.compile_function
def gaussian_noise() -> NumericArray:
"""Generate Gaussian white noise."""
return np.random.randn(*data_shape)
return gaussian_noise
[docs]
def make_stepper(self, solver: SolverBase, state: TField) -> StepperType:
"""Create a field-based stepping function for a given solver.
Args:
solver (:class:`~pde.solvers.base.SolverBase`):
The solver instance, which determines how the stepper is constructed
state (:class:`~pde.fields.base.FieldBase`):
An example for the state from which the grid and other information can
be extracted
Returns:
Function that can be called to advance the `state` from time `t_start` to
time `t_end`. The function call signature is `(state: numpy.ndarray,
t_start: float, t_end: float)`
"""
from ._solvers import make_inner_stepper
assert solver.backend == self
inner_stepper = make_inner_stepper(solver, state)
# We don't access self.pde directly since we might want to reuse the solver
# infrastructure for more general cases where a PDE is not defined.
def stepper(state: TField, t_start: float, t_end: float) -> float:
"""Advance `state` by executing the backend-level stepping function."""
# call the backend-level stepping function with field data directly
return inner_stepper(state.data, t_start, t_end)
return stepper # type: ignore