r"""This module handles the boundaries of all axes of a grid.
.. autosummary::
:nosignatures:
~BoundariesBase
~BoundariesList
~BoundariesSetter
~set_default_bc
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import itertools
import logging
import warnings
from collections.abc import Iterator, Sequence
from typing import Any, Callable, Union
import numpy as np
from numba.extending import register_jitable
from ... import config
from ...tools.numba import jit
from ...tools.typing import GhostCellSetter
from ..base import GridBase, PeriodicityError
from .axis import BoundaryAxisBase, BoundaryPair, BoundaryPairData, get_boundary_axis
from .local import BCBase, BCDataError, BoundaryData
_logger = logging.getLogger(__name__)
""":class:`logging.Logger`: Logger instance."""
BoundariesData = Union[
BoundaryPairData, Sequence[BoundaryPairData], Callable, "BoundariesBase"
]
BC_LOCAL_KEYS = ["type", "value"] + list(BCBase._conditions.keys())
def _is_local_bc_data(data: dict[str, Any]) -> bool:
"""Tries to identify whether data specifies a local boundary condition."""
return any(key in data for key in BC_LOCAL_KEYS)
[docs]
class BoundariesBase:
"""Base class keeping information about how to set conditions on all boundaries."""
[docs]
def set_ghost_cells(self, data_full: np.ndarray, *, args=None) -> None:
"""Set the ghost cells for all boundaries.
Args:
data_full (:class:`~numpy.ndarray`):
The full field data including ghost points
set_corners (bool):
Determines whether the corner cells are set using interpolation
args:
Additional arguments that might be supported by special boundary
conditions.
"""
raise NotImplementedError
[docs]
def make_ghost_cell_setter(self) -> GhostCellSetter:
"""Return function that sets the ghost cells on a full array.
Returns:
Callable with signature :code:`(data_full: np.ndarray, args=None)`, which
sets the ghost cells of the full data, potentially using additional
information in `args` (e.g., the time `t` during solving a PDE)
"""
raise NotImplementedError
[docs]
@classmethod
def from_data(cls, data, **kwargs) -> BoundariesBase:
r"""Creates all boundaries from given data.
Args:
data (str or dict or callable):
Data that describes the boundaries. If this is a callable, we create
:class:`~pde.grids.boundaries.axes.BoundariesSetter`. In all other,
cases :class:`~pde.grids.boundaries.axes.BoundariesList` is created and
`data` can either be string denoting a specific boundary condition
applied to all sides or a dictionary with detailed information.
**kwargs:
In some cases additional data can be specified or is even required. For
instance, :class:`~pde.grids.boundaries.axes.BoundariesList` expects
a `grid` (:class:`~pde.grids.base.GridBase`): to which the boundary
condition are associated, and it can use a `rank` (int), which sets
the tensorial rank of the field for this boundary condition.
"""
# check whether this is already of the correct type
if isinstance(data, BoundariesBase):
return data.__class__.from_data(data, **kwargs)
# best guess based on the data:
if callable(data):
return BoundariesSetter.from_data(data)
else:
return BoundariesList.from_data(data, **kwargs)
[docs]
@classmethod
def get_help(cls) -> str:
"""Return information on how boundary conditions can be set."""
return (
'Boundary conditions for each axis are set using a dictionary: {"x": bc_x, '
'"y-": bc_y_lower, "y+": bc_y_upper}. If the associated axis is periodic, '
'the boundary condition needs to be set to "periodic". Otherwise, '
+ BCBase.get_help()
)
[docs]
class BoundariesList(BoundariesBase):
"""Defines boundary conditions for all axes individually."""
def __init__(self, boundaries: list[BoundaryAxisBase]):
"""Initialize with a list of boundaries."""
if len(boundaries) == 0:
raise BCDataError("List of boundaries must not be empty")
# extract grid
self.grid = boundaries[0].grid
# check dimension
if len(boundaries) != self.grid.num_axes:
raise BCDataError(f"Need boundary conditions for {self.grid.num_axes} axes")
# check consistency
for axis, boundary in enumerate(boundaries):
if boundary.grid != self.grid:
raise BCDataError("BoundariesList are not defined on the same grid")
if boundary.axis != axis:
raise BCDataError(
"BoundariesList need to be ordered like the respective axes"
)
if boundary.periodic != self.grid.periodic[axis]:
raise PeriodicityError(
"Periodicity specified in the boundaries conditions is not "
f"compatible with the grid ({boundary.periodic} != "
f"{self.grid.periodic[axis]} for axis {axis})"
)
# create the list of boundaries
self._axes: list[BoundaryAxisBase] = boundaries
@classmethod
def _parse_from_dict(
cls, data: dict[str, Any], *, grid: GridBase, rank: int = 0, **kwargs
) -> list[BoundaryAxisBase]:
"""Creates all boundaries from given data in dictionary format.
Args:
data (dict):
Data that describes the boundaries using a dictionary.
grid (:class:`~pde.grids.base.GridBase`):
The grid with which the boundary condition is associated
rank (int):
The tensorial rank of the field for this boundary condition
"""
if config["boundaries.accept_lists"] and ("low" in data or "high" in data):
# check for legacy format that has been deprecated on 2024-11-23
warnings.warn(
"Deprecated format for boundary conditions." + cls.get_help(),
DeprecationWarning,
)
return [
get_boundary_axis(grid, i, data, rank=rank)
for i in range(grid.num_axes)
]
if _is_local_bc_data(data):
# detected identifier signifying that a single condition was specified
# -> create the same boundary condition for all axes
return [
get_boundary_axis(grid, i, data, rank=rank)
for i in range(grid.num_axes)
]
# else: assume that boundary conditions are given for separate axes
# initialize boundary data with wildcard default
data = data.copy() # we want to modify this dictionary
bc_all = data.pop("*", None)
bc_data = [[bc_all, bc_all] for _ in range(grid.num_axes)]
bc_seen = [[False, False] for _ in range(grid.num_axes)]
# check specific boundary conditions for all axes
for ax, ax_name in enumerate(grid.axes):
# overwrite boundaries whose axes are given
if bc_axes := data.pop(ax_name, None):
bc_data[ax] = [bc_axes, bc_axes]
# overwrite specific conditions for one side
if bc_lower := data.pop(ax_name + "-", None):
bc_data[ax][0] = bc_lower
bc_seen[ax][0] = True
if bc_upper := data.pop(ax_name + "+", None):
bc_data[ax][1] = bc_upper
bc_seen[ax][1] = True
# overwrite conditions for named boundaries
for name, (ax, upper) in grid.boundary_names.items():
if bc := data.pop(name, None):
if bc_seen[ax][upper]:
_logger.warning("Duplicate BC data for axis %s%s", ax, "-+"[upper])
bc_data[ax][upper] = bc
bc_seen[ax][upper] = True
# warn if some keys were left over
if data:
_logger.warning("Didn't use BC data from %s", list(data.keys()))
# find boundary conditions that have not been specified
bcs_unspecified = []
for ax, bc_ax in enumerate(bc_data):
for i, bc_side in enumerate(bc_ax):
if bc_side is None:
bcs_unspecified.append(grid.axes[ax] + "-+"[i])
if bcs_unspecified:
_logger.warning("Didn't specified BCs for %s", bcs_unspecified)
# create the actual boundary conditions
_logger.debug("Parsed BCs as %s", bc_data)
bcs = [
get_boundary_axis(grid, i, tuple(boundary), rank=rank)
for i, boundary in enumerate(bc_data)
]
return bcs
[docs]
@classmethod
def from_data( # type: ignore
cls, data, *, grid: GridBase, rank: int = 0, **kwargs
) -> BoundariesList:
"""Creates all boundaries from given data.
Args:
data (str or dict):
Data that describes the boundaries. This should either be a string
naming a boundary condition or a dictionary with detailed information.
grid (:class:`~pde.grids.base.GridBase`):
The grid with which the boundary condition is associated
rank (int):
The tensorial rank of the field for this boundary condition
"""
# distinguish different possible data formats based on their type
if isinstance(data, BoundariesList):
# boundaries are already in the correct format
if data.grid._mesh is not None:
# we need to exclude this case since otherwise we get into a rabbit hole
# where it is not clear what grid boundary conditions belong to. The
# idea is that users only create boundary conditions for the full grid
# and that the splitting onto subgrids is only done once, automatically,
# and without involving calls to `from_data`
raise ValueError("Cannot create MPI subgrid BC from data")
if data.grid != grid:
raise ValueError(
"The grid of the supplied boundary condition is incompatible with "
f"the current grid ({data.grid!r} != {grid!r})"
)
data.check_value_rank(rank)
return data
elif isinstance(data, BoundariesBase):
# data seems to be given as another base class, which indicates problems
raise TypeError(
"Can only use type `BoundariesList`. Use `BoundariesBase.from_data` "
"for more general data."
)
elif isinstance(data, str):
# a string implies the same boundary condition for all axes
if data.startswith("auto_periodic_"):
# initialize boundary condition that could be periodic
bc = data[len("auto_periodic_") :]
bcs = [
get_boundary_axis(grid, i, "periodic" if per else bc, rank=rank)
for i, per in enumerate(grid.periodic)
]
else:
# assume the same boundary condition for all axes
bcs = [
get_boundary_axis(grid, i, data, rank=rank)
for i in range(grid.num_axes)
]
elif isinstance(data, dict):
# dictionaries can either specify boundary conditions for separate sides or
# they can specify a single boundary condition that is used on all sides
bcs = cls._parse_from_dict(data, grid=grid, rank=rank)
elif config["boundaries.accept_lists"] and hasattr(data, "__len__"):
# sequences have been deprecated on 2024-11-23
warnings.warn(
"Deprecated format for boundary conditions." + cls.get_help(),
DeprecationWarning,
)
if len(data) == grid.num_axes:
# assume that data is given for each boundary
bcs = [
get_boundary_axis(grid, i, boundary, rank=rank)
for i, boundary in enumerate(data)
]
elif grid.num_axes == 1 and len(data) == 2:
# special case where the two sides can be specified directly
bcs = [get_boundary_axis(grid, 0, data, rank=rank)]
else:
raise BCDataError(
f"Got {len(data)} boundary conditions, but grid has "
f"{grid.num_axes} axes." + cls.get_help()
)
else:
# unknown format
raise BCDataError(
f"Unsupported boundary format: `{data}`. " + cls.get_help()
)
return BoundariesList(bcs)
def __str__(self):
items = ", ".join(str(item) for item in self)
return f"[{items}]"
def __len__(self):
return len(self._axes)
def __iter__(self) -> Iterator[BoundaryAxisBase]:
yield from self._axes
def __eq__(self, other):
if not isinstance(other, BoundariesList):
return NotImplemented
return self.grid == other.grid and self._axes == other._axes
def __ne__(self, other):
if not isinstance(other, BoundariesList):
return NotImplemented
return self.grid != other.grid or self._axes != other._axes
@property
def boundaries(self) -> Iterator[BCBase]:
"""Iterator over all non-periodic boundaries."""
for boundary_axis in self._axes: # iterate all axes
if not boundary_axis.periodic: # skip periodic axes
yield from boundary_axis
[docs]
def check_value_rank(self, rank: int) -> None:
"""Check whether the values at the boundaries have the correct rank.
Args:
rank (int):
The tensorial rank of the field for this boundary condition
Throws:
RuntimeError: if any value does not have rank `rank`
"""
for b in self._axes:
b.check_value_rank(rank)
[docs]
def copy(self) -> BoundariesList:
"""Create a copy of the current boundaries."""
return self.__class__([bc.copy() for bc in self._axes])
@property
def periodic(self) -> list[bool]:
""":class:`~numpy.ndarray`: a boolean array indicating which dimensions are
periodic according to the boundary conditions."""
return self.grid.periodic
def __getitem__(self, index):
"""Extract specific boundary conditions.
Args:
index (int or str):
Index can either be a number or an axes name, indicating the axes of
which conditions are returned. Alternatively, `index` can be a named
boundary whose conditions will then be returned
"""
if isinstance(index, str):
# assume that the index is a known identifier
if index in self.grid.boundary_names:
# found a known boundary
axis, upper = self.grid.boundary_names[index]
return self._axes[axis][upper]
# check all axes
for ax, ax_name in enumerate(self.grid.axes):
if index == ax_name:
return self._axes[ax]
if index == ax_name + "-":
return self._axes[ax][False]
if index == ax_name + "+":
return self._axes[ax][True]
# found nothing
raise KeyError(index)
else:
# handle all other cases, in particular integer indices
return self._axes[index]
def __setitem__(self, index, data) -> None:
"""Set specific boundary conditions.
Args:
index (int or str):
Index can either be a number or an axes name, indicating the axes of
which conditions are returned. Alternatively, `index` can be a named
boundary whose conditions will then be returned
data:
Data describing the boundary conditions for this axis or side
"""
if isinstance(index, str):
# assume that the index is a known identifier
if index in self.grid.boundary_names:
# set a specific boundary side
ax, upper = self.grid.boundary_names[index]
self._axes[ax][upper] = data
else:
# check all axes
for ax, ax_name in enumerate(self.grid.axes):
if index == ax_name:
# found just the axis -> set both sides
self._axes[ax] = get_boundary_axis(
grid=self.grid, axis=ax, data=data, rank=self[ax].rank
)
break
if index == ax_name + "-":
# found lower part of the axis
self._axes[ax][False] = data
break
if index == ax_name + "+":
# found upper part of the axis
self._axes[ax][True] = data
break
else:
raise KeyError(index)
else:
# handle all other cases, in particular integer indices
self._axes[index] = data
[docs]
def get_mathematical_representation(self, field_name: str = "C") -> str:
"""Return mathematical representation of the boundary condition."""
result: list[str] = []
for b in self._axes:
try:
result.extend(b.get_mathematical_representation(field_name))
except NotImplementedError:
axis_name = self.grid.axes[b.axis]
result.append(f"Representation not implemented for axis {axis_name}")
return "\n".join(result)
[docs]
def set_ghost_cells(
self, data_full: np.ndarray, *, set_corners: bool = False, args=None
) -> None:
"""Set the ghost cells for all boundaries.
Args:
data_full (:class:`~numpy.ndarray`):
The full field data including ghost points
set_corners (bool):
Determines whether the corner cells are set using interpolation
args:
Additional arguments that might be supported by special boundary
conditions.
"""
for b in self:
b.set_ghost_cells(data_full, args=args)
if set_corners and self.grid.num_axes >= 2:
d = data_full # abbreviation
nxt = [1, -2] # maps 0 to 1 and -1 to -2 to obtain neighboring cells
if self.grid.num_axes == 2:
# iterate all corners
for i, j in itertools.product([0, -1], [0, -1]):
d[..., i, j] = (d[..., nxt[i], j] + d[..., i, nxt[j]]) / 2
elif self.grid.num_axes == 3:
# iterate all edges
for i, j in itertools.product([0, -1], [0, -1]):
d[..., :, i, j] = (+d[..., :, nxt[i], j] + d[..., :, i, nxt[j]]) / 2
d[..., i, :, j] = (+d[..., nxt[i], :, j] + d[..., i, :, nxt[j]]) / 2
d[..., i, j, :] = (+d[..., nxt[i], j, :] + d[..., i, nxt[j], :]) / 2
# iterate all corners
for i, j, k in itertools.product(*[[0, -1]] * 3):
d[..., i, j, k] = (
d[..., nxt[i], j, k]
+ d[..., i, nxt[j], k]
+ d[..., i, j, nxt[k]]
) / 3
elif self.grid.num_axes > 3:
raise NotImplementedError(
f"Can't interpolate corners for grid with {self.grid.num_axes} axes"
)
[docs]
def make_ghost_cell_setter(self) -> GhostCellSetter:
"""Return function that sets the ghost cells on a full array."""
ghost_cell_setters = tuple(b.make_ghost_cell_setter() for b in self)
# TODO: use numba.literal_unroll
# # get the setters for all axes
#
# from pde.tools.numba import jit
#
# @jit
# def set_ghost_cells(data_full: np.ndarray, args=None) -> None:
# for f in nb.literal_unroll(ghost_cell_setters):
# f(data_full, args=args)
#
# return set_ghost_cells
def chain(
fs: Sequence[GhostCellSetter], inner: GhostCellSetter | None = None
) -> GhostCellSetter:
"""Helper function composing setters of all axes recursively."""
first, rest = fs[0], fs[1:]
if inner is None:
@register_jitable
def wrap(data_full: np.ndarray, args=None) -> None:
first(data_full, args=args)
else:
@register_jitable
def wrap(data_full: np.ndarray, args=None) -> None:
inner(data_full, args=args)
first(data_full, args=args)
if rest:
return chain(rest, wrap)
else:
return wrap # type: ignore
return chain(ghost_cell_setters)
[docs]
class BoundariesSetter(BoundariesBase):
"""Represents a function that sets ghost cells to determine boundary conditions.
The function must have accept a :class:`~numpy.ndarray`, which contains the full
field data including the ghost points, and a second, optional argument, which is a
dictionary containing additional parameters, like the current time point `t` in case
of a simulation.
Example:
Here is an example for a simple boundary setter, which sets specific boundary
conditions in the x-direction and periodic conditions in the y-direction of a
grid with two axes. Note that this boundary condition will not work for grids
with other number of axes and no additional checks are performed.
.. code-block:: python
def setter(data, args=None):
data[0, :] = data[1, :] # Vanishing derivative at left side
data[-1, :] = 2 - data[-2, :] # Fixed value `1` at right side
data[:, 0] = data[:, -2] # Periodic BC at top
data[:, -1] = data[:, 1] # Periodic BC at bottom
"""
def __init__(self, setter: GhostCellSetter):
self._setter = setter
[docs]
@classmethod
def from_data(cls, data, **kwargs) -> BoundariesSetter:
"""Creates all boundaries from given data.
Args:
data (callable):
Function that sets the ghost cells
"""
# check whether this is already the correct class
if isinstance(data, BoundariesSetter):
# boundaries are already in the correct format
return data
elif isinstance(data, BoundariesBase):
raise TypeError(
"Can only use type `BoundariesSetter`. Use `BoundariesBase.from_data` "
"for more general data."
)
return BoundariesSetter(data)
[docs]
def set_ghost_cells(self, data_full: np.ndarray, *, args=None) -> None:
"""Set the ghost cells for all boundaries.
Args:
data_full (:class:`~numpy.ndarray`):
The full field data including ghost points
set_corners (bool):
Determines whether the corner cells are set using interpolation
args:
Additional arguments that might be supported by special boundary
conditions.
"""
self._setter(data_full, args=args)
[docs]
def make_ghost_cell_setter(self) -> GhostCellSetter:
"""Return function that sets the ghost cells on a full array.
Returns:
Callable with signature :code:`(data_full: np.ndarray, args=None)`, which
sets the ghost cells of the full data, potentially using additional
information in `args` (e.g., the time `t` during solving a PDE)
"""
return jit(self._setter) # type: ignore
[docs]
def set_default_bc(
bc_data: BoundariesData | None, default: BoundaryData
) -> BoundariesData:
"""Set a default boundary condition.
Args:
bc_data (str or list or tuple or dict or callable):
User-supplied data specifying boundary conditions
default:
Default condition that should be imposed where user conditions are not given
Returns:
Modified `bc_data` with added defaults
"""
if bc_data is None:
bc_data = default
elif isinstance(bc_data, dict) and not _is_local_bc_data(bc_data):
# set default when boundary conditions for axes are specified
bc_data.setdefault("*", default)
return bc_data