Source code for pde.solvers.explicit

"""
Defines an explicit solver supporting various methods
   
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de> 
"""

from typing import Callable, Optional, Tuple

import numba as nb
import numpy as np
from numba.extending import register_jitable

from ..fields.base import FieldBase
from ..pdes.base import PDEBase
from ..tools.math import OnlineStatistics
from ..tools.numba import jit
from .base import SolverBase


[docs]class ExplicitSolver(SolverBase): """class for solving partial differential equations explicitly""" name = "explicit" dt_min: float = 1e-10 """float: minimal time step that the adaptive solver will use""" dt_max: float = 1e10 """float: maximal time step that the adaptive solver will use""" def __init__( self, pde: PDEBase, scheme: str = "euler", *, backend: str = "auto", adaptive: bool = False, tolerance: float = 1e-4, ): """ Args: pde (:class:`~pde.pdes.base.PDEBase`): The instance describing the pde that needs to be solved scheme (str): Defines the explicit scheme to use. Supported values are 'euler' and 'runge-kutta' (or 'rk' for short). backend (str): Determines how the function is created. Accepted values are 'numpy` and 'numba'. Alternatively, 'auto' lets the code decide for the most optimal backend. adaptive (bool): When enabled, the time step is adjusted during the simulation using the error tolerance set with `tolerance`. tolerance (float): The error tolerance used in adaptive time stepping. This is used in adaptive time stepping to choose a time step which is small enough so the truncation error of a single step is below `tolerance`. """ super().__init__(pde) self.scheme = scheme self.backend = backend self.adaptive = adaptive self.tolerance = tolerance def _make_error_synchronizer(self) -> Callable[[float], float]: """return helper function that synchronizes errors between multiple processes""" @register_jitable def synchronize_errors(error: float) -> float: return error return synchronize_errors # type: ignore def _make_dt_adjuster(self) -> Callable[[float, float, float], float]: """return a function that can be used to adjust time steps""" dt_min = self.dt_min dt_min_err = f"Time step below {dt_min}" dt_max = self.dt_max def adjust_dt(dt: float, error_rel: float, t: float) -> float: """helper function that adjust the time step Args: dt (float): Current time step error_rel (float): Current (normalized) error estimate t (float): Current time point Returns: float: Time step of the next iteration """ # adjust the time step if error_rel < 0.00057665: # error was very small => maximal increase in dt # The constant on the right hand side of the comparison is chosen to # agree with the equation for adjusting dt below dt *= 4.0 elif np.isnan(error_rel): # state contained NaN => decrease time step strongly dt *= 0.25 else: # otherwise, adjust time step according to error dt *= max(0.9 * error_rel**-0.2, 0.1) # limit time step to permissible bracket if dt > dt_max: dt = dt_max elif dt < dt_min: if np.isnan(error_rel): raise RuntimeError("Encountered NaN during simulation") else: raise RuntimeError(dt_min_err) return dt if self.backend == "numba": adjust_dt = jit(adjust_dt) return adjust_dt def _make_fixed_euler_stepper( self, state: FieldBase, dt: float ) -> Callable[[np.ndarray, float, int], Tuple[float, float]]: """make a simple Euler stepper with fixed time step Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted dt (float): Time step of the explicit stepping. Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, steps: int)` """ # obtain post-step action function modify_after_step = jit(self.pde.make_modify_after_step(state)) if self.pde.is_sde: # handle stochastic version of the pde rhs_sde = self._make_sde_rhs(state, backend=self.backend) def stepper( state_data: np.ndarray, t_start: float, steps: int ) -> Tuple[float, float]: """compiled inner loop for speed""" modifications = 0.0 for i in range(steps): # calculate the right hand side t = t_start + i * dt evolution_rate, noise_realization = rhs_sde(state_data, t) state_data += dt * evolution_rate if noise_realization is not None: state_data += np.sqrt(dt) * noise_realization modifications += modify_after_step(state_data) return t + dt, modifications self.info["stochastic"] = True self._logger.info( f"Initialized explicit Euler-Maruyama stepper with dt=%g", dt ) else: # handle deterministic version of the pde rhs_pde = self._make_pde_rhs(state, backend=self.backend) def stepper( state_data: np.ndarray, t_start: float, steps: int ) -> Tuple[float, float]: """compiled inner loop for speed""" modifications = 0 for i in range(steps): # calculate the right hand side t = t_start + i * dt state_data += dt * rhs_pde(state_data, t) modifications += modify_after_step(state_data) return t + dt, modifications self.info["stochastic"] = False self._logger.info(f"Initialized explicit Euler stepper with dt=%g", dt) return stepper def _make_adaptive_euler_stepper( self, state: FieldBase ) -> Callable[ [np.ndarray, float, float, float, Optional[OnlineStatistics]], Tuple[float, float, int, float], ]: """make an adaptive Euler stepper Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, t_end: float)` """ if self.pde.is_sde: raise RuntimeError( "Cannot use adaptive Euler stepper with stochastic equation" ) # obtain functions determining how the PDE is evolved rhs_pde = self._make_pde_rhs(state, backend=self.backend) modify_after_step = jit(self.pde.make_modify_after_step(state)) # obtain auxiliary functions sync_errors = self._make_error_synchronizer() adjust_dt = self._make_dt_adjuster() tolerance = self.tolerance dt_min = self.dt_min compiled = self.backend == "numba" def stepper( state_data: np.ndarray, t_start: float, t_end: float, dt_init: float, dt_stats: Optional[OnlineStatistics] = None, ) -> Tuple[float, float, int, float]: """compiled inner loop for speed""" modifications = 0.0 dt_opt = dt_init t = t_start calculate_rate = True # flag stating whether to calculate rate for time t steps = 0 while True: # use a smaller (but not too small) time step if close to t_end dt_step = max(min(dt_opt, t_end - t), dt_min) if calculate_rate: rate = rhs_pde(state_data, t) calculate_rate = False # else: rate is reused from last (failed) iteration # single step with dt k1 = state_data + dt_step * rate # double step with half the time step k2 = state_data + 0.5 * dt_step * rate k2 += 0.5 * dt_step * rhs_pde(k2, t + 0.5 * dt_step) # calculate maximal error if compiled: error = 0.0 for i in range(state_data.size): # max() has the weird behavior that `max(np.nan, 0)` is `np.nan` # while `max(0, np.nan) == 0`. To propagate NaNs in the # evaluation, we thus need to use the following order: error = max(abs(k1.flat[i] - k2.flat[i]), error) else: error = np.abs(k1 - k2).max() error_rel = error / tolerance # normalize error to given tolerance # synchronize the error between all processes (if necessary) error_rel = sync_errors(error_rel) # do the step if the error is sufficiently small if error_rel <= 1: steps += 1 t += dt_step state_data[...] = k2 modifications += modify_after_step(state_data) calculate_rate = True if dt_stats is not None: dt_stats.add(dt_step) if t < t_end: # adjust the time step and continue dt_opt = adjust_dt(dt_step, error_rel, t) else: break # return to the controller return t, dt_opt, steps, modifications self._logger.info(f"Initialized adaptive Euler stepper") return stepper def _make_rk45_stepper( self, state: FieldBase, dt: float ) -> Callable[[np.ndarray, float, int], Tuple[float, float]]: """make a simple stepper for the explicit Runge-Kutta method of order 5(4) Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted dt (float): Time step of the explicit stepping. Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, steps: int)` """ if self.pde.is_sde: raise RuntimeError( "Runge-Kutta stepper does not support stochastic equations" ) self.info["stochastic"] = False # obtain functions determining how the PDE is evolved rhs = self._make_pde_rhs(state, backend=self.backend) modify_after_step = jit(self.pde.make_modify_after_step(state)) def stepper( state_data: np.ndarray, t_start: float, steps: int ) -> Tuple[float, float]: """compiled inner loop for speed""" modifications = 0.0 for i in range(steps): # calculate the right hand side t = t_start + i * dt # calculate the intermediate values in Runge-Kutta k1 = dt * rhs(state_data, t) k2 = dt * rhs(state_data + 0.5 * k1, t + 0.5 * dt) k3 = dt * rhs(state_data + 0.5 * k2, t + 0.5 * dt) k4 = dt * rhs(state_data + k3, t + dt) state_data += (k1 + 2 * k2 + 2 * k3 + k4) / 6 modifications += modify_after_step(state_data) return t + dt, modifications self._logger.info(f"Initialized explicit Runge-Kutta-45 stepper with dt=%g", dt) return stepper def _make_rkf_stepper( self, state: FieldBase ) -> Callable[ [np.ndarray, float, float, float, Optional[OnlineStatistics]], Tuple[float, float, int, float], ]: """make an adaptive stepper using the explicit Runge-Kutta-Fehlberg method Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, t_end: float)` """ if self.pde.is_sde: raise RuntimeError( "Cannot use Runge-Kutta-Fehlberg stepper with stochastic equation" ) # obtain functions determining how the PDE is evolved rhs = self._make_pde_rhs(state, backend=self.backend) modify_after_step = jit(self.pde.make_modify_after_step(state)) self.info["stochastic"] = False # obtain auxiliary functions sync_errors = self._make_error_synchronizer() adjust_dt = self._make_dt_adjuster() tolerance = self.tolerance dt_min = self.dt_min compiled = self.backend == "numba" # use Runge-Kutta-Fehlberg method # define coefficients for RK4(5), formula 2 Table III in Fehlberg a2 = 1 / 4 a3 = 3 / 8 a4 = 12 / 13 a5 = 1.0 a6 = 1 / 2 b21 = 1 / 4 b31 = 3 / 32 b32 = 9 / 32 b41 = 1932 / 2197 b42 = -7200 / 2197 b43 = 7296 / 2197 b51 = 439 / 216 b52 = -8.0 b53 = 3680 / 513 b54 = -845 / 4104 b61 = -8 / 27 b62 = 2.0 b63 = -3544 / 2565 b64 = 1859 / 4104 b65 = -11 / 40 r1 = 1 / 360 # r2 = 0 r3 = -128 / 4275 r4 = -2197 / 75240 r5 = 1 / 50 r6 = 2 / 55 c1 = 25 / 216 # c2 = 0 c3 = 1408 / 2565 c4 = 2197 / 4104 c5 = -1 / 5 def stepper( state_data: np.ndarray, t_start: float, t_end: float, dt_init: float, dt_stats: Optional[OnlineStatistics] = None, ) -> Tuple[float, float, int, float]: """compiled inner loop for speed""" modifications = 0.0 dt_opt = dt_init t = t_start steps = 0 while True: # use a smaller (but not too small) time step if close to t_end dt_step = max(min(dt_opt, t_end - t), dt_min) # do the six intermediate steps k1 = dt_step * rhs(state_data, t) k2 = dt_step * rhs(state_data + b21 * k1, t + a2 * dt_step) k3 = dt_step * rhs(state_data + b31 * k1 + b32 * k2, t + a3 * dt_step) k4 = dt_step * rhs( state_data + b41 * k1 + b42 * k2 + b43 * k3, t + a4 * dt_step ) k5 = dt_step * rhs( state_data + b51 * k1 + b52 * k2 + b53 * k3 + b54 * k4, t + a5 * dt_step, ) k6 = dt_step * rhs( state_data + b61 * k1 + b62 * k2 + b63 * k3 + b64 * k4 + b65 * k5, t + a6 * dt_step, ) # estimate the maximal error if compiled: error = 0.0 for i in range(state_data.size): error_local = abs( r1 * k1.flat[i] + r3 * k3.flat[i] + r4 * k4.flat[i] + r5 * k5.flat[i] + r6 * k6.flat[i] ) # max() has the weird behavior that `max(np.nan, 0)` is `np.nan` # while `max(0, np.nan) == 0`. To propagate NaNs in the evaluation, # we thus need to use the following order: error = max(error_local, error) # type: ignore else: error_local = r1 * k1 + r3 * k3 + r4 * k4 + r5 * k5 + r6 * k6 error = np.abs(error_local).sum() error_rel = error / tolerance # normalize error to given tolerance # synchronize the error between all processes (if necessary) error_rel = sync_errors(error_rel) # do the step if the error is sufficiently small if error_rel <= 1: steps += 1 t += dt_step state_data += c1 * k1 + c3 * k3 + c4 * k4 + c5 * k5 modifications += modify_after_step(state_data) if dt_stats is not None: dt_stats.add(dt_step) if t < t_end: # adjust the time step and continue dt_opt = adjust_dt(dt_step, error_rel, t) else: break # return to the controller return t, dt_opt, steps, modifications self._logger.info(f"Initialized adaptive Runge-Kutta-Fehlberg stepper") return stepper def _make_fixed_stepper( self, state: FieldBase, dt: float ) -> Callable[[np.ndarray, float, int], Tuple[float, float]]: """return a stepper function using an explicit scheme with fixed time steps Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted dt (float): Time step of the explicit stepping. """ if self.scheme == "euler": fixed_stepper = self._make_fixed_euler_stepper(state, dt) elif self.scheme in {"runge-kutta", "rk", "rk45"}: fixed_stepper = self._make_rk45_stepper(state, dt) else: raise ValueError(f"Explicit scheme `{self.scheme}` is not supported") if self.backend == "numba": # compile inner stepper sig_fixed = (nb.typeof(state.data), nb.double, nb.int_) fixed_stepper = jit(sig_fixed)(fixed_stepper) self.info["dt_adaptive"] = False return fixed_stepper def _make_adaptive_stepper( self, state: FieldBase, dt: float ) -> Callable[ [np.ndarray, float, float, float, OnlineStatistics], Tuple[float, float, int, float], ]: """return a stepper function using an explicit scheme with fixed time steps Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted dt (float): Initial time step of the adaptive explicit stepping """ if self.pde.is_sde: raise NotImplementedError( "Adaptive stochastic stepping is not implemented. Use a fixed time " "step instead." ) self.info["stochastic"] = False if self.scheme == "euler": adaptive_stepper = self._make_adaptive_euler_stepper(state) elif self.scheme in {"runge-kutta", "rk", "rk45"}: adaptive_stepper = self._make_rkf_stepper(state) else: raise ValueError( f"Explicit adaptive scheme `{self.scheme}` is not supported" ) if self.backend == "numba": # compile inner stepper sig_adaptive = ( nb.typeof(state.data), nb.double, nb.double, nb.double, nb.typeof(self.info["dt_statistics"]), ) adaptive_stepper = jit(sig_adaptive)(adaptive_stepper) self.info["dt_adaptive"] = True return adaptive_stepper
[docs] def make_stepper( self, state: FieldBase, dt: Optional[float] = None ) -> Callable[[FieldBase, float, float], float]: """return a stepper function using an explicit scheme Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted dt (float): Time step of the explicit stepping. If `None`, this solver specifies 1e-3 as a default value. Returns: Function that can be called to advance the `state` from time `t_start` to time `t_end`. The function call signature is `(state: numpy.ndarray, t_start: float, t_end: float)` """ # support `None` as a default value, so the controller can signal that # the solver should use a default time step if dt is None: dt = 1e-3 if not self.adaptive: self._logger.warning( "Explicit stepper with a fixed time step did not receive any " f"initial value for `dt`. Using dt={dt}, but specifying a value or " "enabling adaptive stepping is advisable." ) dt_float = float(dt) # explicit casting to help type checking self.info["dt"] = dt_float self.info["steps"] = 0 self.info["scheme"] = self.scheme self.info["state_modifications"] = 0.0 if self.adaptive: # create stepper with adaptive steps self.info["dt_statistics"] = OnlineStatistics() adaptive_stepper = self._make_adaptive_stepper(state, dt_float) def wrapped_stepper( state: FieldBase, t_start: float, t_end: float ) -> float: """advance `state` from `t_start` to `t_end` using adaptive steps""" nonlocal dt_float # `dt_float` stores value for the next call t_last, dt_float, steps, modifications = adaptive_stepper( state.data, t_start, t_end, dt_float, self.info["dt_statistics"] ) self.info["steps"] += steps self.info["state_modifications"] += modifications return t_last else: # create stepper with fixed steps fixed_stepper = self._make_fixed_stepper(state, dt_float) def wrapped_stepper( state: FieldBase, t_start: float, t_end: float ) -> float: """advance `state` from `t_start` to `t_end` using fixed steps""" # calculate number of steps (which is at least 1) steps = max(1, int(np.ceil((t_end - t_start) / dt_float))) t_last, modifications = fixed_stepper(state.data, t_start, steps) self.info["steps"] += steps self.info["state_modifications"] += modifications return t_last return wrapped_stepper