"""Provides support for mypy type checking of the package.
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Protocol, TypeVar
import numpy as np
from numpy.typing import ArrayLike # noqa: F401
if TYPE_CHECKING:
from ..fields import DataFieldBase, FieldCollection
from ..grids.base import GridBase
# types for single numbers:
Real = int | float # a real number (no complex number allowed)
Number = Real | complex | np.number # any number, including complex numbers
# array types:
NumericArray = np.ndarray[Any, np.dtype[np.number]] # array of numbers (incl complex)
NumberOrArray = Number | NumericArray # number or array of numbers (incl complex)
# a floating number or an array of floating (no integers and no complex numbers)
FloatingArray = np.ndarray[Any, np.dtype[np.floating]]
FloatOrArray = float | np.ndarray[Any, np.dtype[np.floating]]
# miscellaneous types:
BackendType = Literal["scipy", "numpy", "numba", "numba_mpi"]
TField = TypeVar("TField", "FieldCollection", "DataFieldBase", covariant=True)
[docs]
class OperatorInfo(NamedTuple):
"""Stores information about an operator."""
factory: OperatorFactory
rank_in: int
rank_out: int
name: str = "" # attach a unique name to help caching
[docs]
class OperatorImplType(Protocol):
"""An operator that acts on an array."""
def __call__(self, arr: NumericArray, out: NumericArray) -> None:
"""Evaluate the operator."""
[docs]
class OperatorFactory(Protocol):
"""A factory function that creates an operator for a particular grid."""
def __call__(self, grid: GridBase, **kwargs) -> OperatorImplType:
"""Create the operator."""
[docs]
class OperatorType(Protocol):
"""An operator that acts on an array."""
def __call__(
self,
arr: NumericArray,
out: NumericArray | None = None,
args: dict[str, Any] | None = None,
) -> NumericArray:
"""Evaluate the operator."""
[docs]
class CellVolume(Protocol):
def __call__(self, *args: int) -> float:
"""Calculate the volume of the cell at the given position."""
[docs]
class VirtualPointEvaluator(Protocol):
def __call__(self, arr: NumericArray, idx: tuple[int, ...], args=None) -> float:
"""Evaluate the virtual point at the given position."""
[docs]
class GhostCellSetter(Protocol):
def __call__(self, data_full: NumericArray, args=None) -> None:
"""Set the ghost cells."""
[docs]
class DataSetter(Protocol):
def __call__(self, data_full: NumericArray, args=None) -> None:
"""Set the valid data cells (and potentially BCs)."""
[docs]
class StepperHook(Protocol):
def __call__(
self, state_data: NumericArray, t: float, post_step_data: NumericArray
) -> None:
"""Function analyzing and potentially modifying the current state."""