"""Defines an implicit Euler solver.
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from ..tools.misc import get_array_namespace
from .base import ConvergenceError, SolverBase
from .euler import _DUMMY_FUNCTION, _DUMMY_FUNCTION_2ARGS
if TYPE_CHECKING:
from collections.abc import Callable
from ..backends.base import BackendBase
from ..pdes.base import PDEBase
from ..tools.typing import NumericArray, TField
[docs]
class ImplicitSolver(SolverBase):
"""Implicit (backward) Euler PDE solver."""
name = "implicit"
def __init__(
self,
pde: PDEBase,
*,
maxiter: int = 100,
maxerror: float = 1e-4,
backend: str | BackendBase = "auto",
):
"""
Args:
pde (:class:`~pde.pdes.base.PDEBase`):
The partial differential equation that should be solved
maxiter (int):
Maximum number of iterations for the implicit solver
maxerror (float):
Maximum error tolerance for the implicit solver
backend (str):
The backend used for numerical operations
"""
super().__init__(pde, backend=backend)
self.maxiter = maxiter
self.maxerror = maxerror
def _make_single_step_fixed_dt_deterministic(
self, state: TField, dt: float
) -> Callable[[NumericArray, float], NumericArray]:
"""Return a function doing a deterministic step with an implicit Euler scheme.
Args:
state (:class:`~pde.fields.base.FieldBase`):
An example for the state from which the grid and other information can
be extracted
dt (float):
Time step of the implicit step
"""
if self.pde.is_sde:
msg = "Deterministic implicit Euler does not support stochastic equations"
raise RuntimeError(msg)
self.info["function_evaluations"] = 0
self.info["stochastic"] = False
rhs = self.backend.make_pde_rhs(self.pde, state)
maxiter = int(self.maxiter)
maxerror2 = self.maxerror**2
# handle deterministic version of the pde
def implicit_step(state_data: NumericArray, t: float) -> NumericArray:
"""Compiled inner loop for speed."""
# support any backend following Python Array API
nx = get_array_namespace(state_data)
nfev = 0 # count function evaluations
# save state at current time point t for stepping
state_t = state_data.copy()
# estimate state at next time point
state_data[:] = state_t + dt * rhs(state_data, t)
state_prev = nx.empty_like(state_data)
# fixed point iteration for improving state after dt
for n in range(maxiter):
state_prev[:] = state_data # keep previous state to judge convergence
# another iteration to improve estimate
state_data[:] = state_t + dt * rhs(state_data, t + dt)
# calculate mean squared error to judge convergence
err = 0.0
for j in range(state_data.size):
diff: NumericArray = state_data.flat[j] - state_prev.flat[j]
err += (nx.conj(diff) * diff).real
err /= state_data.size
if err < maxerror2:
# fix point iteration converged
break
else:
msg = "Implicit Euler step did not converge."
raise ConvergenceError(msg)
nfev += n + 1
return state_data
self._logger.info("Initialize implicit Euler single-step update with dt=%g", dt)
return implicit_step
def _make_single_step_fixed_dt_stochastic(
self, state: TField, dt: float
) -> Callable[[NumericArray, float], NumericArray]:
"""Return a function doing a step for a SDE with an implicit Euler scheme.
Args:
state (:class:`~pde.fields.base.FieldBase`):
An example for the state from which the grid and other information can
be extracted
dt (float):
Time step of the implicit step
"""
if self.pde.use_noise_realization:
msg = "Implicit stepper cannot handle `make_noise_realization` interface"
raise NotImplementedError(msg)
self.info["function_evaluations"] = 0
self.info["stochastic"] = True
# get the function that calculates the noise
rhs = self.backend.make_pde_rhs(self.pde, state)
if use_noise_variance := self.pde.use_noise_variance:
noise_var = self.pde.make_noise_variance( # type: ignore
state, backend=self.backend, ret_diff=False
)
gaussian_noise = self.backend.make_gaussian_noise(state, rng=self.pde.rng)
else:
noise_var = _DUMMY_FUNCTION_2ARGS
gaussian_noise = self.backend.compile_function(_DUMMY_FUNCTION)
noise_var = self.backend.compile_function(noise_var)
# noise variance scales with inverse cell volumes
inv_cell = 1 / state.grid.cell_volumes
maxiter = int(self.maxiter)
maxerror2 = self.maxerror**2
# handle deterministic version of the pde
def implicit_step(state_data: NumericArray, t: float) -> NumericArray:
"""Compiled inner loop for speed."""
nx = get_array_namespace(state_data)
nfev = 0 # count function evaluations
# save state at current time point t for stepping
state_t = state_data.copy()
state_prev = nx.empty_like(state_data)
# evaluate deterministic part and variance without modifying field, yet
evolution_rate = rhs(state_data, t)
# add the noise to the reference state at the current time point and
# adept the state at the next time point iteratively below
if use_noise_variance:
noise_var_field = noise_var(state_data, t)
state_t += nx.sqrt(dt * noise_var_field * inv_cell) * gaussian_noise()
state_data[:] = state_t + dt * evolution_rate # estimated new state
# fixed point iteration for improving state after dt
for n in range(maxiter):
state_prev[:] = state_data # keep previous state to judge convergence
# another iteration to improve estimate
state_data[:] = state_t + dt * rhs(state_data, t + dt)
# calculate mean squared error to judge convergence
err = 0.0
for j in range(state_data.size):
diff: NumericArray = state_data.flat[j] - state_prev.flat[j]
err += (nx.conj(diff) * diff).real
err /= state_data.size
if err < maxerror2:
# fix point iteration converged
break
else:
msg = "Semi-implicit Euler-Maruyama step did not converge."
raise ConvergenceError(msg)
nfev += n + 1
return state_data
self._logger.info(
"Initialize semi-implicit Euler-Maruyama single-step update with dt=%g",
dt,
)
return implicit_step
def _make_single_step_fixed_dt(
self, state: TField, dt: float
) -> Callable[[NumericArray, float], NumericArray]:
"""Return a function doing a single step with an implicit Euler scheme.
Args:
state (:class:`~pde.fields.base.FieldBase`):
An example for the state from which the grid and other information can
be extracted
dt (float):
Time step of the implicit step
"""
if self.pde.is_sde:
return self._make_single_step_fixed_dt_stochastic(state, dt)
return self._make_single_step_fixed_dt_deterministic(state, dt)