Source code for pde.backends.numba.operators.common

"""Common functions that are used by many operators.

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Literal

from .... import get_backend
from ..utils import jit

if TYPE_CHECKING:
    from ....grids.base import GridBase
    from ....tools.typing import NumericArray, OperatorImplType
    from ..backend import NumbaBackend


[docs] def make_derivative( grid: GridBase, axis: int = 0, *, method: Literal["central", "forward", "backward"] = "central", backend: NumbaBackend | None = None, ) -> OperatorImplType: """Make a derivative operator along a single axis using numba compilation. Args: grid (:class:`~pde.grids.base.GridBase`): The grid for which the operator is created axis (int): The axis along which the derivative will be taken method (str): The method for calculating the derivative. Possible values are 'central', 'forward', and 'backward'. backend (:class:`~pde.backends.numba.backend.NumbaBackend`): References to the backend to read configuration details Returns: A function that can be applied to an full array of values including those at ghost cells. The result will be an array of the same shape containing the actual derivatives at the valid (interior) grid points. """ if backend is None: backend = get_backend("numba") # type: ignore if method not in {"central", "forward", "backward"}: msg = f"Unknown derivative type `{method}`" raise ValueError(msg) shape = grid.shape num_axes = len(shape) dx = grid.discretization[axis] if axis == 0: di, dj, dk = 1, 0, 0 elif axis == 1: di, dj, dk = 0, 1, 0 elif axis == 2: di, dj, dk = 0, 0, 1 else: msg = f"Derivative for {axis:d} dimensions" raise NotImplementedError(msg) if num_axes == 1: @jit(backend=backend) def diff(arr: NumericArray, out: NumericArray) -> None: """Calculate derivative of 1d array `arr`""" for i in range(1, shape[0] + 1): if method == "central": out[i - 1] = (arr[i + 1] - arr[i - 1]) / (2 * dx) elif method == "forward": out[i - 1] = (arr[i + 1] - arr[i]) / dx elif method == "backward": out[i - 1] = (arr[i] - arr[i - 1]) / dx elif num_axes == 2: @jit(backend=backend) def diff(arr: NumericArray, out: NumericArray) -> None: """Calculate derivative of 2d array `arr`""" for i in range(1, shape[0] + 1): for j in range(1, shape[1] + 1): arr_l = arr[i - di, j - dj] arr_r = arr[i + di, j + dj] if method == "central": out[i - 1, j - 1] = (arr_r - arr_l) / (2 * dx) elif method == "forward": out[i - 1, j - 1] = (arr_r - arr[i, j]) / dx elif method == "backward": out[i - 1, j - 1] = (arr[i, j] - arr_l) / dx elif num_axes == 3: @jit(backend=backend) def diff(arr: NumericArray, out: NumericArray) -> None: """Calculate derivative of 3d array `arr`""" for i in range(1, shape[0] + 1): for j in range(1, shape[1] + 1): for k in range(1, shape[2] + 1): arr_l = arr[i - di, j - dj, k - dk] arr_r = arr[i + di, j + dj, k + dk] if method == "central": out[i - 1, j - 1, k - 1] = (arr_r - arr_l) / (2 * dx) elif method == "forward": out[i - 1, j - 1, k - 1] = (arr_r - arr[i, j, k]) / dx elif method == "backward": out[i - 1, j - 1, k - 1] = (arr[i, j, k] - arr_l) / dx else: msg = f"Numba derivative operator not implemented for {num_axes:d} axes" raise NotImplementedError(msg) return diff
[docs] def make_derivative2( grid: GridBase, axis: int = 0, *, backend: NumbaBackend | None = None, ) -> OperatorImplType: """Make a second-order derivative operator along a single axis. Args: grid (:class:`~pde.grids.base.GridBase`): The grid for which the operator is created axis (int): The axis along which the derivative will be taken backend (:class:`~pde.backends.numba.backend.NumbaBackend`): References to the backend to read configuration details Returns: A function that can be applied to an full array of values including those at ghost cells. The result will be an array of the same shape containing the actual derivatives at the valid (interior) grid points. """ if backend is None: backend = get_backend("numba") # type: ignore shape = grid.shape num_axes = len(shape) scale = 1 / grid.discretization[axis] ** 2 if axis == 0: di, dj, dk = 1, 0, 0 elif axis == 1: di, dj, dk = 0, 1, 0 elif axis == 2: di, dj, dk = 0, 0, 1 else: msg = f"Derivative for {axis:d} dimensions" raise NotImplementedError(msg) if num_axes == 1: @jit(backend=backend) def diff(arr: NumericArray, out: NumericArray) -> None: """Calculate derivative of 1d array `arr`""" for i in range(1, shape[0] + 1): out[i - 1] = (arr[i + 1] - 2 * arr[i] + arr[i - 1]) * scale elif num_axes == 2: @jit(backend=backend) def diff(arr: NumericArray, out: NumericArray) -> None: """Calculate derivative of 2d array `arr`""" for i in range(1, shape[0] + 1): for j in range(1, shape[1] + 1): arr_l = arr[i - di, j - dj] arr_r = arr[i + di, j + dj] out[i - 1, j - 1] = (arr_r - 2 * arr[i, j] + arr_l) * scale elif num_axes == 3: @jit(backend=backend) def diff(arr: NumericArray, out: NumericArray) -> None: """Calculate derivative of 3d array `arr`""" for i in range(1, shape[0] + 1): for j in range(1, shape[1] + 1): for k in range(1, shape[2] + 1): arr_l = arr[i - di, j - dj, k - dk] arr_r = arr[i + di, j + dj, k + dk] out[i - 1, j - 1, k - 1] = ( arr_r - 2 * arr[i, j, k] + arr_l ) * scale else: msg = f"Numba derivative operator not implemented for {num_axes:d} axes" raise NotImplementedError(msg) return diff