"""Compiled functions for dealing with grids.
.. autosummary::
:nosignatures:
get_grid_numba_type
make_cell_volume_getter
make_interpolation_axis_data
make_single_interpolator
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from numba.extending import register_jitable
from ...grids import DomainError, GridBase
from .utils import jit
if TYPE_CHECKING:
from collections.abc import Callable
from ...tools.typing import (
CellVolume,
FloatingArray,
Number,
NumberOrArray,
NumericArray,
)
[docs]
def get_grid_numba_type(grid: GridBase, rank: int = 0) -> str:
"""Return numba type corresponding to a particular grid.
Args:
grid (GridBase):
The grid for which we determine the type
rank (int):
The rank of the data stored in the grid
Returns:
_type_: _description_
"""
dim = grid.num_axes + rank
return "f8[" + ", ".join([":"] * dim) + "]"
[docs]
def make_cell_volume_getter(grid: GridBase, *, flat_index: bool = False) -> CellVolume:
"""Return a compiled function returning the volume of a grid cell.
Args:
grid (:class:`~pde.grid.base.GridBase`):
Grid for which the function is defined
flat_index (bool):
When True, cell_volumes are indexed by a single integer into the
flattened array.
Returns:
function: returning the volume of the chosen cell
"""
if grid.cell_volume_data is not None and all(
np.isscalar(d) for d in grid.cell_volume_data
):
# all cells have the same volume
cell_volume = np.prod(grid.cell_volume_data) # type: ignore
@jit
def get_cell_volume(*args) -> float:
return cell_volume # type: ignore
else:
# some cells have a different volume
cell_volumes = grid.cell_volumes
if flat_index:
@jit
def get_cell_volume(idx: int) -> float:
return cell_volumes.flat[idx] # type: ignore
else:
@jit
def get_cell_volume(*args) -> float:
return cell_volumes[args] # type: ignore
return get_cell_volume
[docs]
def make_interpolation_axis_data(
grid: GridBase,
axis: int,
*,
with_ghost_cells: bool = False,
cell_coords: bool = False,
) -> Callable[[float], tuple[int, int, float, float]]:
"""Factory for obtaining interpolation information.
Args:
grid (:class:`~pde.grid.base.GridBase`):
Grid for which the interpolator is defined
axis (int):
The axis along which interpolation is performed
with_ghost_cells (bool):
Flag indicating that the interpolator should work on the full data array
that includes values for the ghost points. If this is the case, the
boundaries are not checked and the coordinates are used as is.
cell_coords (bool):
Flag indicating whether points are given in cell coordinates or actual
point coordinates.
Returns:
callable: A function that is called with a coordinate value for the axis.
The function returns the indices of the neighboring support points as well
as the associated weights.
"""
# obtain information on how this axis is discretized
size = grid.shape[axis]
periodic = grid.periodic[axis]
lo = grid.axes_bounds[axis][0]
dx = grid.discretization[axis]
@register_jitable
def get_axis_data(coord: float) -> tuple[int, int, float, float]:
"""Determines data for interpolating along one axis."""
# determine the index of the left cell and the fraction toward the right
if cell_coords:
c_l, d_l = divmod(coord, 1.0)
else:
c_l, d_l = divmod((coord - lo) / dx - 0.5, 1.0)
# determine the indices of the two cells whose value affect interpolation
if periodic:
# deal with periodic domains, which is easy
c_li = int(c_l) % size # left support point
c_hi = (c_li + 1) % size # right support point
elif with_ghost_cells:
# deal with edge cases using the values of ghost cells
if -0.5 <= c_l + d_l <= size - 0.5: # in bulk part of domain
c_li = int(c_l) # left support point
c_hi = c_li + 1 # right support point
else:
return -42, -42, 0.0, 0.0 # indicates out of bounds
else:
# deal with edge cases using nearest-neighbor interpolation at boundary
if 0 <= c_l + d_l < size - 1: # in bulk part of domain
c_li = int(c_l) # left support point
c_hi = c_li + 1 # right support point
elif size - 1 <= c_l + d_l <= size - 0.5: # close to upper boundary
c_li = c_hi = int(c_l) # both support points close to boundary
# This branch also covers the special case, where size == 1 and data
# is evaluated at the only support point (c_l == d_l == 0.)
elif -0.5 <= c_l + d_l <= 0: # close to lower boundary
c_li = c_hi = int(c_l) + 1 # both support points close to boundary
else:
return -42, -42, 0.0, 0.0 # indicates out of bounds
# determine the weights of the two cells
w_l, w_h = 1 - d_l, d_l
# set small weights to zero. If this is not done, invalid data at the corner
# of the grid (where two rows of ghost cells intersect) could be accessed.
# If this random data is very large, e.g., 1e100, it contributes
# significantly, even if the weight is low, e.g., 1e-16.
if w_l < 1e-15:
w_l = 0
if w_h < 1e-15:
w_h = 0
# shift points to allow accessing data with ghost points
if with_ghost_cells:
c_li += 1
c_hi += 1
return c_li, c_hi, w_l, w_h
return get_axis_data # type: ignore
[docs]
def make_single_interpolator(
grid: GridBase,
*,
fill: Number | None = None,
with_ghost_cells: bool = False,
cell_coords: bool = False,
) -> Callable[[NumericArray, FloatingArray], NumericArray]:
"""Return a compiled function for linear interpolation on the grid.
Args:
grid (:class:`~pde.grid.base.GridBase`):
Grid for which the interpolator is defined
fill (Number, optional):
Determines how values out of bounds are handled. If `None`, `ValueError`
is raised when out-of-bounds points are requested. Otherwise, the given
value is returned.
with_ghost_cells (bool):
Flag indicating that the interpolator should work on the full data array
that includes values for the ghost points. If this is the case, the
boundaries are not checked and the coordinates are used as is.
cell_coords (bool):
Flag indicating whether points are given in cell coordinates or actual
point coordinates.
Returns:
callable: A function which returns interpolated values when called with
arbitrary positions within the space of the grid. The signature of this
function is (data, point), where `data` is the numpy array containing the
field data and position denotes the position in grid coordinates.
"""
args = {"with_ghost_cells": with_ghost_cells, "cell_coords": cell_coords}
if grid.num_axes == 1:
# specialize for 1-dimensional interpolation
data_x = make_interpolation_axis_data(grid=grid, axis=0, **args)
@jit
def interpolate_single(
data: NumericArray, point: FloatingArray
) -> NumberOrArray:
"""Obtain interpolated value of data at a point.
Args:
data (:class:`~numpy.ndarray`):
A 1d array of valid values at the grid points
point (:class:`~numpy.ndarray`):
Coordinates of a single point in the grid coordinate system
Returns:
:class:`~numpy.ndarray`: The interpolated value at the point
"""
c_li, c_hi, w_l, w_h = data_x(float(point[0]))
if c_li == -42: # out of bounds
if fill is None: # outside the domain
print("POINT", point)
msg = "Point lies outside the grid domain"
raise DomainError(msg)
return fill
# do the linear interpolation
return w_l * data[..., c_li] + w_h * data[..., c_hi]
elif grid.num_axes == 2:
# specialize for 2-dimensional interpolation
data_x = make_interpolation_axis_data(grid=grid, axis=0, **args)
data_y = make_interpolation_axis_data(grid=grid, axis=1, **args)
@jit
def interpolate_single(
data: NumericArray, point: FloatingArray
) -> NumberOrArray:
"""Obtain interpolated value of data at a point.
Args:
data (:class:`~numpy.ndarray`):
A 2d array of valid values at the grid points
point (:class:`~numpy.ndarray`):
Coordinates of a single point in the grid coordinate system
Returns:
:class:`~numpy.ndarray`: The interpolated value at the point
"""
# determine surrounding points and their weights
c_xli, c_xhi, w_xl, w_xh = data_x(float(point[0]))
c_yli, c_yhi, w_yl, w_yh = data_y(float(point[1]))
if c_xli == -42 or c_yli == -42: # out of bounds
if fill is None: # outside the domain
print("POINT", point)
msg = "Point lies outside the grid domain"
raise DomainError(msg)
return fill
# do the linear interpolation
return (
w_xl * w_yl * data[..., c_xli, c_yli]
+ w_xl * w_yh * data[..., c_xli, c_yhi]
+ w_xh * w_yl * data[..., c_xhi, c_yli]
+ w_xh * w_yh * data[..., c_xhi, c_yhi]
)
elif grid.num_axes == 3:
# specialize for 3-dimensional interpolation
data_x = make_interpolation_axis_data(grid=grid, axis=0, **args)
data_y = make_interpolation_axis_data(grid=grid, axis=1, **args)
data_z = make_interpolation_axis_data(grid=grid, axis=2, **args)
@jit
def interpolate_single(
data: NumericArray, point: FloatingArray
) -> NumberOrArray:
"""Obtain interpolated value of data at a point.
Args:
data (:class:`~numpy.ndarray`):
A 2d array of valid values at the grid points
point (:class:`~numpy.ndarray`):
Coordinates of a single point in the grid coordinate system
Returns:
:class:`~numpy.ndarray`: The interpolated value at the point
"""
# determine surrounding points and their weights
c_xli, c_xhi, w_xl, w_xh = data_x(float(point[0]))
c_yli, c_yhi, w_yl, w_yh = data_y(float(point[1]))
c_zli, c_zhi, w_zl, w_zh = data_z(float(point[2]))
if c_xli == -42 or c_yli == -42 or c_zli == -42: # out of bounds
if fill is None: # outside the domain
print("POINT", point)
msg = "Point lies outside the grid domain"
raise DomainError(msg)
return fill
# do the linear interpolation
return (
w_xl * w_yl * w_zl * data[..., c_xli, c_yli, c_zli]
+ w_xl * w_yl * w_zh * data[..., c_xli, c_yli, c_zhi]
+ w_xl * w_yh * w_zl * data[..., c_xli, c_yhi, c_zli]
+ w_xl * w_yh * w_zh * data[..., c_xli, c_yhi, c_zhi]
+ w_xh * w_yl * w_zl * data[..., c_xhi, c_yli, c_zli]
+ w_xh * w_yl * w_zh * data[..., c_xhi, c_yli, c_zhi]
+ w_xh * w_yh * w_zl * data[..., c_xhi, c_yhi, c_zli]
+ w_xh * w_yh * w_zh * data[..., c_xhi, c_yhi, c_zhi]
)
else:
msg = f"Compiled interpolation not implemented for dimension {grid.num_axes}"
raise NotImplementedError(msg)
return interpolate_single # type: ignore