Source code for pde.solvers.runge_kutta

"""Defines an explicit solver supporting various methods.

.. autosummary::
   :nosignatures:

   ExplicitSolver
   EulerSolver
   RungeKuttaSolver

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

from typing import TYPE_CHECKING

from ..tools.misc import get_array_namespace
from .base import AdaptiveSolverBase

if TYPE_CHECKING:
    from collections.abc import Callable

    from ..tools.typing import NumericArray, TField


[docs] class RungeKuttaSolver(AdaptiveSolverBase): """Explicit Runge-Kutta PDE solver of order 5(4).""" name = "runge-kutta" def _make_single_step_fixed_dt( self, state: TField, dt: float ) -> Callable[[NumericArray, float], NumericArray]: """Make function doing a single explicit Runge-Kutta step of order 5(4) Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted dt (float): Time step of the explicit stepping. Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, steps: int)` """ if self.pde.is_sde: msg = "Deterministic Runge-Kutta does not support stochastic equations" raise RuntimeError(msg) # obtain functions determining how the PDE is evolved rhs = self.backend.make_pde_rhs(self.pde, state) def stepper(state_data: NumericArray, t: float) -> NumericArray: """Compiled inner loop for speed.""" # calculate the intermediate values in Runge-Kutta k1 = dt * rhs(state_data, t) k2 = dt * rhs(state_data + 0.5 * k1, t + 0.5 * dt) k3 = dt * rhs(state_data + 0.5 * k2, t + 0.5 * dt) k4 = dt * rhs(state_data + k3, t + dt) state_data += (k1 + 2 * k2 + 2 * k3 + k4) / 6 return state_data self._logger.info("Initialize explicit Runge-Kutta-45 stepper with dt=%g", dt) return stepper def _make_single_step_error_estimate( self, state: TField ) -> Callable[[NumericArray, float, float], tuple[NumericArray, float]]: """Make an adaptive stepper using the explicit Runge-Kutta-Fehlberg method. Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, t_end: float)` """ if self.pde.is_sde: msg = "Deterministic Runge-Kutta does not support stochastic equations" raise RuntimeError(msg) # obtain functions determining how the PDE is evolved rhs = self.backend.make_pde_rhs(self.pde, state) # use Runge-Kutta-Fehlberg method # define coefficients for RK4(5), formula 2 Table III in Fehlberg a2 = 1 / 4 a3 = 3 / 8 a4 = 12 / 13 a5 = 1.0 a6 = 1 / 2 b21 = 1 / 4 b31 = 3 / 32 b32 = 9 / 32 b41 = 1932 / 2197 b42 = -7200 / 2197 b43 = 7296 / 2197 b51 = 439 / 216 b52 = -8.0 b53 = 3680 / 513 b54 = -845 / 4104 b61 = -8 / 27 b62 = 2.0 b63 = -3544 / 2565 b64 = 1859 / 4104 b65 = -11 / 40 r1 = 1 / 360 # r2 = 0 r3 = -128 / 4275 r4 = -2197 / 75240 r5 = 1 / 50 r6 = 2 / 55 c1 = 25 / 216 # c2 = 0 c3 = 1408 / 2565 c4 = 2197 / 4104 c5 = -1 / 5 def stepper( state_data: NumericArray, t: float, dt: float ) -> tuple[NumericArray, float]: """Basic stepper to estimate error.""" # support any backend following Python Array API nx = get_array_namespace(state_data) # do the six intermediate steps k1 = dt * rhs(state_data, t) k2 = dt * rhs(state_data + b21 * k1, t + a2 * dt) k3 = dt * rhs(state_data + b31 * k1 + b32 * k2, t + a3 * dt) k4 = dt * rhs(state_data + b41 * k1 + b42 * k2 + b43 * k3, t + a4 * dt) k5 = dt * rhs( state_data + b51 * k1 + b52 * k2 + b53 * k3 + b54 * k4, t + a5 * dt, ) k6 = dt * rhs( state_data + b61 * k1 + b62 * k2 + b63 * k3 + b64 * k4 + b65 * k5, t + a6 * dt, ) # estimate the maximal error error_local = r1 * k1 + r3 * k3 + r4 * k4 + r5 * k5 + r6 * k6 error = nx.abs(error_local).max() state_new = state_data + c1 * k1 + c3 * k3 + c4 * k4 + c5 * k5 return state_new, error self._logger.info("Initialize adaptive Runge-Kutta-Fehlberg stepper") return stepper