Source code for pde.backends.torch.operators.common

"""This module implements infrastructure for differential operators using torch.

.. autosummary::
   :nosignatures:

   TorchOperator

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import torch

from ....grids.boundaries import BoundariesList
from .._boundaries import make_local_ghost_cell_setter

if TYPE_CHECKING:
    from torch import Tensor

    from ....grids import GridBase


[docs] class TorchOperator(torch.nn.Module): """Base class for operators implemented in torch.""" data_full: Tensor rank_in: int = 0 """int: The rank of the input tensor""" def __init__( self, grid: GridBase, bcs: BoundariesList | None, *, dtype=np.double, ): """Initialize the torch operator. Args: grid (:class:`~pde.grids.base.GridBase`): The grid on which the operator acts bcs (:class:`~pde.grids.boundaries.axes.BoundariesList` or None): The boundary conditions applied to the field. If `None`, no boundary conditions are enforced and it is assumed that the operator is applied to the full field. dtype: The data type of the field """ super().__init__() # initialize buffer for full data (including ghost cells) self.grid = grid full_shape = (grid.dim,) * self.rank_in + tuple(n + 2 for n in self.grid.shape) data_full = torch.empty(full_shape, dtype=dtype) self.register_buffer("data_full", data_full) if bcs is None: self.apply_bcs = False elif isinstance(bcs, BoundariesList): # get the ghost cell setters for all boundaries if grid != bcs.grid: msg = "Different grids for operator and BCs" raise ValueError(msg) self.apply_bcs = True self.ghost_cell_setters = [ make_local_ghost_cell_setter(bc_local) for bc_axis in bcs for bc_local in bc_axis ] else: raise NotImplementedError
[docs] def set_valid(self, arr: Tensor) -> None: """Set valid data in the internal full array. Args: arr (:class:`torch.Tensor`): The data of the valid grid points """ if self.grid.num_axes == 1: self.data_full[..., 1:-1] = arr elif self.grid.num_axes == 2: self.data_full[..., 1:-1, 1:-1] = arr elif self.grid.num_axes == 3: self.data_full[..., 1:-1, 1:-1, 1:-1] = arr else: raise NotImplementedError
[docs] def set_ghost_cells(self, args=None): """Return function that sets the ghost cells on a full array. Args: boundaries (:class:`~pde.grids.boundaries.axes.BoundariesBase`): Defines the boundary conditions for a particular grid, for which the setter should be defined. Returns: Callable with signature :code:`(data_full: NumericArray, args=None)`, which sets the ghost cells of the full data, potentially using additional information in `args` (e.g., the time `t` during solving a PDE) """ for set_ghost_cells in self.ghost_cell_setters: set_ghost_cells(self.data_full, args=args)
[docs] def get_full_data(self, arr: Tensor, args=None) -> Tensor: """Get full data array including ghost cells. Args: arr (:class:`torch.Tensor`): The input data. If boundary conditions are applied, this should contain only the valid grid points. Otherwise, it should already include ghost cells. Returns: :class:`torch.Tensor`: The full data array including ghost cells with boundary conditions applied if necessary. """ if self.apply_bcs: # `arr` only contains the valid data and we need to apply boundary # conditions. We thus use the internal data `self.data_full` self.set_valid(arr) self.set_ghost_cells(args=args) return self.data_full # Assume `arr` already contains the full data return arr
[docs] class IntegralOperator(torch.nn.Module): """Operator integrating a field implemented in torch.""" def __init__(self, grid: GridBase, *, dtype=np.double): """Initialize the torch operator. Args: grid (:class:`~pde.grids.base.GridBase`): The grid on which the operator acts dtype: The data type of the field """ super().__init__() # initialize cell volumes array necessary for integration self.grid = grid self.spatial_dims = tuple(range(-grid.num_axes, 0, 1)) cell_volumes = np.broadcast_to(grid.cell_volumes, grid.shape) self.register_buffer("cell_volumes", torch.from_numpy(cell_volumes.copy()))
[docs] def forward(self, arr: Tensor) -> Tensor: """Fill internal data array, apply operator, and return valid data.""" amounts = arr * self.cell_volumes # type: ignore return torch.sum(amounts, dim=self.spatial_dims)