Source code for pde.backends.torch.typing

"""Defines types specific to the torch backend.

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from typing import Protocol

import numpy as np
import torch
from numpy.typing import DTypeLike

AnyDType = DTypeLike | torch.dtype

NUMPY_TO_TORCH_DTYPE: dict[DTypeLike, torch.dtype] = {
    np.bool: torch.bool,
    np.uint8: torch.uint8,
    np.int8: torch.int8,
    np.int16: torch.int16,
    np.int32: torch.int32,
    np.int64: torch.int64,
    np.float16: torch.float16,
    np.float32: torch.float32,
    np.float64: torch.float64,
    np.double: torch.double,
    np.complex64: torch.complex64,
    np.complex128: torch.complex128,
}
# also define inverse mapping to proper numpy dtypes
TORCH_TO_NUMPY_DTYPE = {v: np.dtype(k) for k, v in NUMPY_TO_TORCH_DTYPE.items()}
# add the proper numpy dtype as an alternative
NUMPY_TO_TORCH_DTYPE |= {np.dtype(k): v for k, v in NUMPY_TO_TORCH_DTYPE.items()}


[docs] class TorchRHSType(Protocol): """General right-hand-side function type working with torch tensors.""" def __call__(self, state_data: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """Evaluates right hand side of the PDE in a torch backend. Args: state_data (:class:`~torch.Tensor`): The current state t (float): Current time point Returns: :class:`~torch.Tensor`: Evolution rate """
[docs] class TorchInnerStepperType(Protocol): """General backend-level stepping-function type working with torch tensors.""" def __call__( self, state_data: torch.Tensor, t_start: float, t_end: float ) -> tuple[torch.Tensor, float]: """Advance the state given as a torch tensor. Args: state_data (:class:`~torch.Tensor`): The current state t_start (float): Initial time point t_end (float): Desired final time point Returns: tuple of the state and time at the final point """