"""Provides support for mypy type checking of the module.
.. autosummary::
:nosignatures:
JaxOperatorType
JaxDataSetter
JaxGhostCellSetter
JaxVirtualPointEvaluator
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol
if TYPE_CHECKING:
from jax import Array
[docs]
class JaxOperatorType(Protocol):
"""An operator that acts on an array."""
def __call__(
self,
arr: Array,
args: dict[str, Any] | None = None,
) -> Array:
"""Evaluate the operator.
Args:
arr: Input array
args: Additional arguments (optional)
Returns:
Output array
"""
[docs]
class JaxDataSetter(Protocol):
def __call__(self, data_valid: Array, args=None) -> Array:
"""Set the valid data cells (and potentially BCs).
Args:
data_valid: Valid data array
args: Additional arguments (optional)
Returns:
Full data array including ghost cells
"""
[docs]
class JaxGhostCellSetter(Protocol):
def __call__(self, data_full: Array, args=None) -> Array:
"""Set the ghost cells.
Args:
data_full: Full data array including ghost cells
args: Additional arguments (optional)
"""
[docs]
class JaxVirtualPointEvaluator(Protocol):
def __call__(self, arr: Array, idx: tuple[int | slice, ...], args=None) -> Array:
"""Evaluate the virtual point at the given position.
Args:
arr: Data array
idx: Index tuple
args: Additional arguments (optional)
"""
[docs]
class JaxInnerStepperType(Protocol):
"""General stepper type working with jax arrays."""
def __call__(
self, state_data: Array, t_start: float, t_end: float
) -> tuple[Array, float]:
"""General stepper that advances the state given as a jax array.
Args:
state_data (:class:`~jax.Array`):
The current state
t_start (float):
Initial time point
t_end (float):
Desired final time point
Returns:
tuple of the state and time at the final point
"""