Source code for pde.backends.torch.utils

"""Defines utilities for the torch backend.

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from typing import Any, Protocol

import numpy as np
import torch
from numpy.typing import DTypeLike
from torch import Tensor

AnyDType = DTypeLike | torch.dtype

NUMPY_TO_TORCH_DTYPE: dict[DTypeLike, torch.dtype] = {
    np.bool: torch.bool,
    np.uint8: torch.uint8,
    np.int8: torch.int8,
    np.int16: torch.int16,
    np.int32: torch.int32,
    np.int64: torch.int64,
    np.float16: torch.float16,
    np.float32: torch.float32,
    np.float64: torch.float64,
    np.complex64: torch.complex64,
    np.complex128: torch.complex128,
}
NUMPY_TO_TORCH_DTYPE = {np.dtype(k): v for k, v in NUMPY_TO_TORCH_DTYPE.items()}


[docs] def get_torch_dtype(dtype: AnyDType) -> torch.dtype: """Convert dtype to torch dtype. Args: dtype: dtype which could be a python type, a numpy dtype, or already a torch dtype Returns: :class:`torch.dtype`: A proper dtype for torch """ if isinstance(dtype, torch.dtype): return dtype dtype = np.dtype(dtype) return NUMPY_TO_TORCH_DTYPE[dtype]
[docs] class TorchOperatorType(Protocol): """An operator that acts on an array.""" def __call__( self, arr: Tensor, out: Tensor | None = None, args: dict[str, Any] | None = None, ) -> Tensor: """Evaluate the operator."""