"""Defines a PDE class whose right hand side is given as a string.
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import numbers
import re
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Callable, Literal
import numba as nb
import numpy as np
from numba.extending import register_jitable
from numba.typed import Dict as NumbaDict
from sympy import Symbol
from sympy.core.function import UndefinedFunction
from ..fields import FieldCollection, VectorField
from ..fields.base import FieldBase
from ..fields.datafield_base import DataFieldBase
from ..grids.boundaries import set_default_bc
from ..grids.boundaries.axes import BoundariesData
from ..grids.boundaries.local import BCDataError
from ..pdes.base import PDEBase, TState
from ..tools.docstrings import fill_in_docstring
from ..tools.numba import jit
from ..tools.typing import ArrayLike, NumberOrArray, StepperHook
if TYPE_CHECKING:
import sympy
# Define short notations that can appear in mathematical equations and need to be
# expanded. Since these replacements are replaced in order, it's advisable to start with
# more complex expressions first
_EXPRESSION_REPLACEMENT: dict[str, str] = {
r"\|\s*∇\s*(\w+)\s*\|(²|\*\*2)": r"gradient_squared(\1)", # |∇c|² or |∇c|**2
r"∇(²|\*\*2)\s*(\w+)": r"laplace(\2)", # ∇²c or ∇**2 c
r"∇(²|\*\*2)\s*\(": r"laplace(", # ∇²(c) or ∇**2(c)
r"²": r"**2",
r"³": r"**3",
}
# Define how common operators map to Fourier space
_OPERATOR_FOURIER_MAPPING = {
"laplace": "-wave_vector**2 * argument",
"gradient": "I * wave_vector * argument",
"divergence": "I * wave_vector * argument",
# "gradient_squared": "wave_vector**2 * argument**2", # or "0"? <- CHECK
}
[docs]
class PDE(PDEBase):
"""PDE defined by mathematical expressions.
Attributes:
variables (tuple):
The name of the variables (i.e., fields) in the order they are expected to
appear in the `state`.
"""
default_bc = "auto_periodic_neumann"
"""Default boundary condition used when no specific conditions are chosen."""
@fill_in_docstring
def __init__(
self,
rhs: dict[str, str],
*,
bc: BoundariesData | None = None,
bc_ops: dict[str, BoundariesData] | None = None,
post_step_hook: Callable[[np.ndarray, float], None] | None = None,
user_funcs: dict[str, Callable] | None = None,
consts: dict[str, NumberOrArray] | None = None,
noise: ArrayLike = 0,
rng: np.random.Generator | None = None,
):
r"""
Warning:
{WARNING_EXEC}
Args:
rhs (dict):
The expressions defining the evolution rate. The dictionary keys define
the name of the fields whose evolution is considered, while the values
specify their evolution rate as a string that can be parsed by
:mod:`sympy`. These expression may contain variables (i.e., the fields
themselves, spatial coordinates of the grid, and `t` for the time),
standard local mathematical operators defined by sympy, and the
operators defined in the :mod:`pde` package. Note that operators need to
be specified with their full name, i.e., `laplace` for a scalar
Laplacian and `vector_laplace` for a Laplacian operating on a vector
field. Moreover, the dot product between two vector fields can be
denoted by using `dot(field1, field2)` in the expression, an outer
product is calculated using `outer(field1, field2)`, and
`integral(field)` denotes an integral over a field.
More information can be found in the
:ref:`expression documentation <documentation-expressions>`.
bc:
General boundary conditions for all operators that do not have a
specialized condition given in `bc_ops`.
{ARG_BOUNDARIES}
bc_ops (dict):
Special boundary conditions for specific operators. The keys in this
dictionary specify where the boundary condition will be applied.
The keys follow the format "VARIABLE:OPERATOR", where VARIABLE specifies
the expression in `rhs` where the boundary condition is applied to the
operator specified by OPERATOR. For both identifiers, the wildcard
symbol "\*" denotes that all fields and operators are affected,
respectively. For instance, the identifier "c:\*" allows specifying a
condition for all operators of the field named `c`.
post_step_hook (callable):
A function with signature `(state_data, t)` that will be called after
every time step. The function can modify the :class:`~numpy.ndarray` of
the state_data in place and it can abort the simulation immediately by
raising `StopIteration`. Since the callback defined here will be called
often, it is best to compile the function with :mod:`numba` for speed.
user_funcs (dict, optional):
A dictionary with user defined functions that can be used in the
expressions in `rhs`.
consts (dict, optional):
A dictionary with user defined constants that can be used in the
expression. These can be either scalar numbers or fields defined on the
same grid as the actual simulation.
noise (float or :class:`~numpy.ndarray`):
Variance of additive Gaussian white noise. The default value of zero
implies deterministic partial differential equations will be solved.
Different noise magnitudes can be supplied for each field in coupled
PDEs by either specifying a sequence of numbers or a dictionary with
values for each field.
rng (:class:`~numpy.random.Generator`):
Random number generator (default: :func:`~numpy.random.default_rng()`)
used for stochastic simulations. Note that this random number generator
is only used for numpy function, while compiled numba code uses the
random number generator of numba. Moreover, in simulations using
multiprocessing, setting the same generator in all processes might yield
unintended correlations in the simulation results.
Note:
The order in which the fields are given in `rhs` defines the order in which
they need to appear in the `state` variable when the evolution rate is
calculated. Note that `dict` keep the insertion order since Python version
3.7, so a normal dictionary can be used to define the equations.
"""
from sympy.core.function import AppliedUndef
from ..tools.expressions import ScalarExpression
# parse noise strength
if isinstance(noise, dict):
noise = [noise.get(var, 0) for var in rhs]
if hasattr(noise, "__iter__") and len(noise) != len(rhs): # type: ignore
raise ValueError("Number of noise strengths does not match field count")
super().__init__(noise=noise, rng=rng)
# validate input
if not isinstance(rhs, dict):
rhs = dict(rhs)
if "t" in rhs:
raise ValueError("Cannot name field `t` since it denotes time")
if consts is None:
consts = {}
# turn the expression strings into sympy expressions
self._rhs_expr, self._operators = {}, {}
explicit_time_dependence = False
complex_valued = False
for var, rhs_item in rhs.items():
# replace shorthand operators
if isinstance(rhs_item, str):
rhs_item_old = rhs_item
for search, repl in _EXPRESSION_REPLACEMENT.items():
rhs_item = re.sub(search, repl, rhs_item)
if rhs_item != rhs_item_old:
self._logger.info("Transformed expression to `%s`", rhs_item)
# create placeholder dictionary of constants that will be specified later
consts_d: dict[str, NumberOrArray] = {name: 0 for name in consts}
rhs_expr = ScalarExpression(
rhs_item,
user_funcs=user_funcs,
consts=consts_d,
explicit_symbols=rhs.keys(), # type: ignore
)
if rhs_expr.depends_on("t"):
explicit_time_dependence = True
if rhs_expr.complex:
complex_valued = True
# determine undefined functions in the expression
self._operators[var] = {
func.__class__.__name__
for func in rhs_expr._sympy_expr.atoms(AppliedUndef)
if func.__class__.__name__ not in rhs_expr.user_funcs
}
self._rhs_expr[var] = rhs_expr
# set public instance attributes
self.rhs = rhs
self.variables = tuple(rhs.keys())
self.consts = consts
self.explicit_time_dependence = explicit_time_dependence
self.complex_valued = complex_valued
# setup boundary conditions
bc = set_default_bc(bc, self.default_bc)
if bc_ops is None:
bcs = {"*:*": bc}
elif isinstance(bc_ops, dict):
bcs = dict(bc_ops)
if "*:*" in bcs and bc != "auto_periodic_neumann":
self._logger.warning("Found default BCs in `bcs` and `bc_ops`")
bcs["*:*"] = bc # append default boundary conditions
else:
raise TypeError(f'`bc_ops` must be a dictionary, but got {type(bc_ops)}"')
self.bcs: dict[str, Any] = {}
for key_str, value in bcs.items():
# split on . and :
parts = re.split(r"\.|:", key_str)
if len(parts) == 1:
if len(self.variables):
key = f"{self.variables[0]}:{key_str}"
else:
raise ValueError(
f'Boundary condition "{key_str}" is ambiguous. Use format '
'"VARIABLE:OPERATOR" instead.'
)
elif len(parts) == 2:
key = ":".join(parts)
else:
raise ValueError(f'Cannot parse boundary condition "{key_str}"')
if key in self.bcs:
self._logger.warning("Two boundary conditions for key %s", key)
self.bcs[key] = value
# save information for easy inspection
self.diagnostics["pde"] = {
"variables": list(self.variables),
"constants": sorted(self.consts),
"explicit_time_dependence": explicit_time_dependence,
"complex_valued_rhs": complex_valued,
"operators": sorted(set().union(*self._operators.values())),
}
self._cache: dict[str, dict[str, Any]] = {}
self.post_step_hook = post_step_hook
@property
def expressions(self) -> dict[str, str]:
"""Show the expressions of the PDE."""
return {k: v.expression for k, v in self._rhs_expr.items()}
def _compile_rhs_single(
self,
var: str,
ops: dict[str, Callable],
state: FieldBase,
backend: Literal["numpy", "numba"] = "numpy",
):
"""Compile a function determining the right hand side for one variable.
Args:
var (str):
The variable that is considered
ops (dict):
A dictionary of operators that can be used by this function. Note that
this dictionary might be modified in place
state (:class:`~pde.fields.FieldBase`):
The field describing the state of the PDE
backend (str):
The backend for which the data is prepared
Returns:
callable: The function calculating the RHS
"""
# modify a copy of the expression and the general operator array
expr = self._rhs_expr[var].copy()
# obtain the (differential) operators for this variable
for func in self._operators[var]:
if func in ops:
continue
# determine boundary conditions for this operator and variable
for bc_key, bc in self.bcs.items():
bc_var, bc_func = bc_key.split(":")
var_match = bc_var == var or bc_var == "*"
func_match = bc_func == func or bc_func == "*"
if var_match and func_match: # found a matching boundary condition
self.diagnostics["pde"]["bcs_used"].add(bc_key) # register it
break # continue with this BC
else:
raise RuntimeError(
"Could not find suitable boundary condition for function "
f"`{func}` applied in equation for `{var}`"
)
# Tell the user what BC we chose for a given operator
msg = "Using boundary condition `%s` for operator `%s` in PDE for `%s`"
self._logger.info(msg, bc, func, var)
# create the function evaluating the operator
try:
ops[func] = state.grid.make_operator(func, bc=bc)
except BCDataError:
# wrong data was supplied for the boundary condition
raise
except Exception as err:
err.args += (
f"Problems in boundary condition `{bc}` for operator `{func}` in "
f"PDE for `{var}`",
)
raise err
# add `bc_args` as an argument to the call of the operators to be able
# to pass additional information, like time
expr._sympy_expr = expr._sympy_expr.replace(
# only modify the relevant operator
lambda expr: isinstance(expr.func, UndefinedFunction)
and expr.name == func # noqa: B023
# and do not modify it when the bc_args have already been set
and not (
isinstance(expr.args[-1], Symbol)
and expr.args[-1].name == "bc_args"
),
# otherwise, add None and bc_args as arguments
lambda expr: expr.func(*expr.args, Symbol("none"), Symbol("bc_args")),
)
# obtain the function to calculate the right hand side
signature = self.variables + ("t", "none", "bc_args")
# check whether this function depends on additional input
if any(expr.depends_on(c) for c in state.grid.axes):
# expression has a spatial dependence, too
# extend the signature
signature += tuple(state.grid.axes)
# inject the spatial coordinates into the expression for the rhs
extra_args = tuple(
state.grid.cell_coords[..., i] for i in range(state.grid.num_axes)
)
else:
# expression only depends on the actual variables
extra_args = ()
# check whether all variables are accounted for
extra_vars = set(expr.vars) - set(signature)
if extra_vars:
extra_vars_str = ", ".join(sorted(extra_vars))
raise RuntimeError(f"Undefined variable in expression: {extra_vars_str}")
expr.vars = signature
self._logger.info("RHS for `%s` has signature %s", var, signature)
# prepare the actual function being called in the end
if backend == "numpy":
func_inner = expr._get_function(single_arg=False, user_funcs=ops)
elif backend == "numba":
func_pure = expr._get_function(
single_arg=False, user_funcs=ops, prepare_compilation=True
)
func_inner = jit(func_pure)
else:
raise ValueError(f"Unsupported backend {backend}")
def rhs_func(*args) -> np.ndarray:
"""Wrapper that inserts the extra arguments and initialized bc_args."""
bc_args = NumbaDict() # args for differential operators
bc_args["t"] = args[-1] # pass time to differential operators
return func_inner(*args, None, bc_args, *extra_args) # type: ignore
return rhs_func
def _prepare_cache(
self, state: TState, backend: Literal["numpy", "numba"] = "numpy"
) -> dict[str, Any]:
"""Prepare the expression by setting internal variables in the cache.
Note that the expensive calculations in this method are only carried out if the
state attributes change.
Args:
state (:class:`~pde.fields.FieldBase`):
The field describing the state of the PDE
backend (str):
The backend for which the data is prepared
Returns:
dict: A dictionary with information that can be reused
"""
# check the cache
cache = self._cache.get(backend, {})
if state.attributes == cache.get("state_attributes", None):
return cache # this cache was already prepared
cache = self._cache[backend] = {} # clear cache, if there was any
# check whether PDE has variables with same names as grid axes
name_overlap = set(self.rhs) & set(state.grid.axes)
if name_overlap:
raise ValueError(f"Coordinate {name_overlap} cannot be used as field name")
# check whether the state is compatible with the PDE
num_fields: int = len(self.variables)
self.diagnostics["pde"]["num_fields"] = num_fields
if isinstance(state, FieldCollection):
if num_fields != len(state):
raise ValueError(
f"Expected {num_fields} fields in state, but got {len(state)} ones"
)
elif isinstance(state, DataFieldBase):
if num_fields != 1:
raise ValueError(
f"Expected {num_fields} fields in state, but got only one"
)
else:
raise ValueError(f"Unknown state class {state.__class__.__name__}")
# check compatibility of constants and update the rhs accordingly
for name, value in self.consts.items():
# check whether the constant has a supported value
if np.isscalar(value):
pass # this simple case is fine
elif isinstance(value, DataFieldBase):
# constant is a field, which might need to be split in MPI simulation
if state.grid._mesh is not None:
value.grid.assert_grid_compatible(state.grid._mesh.basegrid)
value = state.grid._mesh.split_field_data_mpi(value.data)
else:
value.grid.assert_grid_compatible(state.grid)
value = value.data # just keep the actual discretized data
else:
raise TypeError(f"Constant has unsupported type {value.__class__}")
for rhs in self._rhs_expr.values():
rhs.consts[name] = value # type: ignore
# obtain functions used in the expression
ops_general: dict[str, Callable] = {}
# create special operators if necessary
operators = self.diagnostics["pde"]["operators"]
if "dot" in operators:
# add dot product between two vector fields. This can for instance
# appear when two gradients of scalar fields need to be multiplied
ops_general["dot"] = VectorField(state.grid).make_dot_operator(backend)
if "inner" in operators:
# inner is a synonym for dot product operator
ops_general["inner"] = VectorField(state.grid).make_dot_operator(backend)
if "outer" in operators:
# generate an operator that calculates an outer product
vec_field = VectorField(state.grid)
ops_general["outer"] = vec_field.make_outer_prod_operator(backend)
if "integral" in operators:
# add an operator that integrates a field
ops_general["integral"] = state.grid.make_integrator()
# Create the right hand sides for all variables. It is important to do this in a
# separate function, so the closures work reliably
self.diagnostics["pde"]["bcs_used"] = set() # keep track of the used BCs
cache["rhs_funcs"] = [
self._compile_rhs_single(var, ops_general.copy(), state, backend)
for var in self.variables
]
# check whether there are boundary conditions that have not been used
bcs_left = set(self.bcs.keys()) - self.diagnostics["pde"]["bcs_used"] - {"*:*"}
if bcs_left:
self._logger.warning("Unused BCs: %s", sorted(bcs_left))
# add extra information for field collection
if isinstance(state, FieldCollection):
# isscalar be False even if start == stop (e.g. vector fields)
isscalar: tuple[bool, ...] = tuple(field.rank == 0 for field in state)
starts: tuple[int, ...] = tuple(slc.start for slc in state._slices)
stops: tuple[int, ...] = tuple(slc.stop for slc in state._slices)
def get_data_tuple(state_data: np.ndarray) -> tuple[np.ndarray, ...]:
"""Helper for turning state_data into a tuple of field data."""
return tuple(
(
state_data[starts[i]]
if isscalar[i]
else state_data[starts[i] : stops[i]]
)
for i in range(num_fields)
)
cache["get_data_tuple"] = get_data_tuple
# store the attributes in the cache, which allows to later circumvent
# calculating the quantities above again. Note that this has to be the
# last expression of the method, so the cache is only valid when the
# prepare function worked successfully
cache["state_attributes"] = state.attributes
return cache
[docs]
def evolution_rate(self, state: TState, t: float = 0.0) -> TState:
"""Evaluate the right hand side of the PDE.
Args:
state (:class:`~pde.fields.FieldBase`):
The field describing the state of the PDE
t (float):
The current time point
Returns:
:class:`~pde.fields.FieldBase`:
Field describing the evolution rate of the PDE
"""
cache = self._prepare_cache(state, backend="numpy")
# create an empty copy of the current field
result = state.copy()
# fill it with data
if isinstance(state, DataFieldBase):
# state is a single field
result.data[:] = cache["rhs_funcs"][0](state.data, t)
elif isinstance(state, FieldCollection):
# state is a collection of fields
for i in range(len(state)):
data_tpl = cache["get_data_tuple"](state.data)
result[i].data[:] = cache["rhs_funcs"][i](*data_tpl, t)
else:
raise TypeError(f"Unsupported field {state.__class__.__name__}")
return result
[docs]
def make_post_step_hook(self, state: FieldBase) -> tuple[StepperHook, Any]:
"""Returns a function that is called after each step.
Args:
state (:class:`~pde.fields.FieldBase`):
An example for the state from which the grid and other information can
be extracted
Returns:
tuple: The first entry is the function that implements the hook. The second
entry gives the initial data that is used as auxiliary data in the hook.
This can be `None` if no data is used.
Raises:
NotImplementedError: When :attr:`post_step_hook` is `None`.
"""
if self.post_step_hook is None:
raise NotImplementedError("`post_step_hook` not set")
else:
post_step_hook = register_jitable(self.post_step_hook)
@register_jitable
def post_step_hook_impl(state_data, t, post_step_data):
post_step_hook(state_data, t)
return post_step_hook_impl, 0 # hook function and initial value
# time will not be updated
def _make_pde_rhs_numba_coll(
self, state: FieldCollection, cache: dict[str, Any]
) -> Callable[[np.ndarray, float], np.ndarray]:
"""Create the compiled rhs if `state` is a field collection.
Args:
state (:class:`~pde.fields.FieldCollection`):
An example for the state defining the grid and data types
cache (dict):
Cached information that will be used in the function. The cache is
populated by :meth:`PDE._prepare_cache`.
Returns:
A function with signature `(state_data, t)`, which can be called
with an instance of :class:`~numpy.ndarray` of the state data and
the time to obtained an instance of :class:`~numpy.ndarray` giving
the evolution rate.
"""
num_fields = len(state)
data_shape = state.data.shape
rhs_list = tuple(jit(cache["rhs_funcs"][i]) for i in range(num_fields))
starts = tuple(slc.start for slc in state._slices)
stops = tuple(slc.stop for slc in state._slices)
get_data_tuple = cache["get_data_tuple"]
def chain(
i: int = 0,
inner: Callable[[np.ndarray, float, np.ndarray], None] | None = None,
) -> Callable[[np.ndarray, float], np.ndarray]:
"""Recursive helper function for applying all rhs."""
# run through all functions
rhs = rhs_list[i]
if inner is None:
# the innermost function does not need to call a child
@jit
def wrap(data_tpl: np.ndarray, t: float, out: np.ndarray) -> None:
out[starts[i] : stops[i]] = rhs(*data_tpl, t)
else:
# all other functions need to call one deeper in the chain
@jit
def wrap(data_tpl: np.ndarray, t: float, out: np.ndarray) -> None:
inner(data_tpl, t, out) # type: ignore
out[starts[i] : stops[i]] = rhs(*data_tpl, t)
if i < num_fields - 1:
# there are more items in the chain
return chain(i + 1, inner=wrap)
else:
# this is the outermost function
@jit
def evolution_rate(state_data: np.ndarray, t: float = 0) -> np.ndarray:
out = np.empty(data_shape)
with nb.objmode():
data_tpl = get_data_tuple(state_data)
wrap(data_tpl, t, out)
return out
return evolution_rate # type: ignore
# compile the recursive chain
return chain()
def _make_pde_rhs_numba( # type: ignore
self, state: TState, **kwargs
) -> Callable[[np.ndarray, float], np.ndarray]:
"""Create a compiled function evaluating the right hand side of the PDE.
Args:
state (:class:`~pde.fields.FieldBase`):
An example for the state defining the grid and data types
Returns:
A function with signature `(state_data, t)`, which can be called with an
instance of :class:`~numpy.ndarray` of the state data and the time to
obtained an instance of :class:`~numpy.ndarray` giving the evolution rate.
"""
cache = self._prepare_cache(state, backend="numba")
if isinstance(state, DataFieldBase):
# state is a single field
return jit(cache["rhs_funcs"][0]) # type: ignore
elif isinstance(state, FieldCollection):
# state is a collection of fields
return self._make_pde_rhs_numba_coll(state, cache)
else:
raise TypeError(f"Unsupported field {state.__class__.__name__}")
def _jacobian_spectral(
self,
state_hom: numbers.Number | list | dict[str, float] | None = None,
*,
t: float = 0,
wave_vector: str | sympy.Symbol = "q",
check_steady_state: bool = True,
) -> sympy.Matrix:
"""Calculate the Jacobian in spectral representation.
Note:
This method currently only supports scalar fields, so that inner and outer
products are not permissible. Moreover, `user_funcs` are typically not
supported and `integral` does not work.
Args:
state_hom (number or list or dict):
Field values of a homogeneous state around which the Jacobian is
determined. If only a single value is given, this value is used for all
fields. If omitted, general expressions containing the fields are
returned.
t (float):
Time point necessary for explicit time dependences
wave_vector (str or :class:`~sympy.Symbol`):
Symbol denoting the wave vector.
check_steady_state (bool):
Checks whether a supplied `state_hom` is a stationary state and raises
an `RuntimeError` otherwise.
Returns:
:class:`~sympy.Matrix`: The Jacobian matrix (evaluated at the homogeneous
state `state_hom` if provided).
"""
import sympy
# basic checks
if wave_vector == "t":
raise ValueError("`wave_vector` must not be `t`")
if wave_vector in self.variables:
raise ValueError(f"`wave_vector` must be different from {self.variables}")
if state_hom is None:
state_dict: Mapping[str, float | complex] | None = None
else:
# prepare homogeneous state
if isinstance(state_hom, dict):
state_dict = state_hom
else:
dim = len(self.variables)
if isinstance(state_hom, numbers.Number):
state_dict = {v: state_hom for v in self.variables} # type: ignore
elif len(state_hom) != dim:
raise ValueError(f"Expect {dim} values in `state_hom`")
else:
state_dict = {v: state_hom[i] for i, v in enumerate(self.variables)}
for v, state in state_dict.items():
if not isinstance(state, numbers.Number):
raise TypeError(f"Value for field `{v}` is not a number")
# prepare fourier transformed operators
q_sym = sympy.symbols(wave_vector)
q_sym_def = sympy.symbols("wave_vector")
arg = sympy.symbols("argument")
fourier_repl = {}
for op, opF in _OPERATOR_FOURIER_MAPPING.items():
opF_expr = sympy.parse_expr(opF).subs(q_sym_def, q_sym)
op_sym = sympy.symbols(op, cls=sympy.Function)
fourier_repl[op_sym] = sympy.Lambda(arg, opF_expr)
# collect the entries of the Jacobian matrix
jacobian = []
for v1 in self.variables:
# convert expressions to Fourier space (by replacing derivatives)
expr = self._rhs_expr[v1]._sympy_expr.subs("t", t)
exprF = expr.subs(fourier_repl)
# check that state_hom marks a stationary state
if check_steady_state and state_dict is not None:
exprF0 = exprF.subs(wave_vector, 0)
try:
exprF0_val = float(exprF0.subs(state_dict))
except Exception as e:
if len(e.args) >= 1:
e.args = (e.args[0] + f" (Expression: {exprF0})",) + e.args[1:]
raise
if not np.isclose(exprF0_val, 0):
raise RuntimeError("State is not a stationary state")
# calculate Jacobian
jac_line = []
for v2 in self.variables:
el = exprF.diff(v2)
if state_dict is not None:
el = el.subs(state_dict)
jac_line.append(sympy.simplify(el))
jacobian.append(jac_line)
return sympy.Matrix(jacobian)
def _dispersion_relation(
self,
state_hom: list | dict[str, float],
qs: np.ndarray | None = None,
*,
t: float = 0,
) -> tuple[np.ndarray, np.ndarray]:
"""Evaluate the dispersion relation.
Args:
state_hom (list or dict):
Field values for the homogeneous state around which the Jacobian is
determined.
qs (:class:`~numpy.ndarray`):
Wave vectors at which the dispersion relation is evaluated.
t (float):
Time point necessary for explicit time dependences
Returns:
tuple of :class:`~numpy.ndarray`: Wave vectors and associated eigenvalues of
the Jacobian
"""
import sympy
if qs is None:
qs = np.linspace(0, 1)
jac = self._jacobian_spectral(state_hom, t=t, wave_vector="wave_vector")
evs_list = []
for q in qs:
jacN = sympy.matrix2numpy(jac.subs("wave_vector", q), dtype=complex)
evs = np.linalg.eigvals(jacN)
evs_list.append(evs)
return qs, np.array(evs_list)