Source code for pde.backends.torch.utils

"""Defines utilities for the torch backend.

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from collections.abc import Callable

import numpy as np
import torch
from numpy.typing import DTypeLike

AnyDType = DTypeLike | torch.dtype

NUMPY_TO_TORCH_DTYPE: dict[DTypeLike, torch.dtype] = {
    np.bool: torch.bool,
    np.uint8: torch.uint8,
    np.int8: torch.int8,
    np.int16: torch.int16,
    np.int32: torch.int32,
    np.int64: torch.int64,
    np.float16: torch.float16,
    np.float32: torch.float32,
    np.float64: torch.float64,
    np.double: torch.double,
    np.complex64: torch.complex64,
    np.complex128: torch.complex128,
}
# also define inverse mapping to proper numpy dtypes
TORCH_TO_NUMPY_DTYPE = {v: np.dtype(k) for k, v in NUMPY_TO_TORCH_DTYPE.items()}
# add the proper numpy dtype as an alternative
NUMPY_TO_TORCH_DTYPE |= {np.dtype(k): v for k, v in NUMPY_TO_TORCH_DTYPE.items()}


[docs] class TorchOperatorBase(torch.nn.Module): """Base class for operators implemented in torch.""" def __init__(self, *, dtype: DTypeLike): """Initialize the torch operator. Args: dtype: The data type of the field using the numpy convention """ super().__init__() self.dtype = np.dtype(dtype)
[docs] def register_array(self, name: str, arr: np.ndarray | torch.Tensor) -> None: """Register an array as a buffer in the torch module. Args: name (str): The name under which the buffer is registered arr (:class:`numpy.ndarray` or :class:`torch.Tensor`): The array to register. If a numpy array is provided, it will be converted to a torch tensor with the appropriate dtype. """ if isinstance(arr, np.ndarray): tensor = torch.from_numpy(np.asarray(arr, dtype=self.dtype)) elif isinstance(arr, torch.Tensor): tensor = arr else: raise TypeError self.register_buffer(name, tensor)
[docs] def torch_heaviside(x1: torch.Tensor, x2: torch.Tensor | None = None) -> torch.Tensor: """Return the Heaviside step function using torch. This wraps :func:`torch.heaviside` and ensures that scalar fallback values are converted to tensors with a dtype compatible with `x1`. Args: x1 (:class:`torch.Tensor`): Input values at which the Heaviside function is evaluated. x2 (:class:`torch.Tensor`, optional): Value used where `x1 == 0`. If omitted, `0.5` is used. Returns: :class:`torch.Tensor`: Tensor containing the Heaviside values of `x1`. """ x1_t = torch.tensor(x1) if x2 is None: return torch.heaviside(x1_t, torch.tensor(0.5, dtype=x1_t.dtype)) return torch.heaviside(x1_t, torch.tensor(x2, dtype=x1_t.dtype))
[docs] def torch_hypot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: """Return the Euclidean norm ``sqrt(x1**2 + x2**2)`` using torch. This wraps :func:`torch.hypot` and ensures that both inputs are converted to tensors before evaluation. Args: x1 (:class:`torch.Tensor`): First input values. x2 (:class:`torch.Tensor`): Second input values. Returns: :class:`torch.Tensor`: Tensor containing the element-wise hypotenuse of `x1` and `x2`. """ return torch.hypot(torch.tensor(x1), torch.tensor(x2))
SPECIAL_FUNCTIONS_TORCH: dict[str, Callable] = { "Heaviside": torch_heaviside, "hypot": torch_hypot, }