Source code for pde.solvers.milstein

"""Defines an explicit Milstein solver for stochastic differential equations.

.. autosummary::
   :nosignatures:

   MilsteinSolver

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from ..tools.misc import get_array_namespace
from .euler import _DUMMY_FUNCTION_2ARGS, EulerSolver

if TYPE_CHECKING:
    from collections.abc import Callable

    from pde.tools.typing import NumericArray, TField

    from ..backends.base import BackendBase
    from ..pdes.base import PDEBase


[docs] class MilsteinSolver(EulerSolver): """Milstein method for stochastic differential equations.""" name = "milstein" def __init__( self, pde: PDEBase, *, backend: str | BackendBase = "auto", adaptive: bool = False, tolerance: float = 1e-4, ): """ Args: pde (:class:`~pde.pdes.base.PDEBase`): The partial differential equation that should be solved backend (str): The backend used for numerical operations adaptive (bool): Whether to use adaptive time stepping tolerance (float): Error tolerance for adaptive time stepping """ super().__init__(pde, backend=backend, adaptive=adaptive, tolerance=tolerance) if not pde.use_noise_variance: msg = "Milstein solver requires `use_noise_variance` enabled." raise RuntimeError(msg) def _make_single_step_fixed_dt_stochastic( self, state: TField, dt: float ) -> Callable[[NumericArray, float], NumericArray]: """Make a Euler-Milstein single-step update with fixed time step. Info: This solver should only be used for problems with multiplicative noise. While the solver works for additive noise, the extra corrections might slow down calculations. Args: state (:class:`~pde.fields.base.FieldBase`): An example for the state from which the grid and other information can be extracted dt (float): Time step of the explicit stepping. Returns: Function that updates the state by one time step. The function call signature is `(state_data: numpy.ndarray, t: float)`. """ # create deterministic part rhs_pde = self.backend.make_pde_rhs(self.pde, state) # handle with first noise interface based on supplying the noise variance assert self.pde.use_noise_variance noise_drift_factor = self.pde._noise_drift_factor fn = self.pde.make_noise_variance( # type: ignore state, backend=self.backend, ret_diff=True ) noise_var = self.backend.compile_function(fn) gaussian_noise = self.backend.make_gaussian_noise(state, rng=self.pde.rng) # handle with second noise interface based on supplying a realization if use_noise_realization := self.pde.use_noise_realization: rhs_noise = self.pde.make_noise_realization(state, backend=self.backend) # type: ignore else: rhs_noise = _DUMMY_FUNCTION_2ARGS rhs_noise = self.backend.compile_function(rhs_noise) # noise increment scales with square root of time step dt_sqrt = np.sqrt(dt) # noise variance scales with inverse cell volumes inv_cell = 1 / state.grid.cell_volumes def single_step(state_data: NumericArray, t: float) -> NumericArray: """Perform a single Euler-Milstein step.""" # support any backend following Python Array API nx = get_array_namespace(state_data) # evaluate deterministic part and variance without modifying field, yet evolution_rate = rhs_pde(state_data, t) noise_var_field, noise_var_diff_field = noise_var(state_data, t) # handle second noise interface if use_noise_realization: noise_realization = rhs_noise(state_data, t) if noise_realization is not None: state_data += dt_sqrt * noise_realization # apply the deterministic part and the additive noise dW = dt_sqrt * gaussian_noise() state_data += ( dt * evolution_rate + 0.5 * dt * noise_drift_factor * noise_var_diff_field * inv_cell + nx.sqrt(noise_var_field * inv_cell) * dW + 0.25 * noise_var_diff_field * inv_cell * (dW**2 - dt) ) return state_data self._logger.info( "Initialize explicit Euler-Milstein single-step update with dt=%g", dt ) return single_step