Source code for pde.solvers.base

"""
Package that contains base classes for solvers.

Beside the abstract base class defining the interfaces, we also provide
:class:`AdaptiveSolverBase`, which contains methods for implementing adaptive solvers.

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de> 
"""

from __future__ import annotations

import logging
import warnings
from abc import ABCMeta
from inspect import isabstract
from typing import Any, Callable, Dict, List, Optional, Tuple, Type  # @UnusedImport

import numba as nb
import numpy as np
from numba.extending import register_jitable

from ..fields.base import FieldBase
from ..pdes.base import PDEBase
from ..tools.math import OnlineStatistics
from ..tools.misc import classproperty
from ..tools.numba import is_jitted, jit


[docs]class SolverBase(metaclass=ABCMeta): """base class for solvers""" dt_default: float = 1e-3 """float: default time step used if no time step was specified""" _modify_state_after_step: bool = True """bool: flag choosing whether the `modify_after_step` hook of the PDE is called""" _subclasses: Dict[str, Type[SolverBase]] = {} """dict: dictionary of all inheriting classes""" def __init__(self, pde: PDEBase, *, backend: str = "auto"): """ Args: pde (:class:`~pde.pdes.base.PDEBase`): The partial differential equation that should be solved backend (str): Determines how the function is created. Accepted values are 'numpy` and 'numba'. Alternatively, 'auto' lets the code decide for the most optimal backend. """ self.pde = pde self.backend = backend self.info: Dict[str, Any] = {"class": self.__class__.__name__} if self.pde: self.info["pde_class"] = self.pde.__class__.__name__ self._logger = logging.getLogger(self.__class__.__name__) def __init_subclass__(cls, **kwargs): # @NoSelf """register all subclassess to reconstruct them later""" super().__init_subclass__(**kwargs) if not isabstract(cls): if cls.__name__ in cls._subclasses: warnings.warn(f"Redefining class {cls.__name__}") cls._subclasses[cls.__name__] = cls if hasattr(cls, "name") and cls.name: if cls.name in cls._subclasses: logging.warning(f"Solver with name {cls.name} is already registered") cls._subclasses[cls.name] = cls
[docs] @classmethod def from_name(cls, name: str, pde: PDEBase, **kwargs) -> SolverBase: r"""create solver class based on its name Solver classes are automatically registered when they inherit from :class:`SolverBase`. Note that this also requires that the respective python module containing the solver has been loaded before it is attempted to be used. Args: name (str): The name of the solver to construct pde (:class:`~pde.pdes.base.PDEBase`): The partial differential equation that should be solved \**kwargs: Additional arguments for the constructor of the solver Returns: An instance of a subclass of :class:`SolverBase` """ try: # obtain the solver class associated with `name` solver_class = cls._subclasses[name] except KeyError: # solver was not registered solvers = ( f"'{solver}'" for solver in sorted(cls._subclasses.keys()) if not solver.endswith("Solver") ) raise ValueError( f"Unknown solver method '{name}'. Registered solvers are " + ", ".join(solvers) ) return solver_class(pde, **kwargs)
@classproperty def registered_solvers(cls) -> List[str]: # @NoSelf """list of str: the names of the registered solvers""" return list(sorted(cls._subclasses.keys())) @property def _compiled(self) -> bool: """bool: indicates whether functions need to be compiled""" return self.backend == "numba" and not nb.config.DISABLE_JIT def _make_modify_after_step( self, state: FieldBase ) -> Callable[[np.ndarray], float]: """create a function that modifies a state after each step A noop function will be returned if `_modify_state_after_step` is `False`, Args: state (:class:`~pde.fields.FieldBase`): An example for the state from which the grid and other information can be extracted. """ if self._modify_state_after_step: modify_after_step = jit(self.pde.make_modify_after_step(state)) else: def modify_after_step(state_data: np.ndarray) -> float: return 0 if self._compiled: sig_modify = (nb.typeof(state.data),) modify_after_step = jit(sig_modify)(modify_after_step) return modify_after_step # type: ignore def _make_pde_rhs( self, state: FieldBase, backend: str = "auto" ) -> Callable[[np.ndarray, float], np.ndarray]: """obtain a function for evaluating the right hand side Args: state (:class:`~pde.fields.FieldBase`): An example for the state from which the grid and other information can be extracted. backend (str): Determines how the function is created. Accepted values are 'numpy` and 'numba'. Alternatively, 'auto' lets the code decide for the most optimal backend. Raises: RuntimeError: when a stochastic partial differential equation is encountered but `allow_stochastic == False`. Returns: A function that is called with data given by a :class:`~numpy.ndarray` and a time. The function returns the deterministic evolution rate and (if applicable) a realization of the associated noise. """ if getattr(self.pde, "is_sde"): raise RuntimeError( f"Cannot create a deterministic stepper for a stochastic equation" ) rhs = self.pde.make_pde_rhs(state, backend=backend) # type: ignore if hasattr(rhs, "_backend"): self.info["backend"] = rhs._backend elif is_jitted(rhs): self.info["backend"] = "numba" else: self.info["backend"] = "undetermined" return rhs def _make_sde_rhs( self, state: FieldBase, backend: str = "auto" ) -> Callable[[np.ndarray, float], Tuple[np.ndarray, np.ndarray]]: """obtain a function for evaluating the right hand side Args: state (:class:`~pde.fields.FieldBase`): An example for the state from which the grid and other information can be extracted. backend (str): Determines how the function is created. Accepted values are 'numpy` and 'numba'. Alternatively, 'auto' lets the code decide for the most optimal backend. Raises: RuntimeError: when a stochastic partial differential equation is encountered but `allow_stochastic == False`. Returns: A function that is called with data given by a :class:`~numpy.ndarray` and a time. The function returns the deterministic evolution rate and (if applicable) a realization of the associated noise. """ rhs = self.pde.make_sde_rhs(state, backend=backend) # type: ignore if hasattr(rhs, "_backend"): self.info["backend"] = rhs._backend elif is_jitted(rhs): self.info["backend"] = "numba" else: self.info["backend"] = "undetermined" return rhs def _make_single_step_fixed_dt( self, state: FieldBase, dt: float ) -> Callable[[np.ndarray, float], None]: """return a function doing a single step with a fixed time step Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted dt (float): Time step of the explicit stepping. """ raise NotImplementedError("Fixed stepper has not been defined") def _make_fixed_stepper( self, state: FieldBase, dt: float ) -> Callable[[np.ndarray, float, int], Tuple[float, float]]: """return a stepper function using an explicit scheme with fixed time steps Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted dt (float): Time step of the explicit stepping. """ single_step = self._make_single_step_fixed_dt(state, dt) modify_state_after_step = self._modify_state_after_step modify_after_step = self._make_modify_after_step(state) if self._compiled: sig_single_step = (nb.typeof(state.data), nb.double) single_step = jit(sig_single_step)(single_step) def fixed_stepper( state_data: np.ndarray, t_start: float, steps: int ) -> Tuple[float, float]: """perform `steps` steps with fixed time steps""" modifications = 0.0 for i in range(steps): # calculate the right hand side t = t_start + i * dt single_step(state_data, t) if modify_state_after_step: modifications += modify_after_step(state_data) return t + dt, modifications if self._compiled: sig_fixed = (nb.typeof(state.data), nb.double, nb.int_) fixed_stepper = jit(sig_fixed)(fixed_stepper) return fixed_stepper
[docs] def make_stepper( self, state: FieldBase, dt: Optional[float] = None ) -> Callable[[FieldBase, float, float], float]: """return a stepper function using an explicit scheme Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted dt (float): Time step used (Uses :attr:`SolverBase.dt_default` if `None`) Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, t_end: float)` """ # support `None` as a default value, so the controller can signal that # the solver should use a default time step if dt is None: dt = self.dt_default self._logger.warning( "Explicit stepper with a fixed time step did not receive any " f"initial value for `dt`. Using dt={dt}, but specifying a value or " "enabling adaptive stepping is advisable." ) dt_float = float(dt) # explicit casting to help type checking self.info["dt"] = dt_float self.info["steps"] = 0 self.info["state_modifications"] = 0.0 self.info["stochastic"] = getattr(self.pde, "is_sde", False) # we don't access self.pde directly since we might want to reuse the solver # infrastructure for more general cases where a PDE is not defined # create stepper with fixed steps fixed_stepper = self._make_fixed_stepper(state, dt_float) def wrapped_stepper(state: FieldBase, t_start: float, t_end: float) -> float: """advance `state` from `t_start` to `t_end` using fixed steps""" # calculate number of steps (which is at least 1) steps = max(1, int(np.ceil((t_end - t_start) / dt_float))) t_last, modifications = fixed_stepper(state.data, t_start, steps) self.info["steps"] += steps self.info["state_modifications"] += modifications return t_last return wrapped_stepper
[docs]class AdaptiveSolverBase(SolverBase): """base class for adaptive time steppers""" dt_min: float = 1e-10 """float: minimal time step that the adaptive solver will use""" dt_max: float = 1e10 """float: maximal time step that the adaptive solver will use""" def __init__( self, pde: PDEBase, *, backend: str = "auto", adaptive: bool = True, tolerance: float = 1e-4, ): """ Args: pde (:class:`~pde.pdes.base.PDEBase`): The instance describing the pde that needs to be solved adaptive (bool): When enabled, the time step is adjusted during the simulation using the error tolerance set with `tolerance`. tolerance (float): The error tolerance used in adaptive time stepping. This is used in adaptive time stepping to choose a time step which is small enough so the truncation error of a single step is below `tolerance`. """ super().__init__(pde, backend=backend) self.adaptive = adaptive self.tolerance = tolerance def _make_error_synchronizer(self) -> Callable[[float], float]: """return helper function that synchronizes errors between multiple processes""" @register_jitable def synchronize_errors(error: float) -> float: return error return synchronize_errors # type: ignore def _make_dt_adjuster(self) -> Callable[[float, float], float]: """return a function that can be used to adjust time steps""" dt_min = self.dt_min dt_min_err = f"Time step below {dt_min}" dt_max = self.dt_max def adjust_dt(dt: float, error_rel: float) -> float: """helper function that adjust the time step Args: dt (float): Current time step error_rel (float): Current (normalized) error estimate Returns: float: Time step of the next iteration """ # adjust the time step if error_rel < 0.00057665: # error was very small => maximal increase in dt # The constant on the right hand side of the comparison is chosen to # agree with the equation for adjusting dt below dt *= 4.0 elif np.isnan(error_rel): # state contained NaN => decrease time step strongly dt *= 0.25 else: # otherwise, adjust time step according to error dt *= max(0.9 * error_rel**-0.2, 0.1) # limit time step to permissible bracket if dt > dt_max: dt = dt_max elif dt < dt_min: if np.isnan(error_rel): raise RuntimeError("Encountered NaN during simulation") else: raise RuntimeError(dt_min_err) return dt if self._compiled: adjust_dt = jit((nb.double, nb.double))(adjust_dt) return adjust_dt def _make_single_step_variable_dt( self, state: FieldBase ) -> Callable[[np.ndarray, float, float], np.ndarray]: """return a function doing a single step with a variable time step Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, t_end: float)` """ rhs_pde = self._make_pde_rhs(state, backend=self.backend) def single_step(state_data: np.ndarray, t: float, dt: float) -> np.ndarray: """basic implementation of Euler scheme""" return state_data + dt * rhs_pde(state_data, t) # type: ignore return single_step def _make_single_step_error_estimate( self, state: FieldBase ) -> Callable[[np.ndarray, float, float], Tuple[np.ndarray, float]]: """make a stepper that also estimates the error Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted """ if getattr(self.pde, "is_sde"): raise RuntimeError("Cannot use adaptive stepper with stochastic equation") single_step = self._make_single_step_variable_dt(state) if compiled := self._compiled: sig_single_step = (nb.typeof(state.data), nb.double, nb.double) single_step = jit(sig_single_step)(single_step) def single_step_error_estimate( state_data: np.ndarray, t: float, dt: float ) -> Tuple[np.ndarray, float]: """basic stepper to estimate error""" # single step with dt k1 = single_step(state_data, t, dt) # double step with half the time step k2a = single_step(state_data, t, 0.5 * dt) k2 = single_step(k2a, t + 0.5 * dt, 0.5 * dt) # calculate maximal error if compiled: error = 0.0 for i in range(state_data.size): # max() has the weird behavior that `max(np.nan, 0)` is `np.nan` # while `max(0, np.nan) == 0`. To propagate NaNs in the # evaluation, we thus need to use the following order: error = max(abs(k1.flat[i] - k2.flat[i]), error) else: error = np.abs(k1 - k2).max() return k2, error return single_step_error_estimate def _make_adaptive_stepper( self, state: FieldBase ) -> Callable[ [np.ndarray, float, float, float, Optional[OnlineStatistics]], Tuple[float, float, int, float], ]: """make an adaptive Euler stepper Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, t_end: float)` """ # obtain functions determining how the PDE is evolved single_step_error = self._make_single_step_error_estimate(state) modify_after_step = self._make_modify_after_step(state) modify_state_after_step = self._modify_state_after_step sync_errors = self._make_error_synchronizer() # obtain auxiliary functions adjust_dt = self._make_dt_adjuster() tolerance = self.tolerance dt_min = self.dt_min if self._compiled: # compile paired stepper sig_stepper = (nb.typeof(state.data), nb.double, nb.double) single_step_error = jit(sig_stepper)(single_step_error) def adaptive_stepper( state_data: np.ndarray, t_start: float, t_end: float, dt_init: float, dt_stats: Optional[OnlineStatistics] = None, ) -> Tuple[float, float, int, float]: """adaptive stepper that advances the state in time""" modifications = 0.0 dt_opt = dt_init t = t_start steps = 0 while True: # use a smaller (but not too small) time step if close to t_end dt_step = max(min(dt_opt, t_end - t), dt_min) # two different steppings to estimate errors new_state, error = single_step_error(state_data, t, dt_step) error_rel = error / tolerance # normalize error to given tolerance # synchronize the error between all processes (if necessary) error_rel = sync_errors(error_rel) # do the step if the error is sufficiently small if error_rel <= 1: steps += 1 t += dt_step state_data[...] = new_state if modify_state_after_step: modifications += modify_after_step(state_data) if dt_stats is not None: dt_stats.add(dt_step) if t < t_end: # adjust the time step and continue dt_opt = adjust_dt(dt_step, error_rel) else: break # return to the controller return t, dt_opt, steps, modifications if self._compiled: # compile inner stepper sig_adaptive = ( nb.typeof(state.data), nb.double, nb.double, nb.double, nb.typeof(self.info["dt_statistics"]), ) adaptive_stepper = jit(sig_adaptive)(adaptive_stepper) self._logger.info(f"Initialized adaptive stepper") return adaptive_stepper
[docs] def make_stepper( self, state: FieldBase, dt: Optional[float] = None ) -> Callable[[FieldBase, float, float], float]: """return a stepper function using an explicit scheme Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted dt (float): Time step used (Uses :attr:`SolverBase.dt_default` if `None`). This sets the initial time step for adaptive solvers. Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, t_end: float)` """ if not self.adaptive: # create stepper with fixed steps return super().make_stepper(state, dt) # support `None` as a default value, so the controller can signal that # the solver should use a default time step if dt is None: dt_float = self.dt_default else: dt_float = float(dt) # explicit casting to help type checking self.info["dt"] = dt_float self.info["dt_adaptive"] = True self.info["steps"] = 0 self.info["stochastic"] = getattr(self.pde, "is_sde", False) self.info["state_modifications"] = 0.0 # create stepper with adaptive steps self.info["dt_statistics"] = OnlineStatistics() adaptive_stepper = self._make_adaptive_stepper(state) def wrapped_stepper(state: FieldBase, t_start: float, t_end: float) -> float: """advance `state` from `t_start` to `t_end` using adaptive steps""" nonlocal dt_float # `dt_float` stores value for the next call t_last, dt_float, steps, modifications = adaptive_stepper( state.data, t_start, t_end, dt_float, self.info["dt_statistics"] ) self.info["steps"] += steps self.info["state_modifications"] += modifications return t_last return wrapped_stepper