"""Defines an explicit solver supporting various methods.
.. autosummary::
:nosignatures:
ExplicitSolver
EulerSolver
RungeKuttaSolver
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from ..tools.misc import get_array_namespace
from .base import AdaptiveSolverBase
if TYPE_CHECKING:
from collections.abc import Callable
from ..tools.typing import NumericArray, TField
[docs]
class RungeKuttaSolver(AdaptiveSolverBase):
"""Explicit Runge-Kutta PDE solver of order 5(4)."""
name = "runge-kutta"
def _make_single_step_fixed_dt(
self, state: TField, dt: float
) -> Callable[[NumericArray, float], NumericArray]:
"""Make function doing a single explicit Runge-Kutta step of order 5(4)
Args:
state (:class:`~pde.fields.base.FieldBase`):
An example for the state from which the grid and other information can
be extracted
dt (float):
Time step of the explicit stepping.
Returns:
Function that can be called to advance the `state` from time `t_start` to
time `t_end`. The function call signature is `(state: numpy.ndarray,
t_start: float, steps: int)`
"""
if self.pde.is_sde:
msg = "Deterministic Runge-Kutta does not support stochastic equations"
raise RuntimeError(msg)
# obtain functions determining how the PDE is evolved
rhs = self.backend.make_pde_rhs(self.pde, state)
def stepper(state_data: NumericArray, t: float) -> NumericArray:
"""Compiled inner loop for speed."""
# calculate the intermediate values in Runge-Kutta
k1 = dt * rhs(state_data, t)
k2 = dt * rhs(state_data + 0.5 * k1, t + 0.5 * dt)
k3 = dt * rhs(state_data + 0.5 * k2, t + 0.5 * dt)
k4 = dt * rhs(state_data + k3, t + dt)
state_data += (k1 + 2 * k2 + 2 * k3 + k4) / 6
return state_data
self._logger.info("Initialize explicit Runge-Kutta-45 stepper with dt=%g", dt)
return stepper
def _make_single_step_error_estimate(
self, state: TField
) -> Callable[[NumericArray, float, float], tuple[NumericArray, float]]:
"""Make an adaptive stepper using the explicit Runge-Kutta-Fehlberg method.
Args:
state (:class:`~pde.fields.base.FieldBase`):
An example for the state from which the grid and other information can
be extracted
Returns:
Function that can be called to advance the `state` from time `t_start` to
time `t_end`. The function call signature is `(state: numpy.ndarray,
t_start: float, t_end: float)`
"""
if self.pde.is_sde:
msg = "Deterministic Runge-Kutta does not support stochastic equations"
raise RuntimeError(msg)
# obtain functions determining how the PDE is evolved
rhs = self.backend.make_pde_rhs(self.pde, state)
# use Runge-Kutta-Fehlberg method
# define coefficients for RK4(5), formula 2 Table III in Fehlberg
a2 = 1 / 4
a3 = 3 / 8
a4 = 12 / 13
a5 = 1.0
a6 = 1 / 2
b21 = 1 / 4
b31 = 3 / 32
b32 = 9 / 32
b41 = 1932 / 2197
b42 = -7200 / 2197
b43 = 7296 / 2197
b51 = 439 / 216
b52 = -8.0
b53 = 3680 / 513
b54 = -845 / 4104
b61 = -8 / 27
b62 = 2.0
b63 = -3544 / 2565
b64 = 1859 / 4104
b65 = -11 / 40
r1 = 1 / 360
# r2 = 0
r3 = -128 / 4275
r4 = -2197 / 75240
r5 = 1 / 50
r6 = 2 / 55
c1 = 25 / 216
# c2 = 0
c3 = 1408 / 2565
c4 = 2197 / 4104
c5 = -1 / 5
def stepper(
state_data: NumericArray, t: float, dt: float
) -> tuple[NumericArray, float]:
"""Basic stepper to estimate error."""
# support any backend following Python Array API
nx = get_array_namespace(state_data)
# do the six intermediate steps
k1 = dt * rhs(state_data, t)
k2 = dt * rhs(state_data + b21 * k1, t + a2 * dt)
k3 = dt * rhs(state_data + b31 * k1 + b32 * k2, t + a3 * dt)
k4 = dt * rhs(state_data + b41 * k1 + b42 * k2 + b43 * k3, t + a4 * dt)
k5 = dt * rhs(
state_data + b51 * k1 + b52 * k2 + b53 * k3 + b54 * k4,
t + a5 * dt,
)
k6 = dt * rhs(
state_data + b61 * k1 + b62 * k2 + b63 * k3 + b64 * k4 + b65 * k5,
t + a6 * dt,
)
# estimate the maximal error
error_local = r1 * k1 + r3 * k3 + r4 * k4 + r5 * k5 + r6 * k6
error = nx.abs(error_local).max()
state_new = state_data + c1 * k1 + c3 * k3 + c4 * k4 + c5 * k5
return state_new, error
self._logger.info("Initialize adaptive Runge-Kutta-Fehlberg stepper")
return stepper