Source code for pde.solvers.euler

"""Defines an explicit solver supporting various methods.

.. autosummary::
   :nosignatures:

   ExplicitSolver
   EulerSolver
   RungeKuttaSolver

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Literal

import numpy as np

from ..tools.math import OnlineStatistics
from .base import AdaptiveSolverBase, _make_dt_adjuster

if TYPE_CHECKING:
    from collections.abc import Callable

    from ..pdes.base import PDEBase
    from ..tools.typing import InnerStepperType, NumericArray, TField


[docs] class EulerSolver(AdaptiveSolverBase): """Explicit Euler solver.""" name = "euler" def _make_single_step_fixed_dt( self, state: TField, dt: float ) -> Callable[[NumericArray, float], NumericArray]: """Make a simple Euler stepper with fixed time step. Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted dt (float): Time step of the explicit stepping. Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, steps: int)` """ if self.pde.is_sde: # handle stochastic version of the pde rhs_pde = self.backend.make_pde_rhs(self.pde, state) rhs_noise = self.pde.make_noise_realization(state, backend=self.backend) # type: ignore rhs_noise = self.backend.compile_function(rhs_noise) def stepper(state_data: NumericArray, t: float) -> NumericArray: """Perform a single Euler-Maruyama step.""" evolution_rate = rhs_pde(state_data, t) noise_realization = rhs_noise(state_data, t) state_data += dt * evolution_rate if noise_realization is not None: state_data += np.sqrt(dt) * noise_realization return state_data self._logger.info( "Initialize explicit Euler-Maruyama stepper with dt=%g", dt ) else: # handle deterministic version of the pde rhs_pde = self.backend.make_pde_rhs(self.pde, state) def stepper(state_data: NumericArray, t: float) -> NumericArray: """Perform a single Euler step.""" state_data += dt * rhs_pde(state_data, t) return state_data self._logger.info("Initialize explicit Euler stepper with dt=%g", dt) return stepper def _make_inner_stepper(self, state: TField) -> InnerStepperType: """Make an (adaptive) Euler stepper. Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, t_end: float)` """ if not self.adaptive: # create stepper with fixed steps return super()._make_inner_stepper(state) # General comment: We implement the full adaptive scheme here instead of just # defining `_make_single_step_error_estimate` to do some optimizations. In # particular, we reuse the calculated right hand side in cases where the step # was not successful. if getattr(self.pde, "is_sde", False): msg = "Cannot use adaptive stepper with stochastic equation" raise RuntimeError(msg) # obtain functions determining how the PDE is evolved rhs_pde = self.backend.make_pde_rhs(self.pde, state) # if post_step_hook is None: post_step_hook = self._make_post_step_hook(state) # obtain auxiliary functions sync_errors = self.backend.make_mpi_synchronizer(operator="MAX") # if adjust_dt is None: adjust_dt = _make_dt_adjuster(self.dt_min, self.dt_max) tolerance = self.tolerance dt_min = self.dt_min # add extra information self.info["dt_adaptive"] = self.adaptive self.info["dt_statistics"] = OnlineStatistics() def adaptive_stepper( state_data: NumericArray, t_start: float, t_end: float ) -> float: """Adaptive stepper that advances the state in time.""" state_cur = state_data dt_opt = self.info["dt"] # time step from last step rate = rhs_pde(state_cur, t_start) # calculate initial rate steps = 0 t = t_start while True: # use a smaller (but not too small) time step if close to t_end dt_step = max(min(dt_opt, t_end - t), dt_min) # do single step with dt step_large = state_cur + dt_step * rate # do double step with half the time step step_small = state_cur + 0.5 * dt_step * rate try: # calculate rate at the midpoint of the double step rate_midpoint = rhs_pde(step_small, t + 0.5 * dt_step) except Exception: # an exception likely signals that rate could not be calculated error_rel = np.nan else: # advance to end of double step step_small += 0.5 * dt_step * rate_midpoint # calculate maximal error error = np.abs(step_large - step_small).max() error_rel = error / tolerance # normalize error to given tolerance # synchronize the error between all processes (necessary for MPI) error_rel = sync_errors(error_rel) if error_rel <= 1: # error is sufficiently small try: # calculating the rate at putative new step rate = rhs_pde(step_small, t) except Exception: # calculating the rate failed => retry with smaller dt error_rel = np.nan else: # everything worked => do the step steps += 1 t += dt_step state_cur, self.info["post_step_data"] = post_step_hook( step_small, t, self.info["post_step_data"] ) if self.info.get("dt_statistics"): self.info["dt_statistics"].add(dt_step) if t < t_end: # adjust the time step and continue dt_opt = adjust_dt(dt_step, error_rel) else: break # return to the controller self.info["dt"] = dt_opt # save last optimal time step self.info["steps"] += steps state_data[:] = state_cur return t self._logger.info("Initialize adaptive Euler stepper") return adaptive_stepper
[docs] class ExplicitSolver(AdaptiveSolverBase): """Various explicit PDE solvers.""" name = "explicit" def __new__( cls, pde: PDEBase, scheme: Literal["euler", "runge-kutta", "rk", "rk45"] = "euler", **kwargs, ): """ Args: pde (:class:`~pde.pdes.base.PDEBase`): The partial differential equation that should be solved scheme (str): Defines the explicit scheme to use. Supported values are 'euler' and 'runge-kutta' (or 'rk' for short). **kwargs: Additional arguments such as `backend`, `adaptive`, and `tolerance` that are forwarded to the chosen solver class. """ # deprecated since 2025-11-01 warnings.warn( "`ExplicitSolver` is deprecated. Use `EulerSolver` or `RungeKuttaSolver`.", stacklevel=2, ) if scheme == "euler": return EulerSolver(pde=pde, **kwargs) if scheme in {"rk", "rk45", "runge-kutta"}: from .runge_kutta import RungeKuttaSolver return RungeKuttaSolver(pde=pde, **kwargs) msg = f"Scheme `{scheme}` is not supported" raise ValueError(msg)