Source code for pde.solvers.explicit

"""Defines an explicit solver supporting various methods.

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Literal

import numpy as np

from .base import AdaptiveSolverBase, AdaptiveStepperType, _make_dt_adjuster

if TYPE_CHECKING:
    from collections.abc import Callable

    from ..pdes.base import PDEBase
    from ..tools.math import OnlineStatistics
    from ..tools.typing import NumericArray, StepperHook, TField


[docs] class EulerSolver(AdaptiveSolverBase): """Explicit Euler solver.""" name = "euler" def _make_single_step_fixed_dt( self, state: TField, dt: float ) -> Callable[[NumericArray, float], None]: """Make a simple Euler stepper with fixed time step. Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted dt (float): Time step of the explicit stepping. Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, steps: int)` """ if self.pde.is_sde: # handle stochastic version of the pde self.info["scheme"] = "euler-maruyama" rhs_pde = self._make_pde_rhs(state) rhs_noise = self.pde.make_noise_realization(state, backend=self.backend) # type: ignore def stepper(state_data: NumericArray, t: float) -> None: """Perform a single Euler-Maruyama step.""" evolution_rate = rhs_pde(state_data, t) noise_realization = rhs_noise(state_data, t) state_data += dt * evolution_rate if noise_realization is not None: state_data += np.sqrt(dt) * noise_realization self._logger.info("Init explicit Euler-Maruyama stepper with dt=%g", dt) else: # handle deterministic version of the pde if self.pde.is_sde: msg = "Deterministic Euler stepper doesn't support stochastic equations" raise RuntimeError(msg) self.info["scheme"] = "euler" rhs_pde = self._make_pde_rhs(state) def stepper(state_data: NumericArray, t: float) -> None: """Perform a single Euler step.""" state_data += dt * rhs_pde(state_data, t) self._logger.info("Init explicit Euler stepper with dt=%g", dt) return stepper def _make_adaptive_stepper( self, state: TField, *, post_step_hook: StepperHook | None = None, adjust_dt: Callable[[float, float], float] | None = None, ) -> AdaptiveStepperType: """Make an adaptive Euler stepper. Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted post_step_hook (callable or None): A function that runs the post_step_hook adjust_dt (callable or None): A function that is used to adjust the time step. The function takes the current time step and a relative error and returns an adjusted time step Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, t_end: float)` """ # General comment: We implement the full adaptive scheme here instead of just # defining `_make_single_step_error_estimate` to do some optimizations. In # particular, we reuse the calculated right hand side in cases where the step # was not successful. if self.pde.is_sde: msg = "Deterministic adaptive stepper does not support stochastic equations" raise RuntimeError(msg) # obtain functions determining how the PDE is evolved rhs_pde = self._make_pde_rhs(state) if post_step_hook is None: post_step_hook = self._make_post_step_hook(state) # obtain auxiliary functions sync_errors = self._backend_obj.make_mpi_synchronizer(operator="MAX") if adjust_dt is None: adjust_dt = _make_dt_adjuster(self.dt_min, self.dt_max) tolerance = self.tolerance dt_min = self.dt_min def adaptive_stepper( state_data: NumericArray, t_start: float, t_end: float, dt_init: float, dt_stats: OnlineStatistics | None = None, post_step_data=None, ) -> tuple[float, float, int]: """Adaptive stepper that advances the state in time.""" dt_opt = dt_init rate = rhs_pde(state_data, t_start) # calculate initial rate steps = 0 t = t_start while True: # use a smaller (but not too small) time step if close to t_end dt_step = max(min(dt_opt, t_end - t), dt_min) # do single step with dt step_large = state_data + dt_step * rate # do double step with half the time step step_small = state_data + 0.5 * dt_step * rate try: # calculate rate at the midpoint of the double step rate_midpoint = rhs_pde(step_small, t + 0.5 * dt_step) except Exception: # an exception likely signals that rate could not be calculated error_rel = np.nan else: # advance to end of double step step_small += 0.5 * dt_step * rate_midpoint # calculate maximal error error = np.abs(step_large - step_small).max() error_rel = error / tolerance # normalize error to given tolerance # synchronize the error between all processes (necessary for MPI) error_rel = sync_errors(error_rel) if error_rel <= 1: # error is sufficiently small try: # calculating the rate at putative new step rate = rhs_pde(step_small, t) except Exception: # calculating the rate failed => retry with smaller dt error_rel = np.nan else: # everything worked => do the step steps += 1 t += dt_step state_data[...] = step_small post_step_hook(state_data, t, post_step_data) if dt_stats is not None: dt_stats.add(dt_step) if t < t_end: # adjust the time step and continue dt_opt = adjust_dt(dt_step, error_rel) else: break # return to the controller return t, dt_opt, steps self._logger.info("Init adaptive Euler stepper") return adaptive_stepper
[docs] class RungeKuttaSolver(AdaptiveSolverBase): """Explicit Runge-Kutta PDE solver of order 5(4).""" name = "runge-kutta" def _make_single_step_fixed_dt( self, state: TField, dt: float ) -> Callable[[NumericArray, float], None]: """Make function doing a single explicit Runge-Kutta step of order 5(4) Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted dt (float): Time step of the explicit stepping. Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, steps: int)` """ if self.pde.is_sde: msg = "Deterministic Runge-Kutta does not support stochastic equations" raise RuntimeError(msg) # obtain functions determining how the PDE is evolved rhs = self._make_pde_rhs(state) def stepper(state_data: NumericArray, t: float) -> None: """Compiled inner loop for speed.""" # calculate the intermediate values in Runge-Kutta k1 = dt * rhs(state_data, t) k2 = dt * rhs(state_data + 0.5 * k1, t + 0.5 * dt) k3 = dt * rhs(state_data + 0.5 * k2, t + 0.5 * dt) k4 = dt * rhs(state_data + k3, t + dt) state_data += (k1 + 2 * k2 + 2 * k3 + k4) / 6 self._logger.info("Init explicit Runge-Kutta-45 stepper with dt=%g", dt) return stepper def _make_single_step_error_estimate( self, state: TField ) -> Callable[[NumericArray, float, float], tuple[NumericArray, float]]: """Make an adaptive stepper using the explicit Runge-Kutta-Fehlberg method. Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, t_end: float)` """ if self.pde.is_sde: msg = "Deterministic Runge-Kutta does not support stochastic equations" raise RuntimeError(msg) # obtain functions determining how the PDE is evolved rhs = self._make_pde_rhs(state) # use Runge-Kutta-Fehlberg method # define coefficients for RK4(5), formula 2 Table III in Fehlberg a2 = 1 / 4 a3 = 3 / 8 a4 = 12 / 13 a5 = 1.0 a6 = 1 / 2 b21 = 1 / 4 b31 = 3 / 32 b32 = 9 / 32 b41 = 1932 / 2197 b42 = -7200 / 2197 b43 = 7296 / 2197 b51 = 439 / 216 b52 = -8.0 b53 = 3680 / 513 b54 = -845 / 4104 b61 = -8 / 27 b62 = 2.0 b63 = -3544 / 2565 b64 = 1859 / 4104 b65 = -11 / 40 r1 = 1 / 360 # r2 = 0 r3 = -128 / 4275 r4 = -2197 / 75240 r5 = 1 / 50 r6 = 2 / 55 c1 = 25 / 216 # c2 = 0 c3 = 1408 / 2565 c4 = 2197 / 4104 c5 = -1 / 5 def stepper( state_data: NumericArray, t: float, dt: float ) -> tuple[NumericArray, float]: """Basic stepper to estimate error.""" # do the six intermediate steps k1 = dt * rhs(state_data, t) k2 = dt * rhs(state_data + b21 * k1, t + a2 * dt) k3 = dt * rhs(state_data + b31 * k1 + b32 * k2, t + a3 * dt) k4 = dt * rhs(state_data + b41 * k1 + b42 * k2 + b43 * k3, t + a4 * dt) k5 = dt * rhs( state_data + b51 * k1 + b52 * k2 + b53 * k3 + b54 * k4, t + a5 * dt, ) k6 = dt * rhs( state_data + b61 * k1 + b62 * k2 + b63 * k3 + b64 * k4 + b65 * k5, t + a6 * dt, ) # estimate the maximal error error_local = r1 * k1 + r3 * k3 + r4 * k4 + r5 * k5 + r6 * k6 error = np.abs(error_local).max() state_new = state_data + c1 * k1 + c3 * k3 + c4 * k4 + c5 * k5 return state_new, error self._logger.info("Init adaptive Runge-Kutta-Fehlberg stepper") return stepper
[docs] class ExplicitSolver(AdaptiveSolverBase): """Various explicit PDE solvers.""" name = "explicit" def __new__( cls, pde: PDEBase, scheme: Literal["euler", "runge-kutta", "rk", "rk45"] = "euler", **kwargs, ): """ Args: pde (:class:`~pde.pdes.base.PDEBase`): The partial differential equation that should be solved scheme (str): Defines the explicit scheme to use. Supported values are 'euler' and 'runge-kutta' (or 'rk' for short). backend (str): Determines how the function is created. Accepted values are 'numpy` and 'numba'. Alternatively, 'auto' lets the code decide for the most optimal backend. adaptive (bool): When enabled, the time step is adjusted during the simulation using the error tolerance set with `tolerance`. tolerance (float): The error tolerance used in adaptive time stepping. This is used in adaptive time stepping to choose a time step which is small enough so the truncation error of a single step is below `tolerance`. """ # deprecated since 2025-11-01 warnings.warn( "`ExplicitSolver` is deprecated. Use `EulerSolver` or `RungeKuttaSolver`.", stacklevel=2, ) if scheme == "euler": return EulerSolver(pde=pde, **kwargs) if scheme in {"rk", "rk45", "runge-kutta"}: return RungeKuttaSolver(pde=pde, **kwargs) msg = f"Scheme `{scheme}` is not supported" raise ValueError(msg)