4.1.1.2 pde.backends.jax.backend module
Defines the jax backend class.
- class JaxBackend(config=None, *, name=None, device='config')[source]
Bases:
BackendBase[Array]Defines
jaxbackend.Initialize the jax backend.
- Parameters:
- compile_function(func)[source]
General method that compiles a user function.
- Parameters:
func (callable) – The function that needs to be compiled for this backend
- Return type:
TFunc
- copy_data = True
Flag indicating whether data needs to be copied between numpy’s representation on CPU and a native device.
- Type:
- get_jax_dtype(dtype)[source]
Convert numpy dtype to jax-compatible dtype.
- Parameters:
dtype (DTypeLike) – numpy dtype to convert to corresponding jax dtype
- Returns:
A proper dtype usable for jax
- Return type:
np.dtype
- implementation = 'jax'
The name of the python module that is used to implement this backend. This information can be used to distinguish the general implementation of backends.
- Type:
- make_data_setter(grid, rank, bcs=None)[source]
Create a function to set the valid part of a full data array.
- Parameters:
grid (
GridBase) – The grid for which the data setter is createdrank (int) – Rank of the data represented on the grid
bcs (
BoundariesBase, optional) – Defines the boundary conditions for a particular grid, for which the setter should be defined.
- Returns:
Takes two numpy arrays, setting the valid data in the first one, using the second array. The arrays need to be allocated already and they need to have the correct dimensions, which are not checked. If bcs are given, a third argument is allowed, which sets arguments for the BCs.
- Return type:
callable
- make_expression_function(expression, *, single_arg=False, user_funcs=None)[source]
Return a function evaluating an expression.
- Parameters:
expression (
ExpressionBase) – The expression that is converted to a functionsingle_arg (bool) – Determines whether the returned function accepts all variables in a single argument as an array or whether all variables need to be supplied separately.
user_funcs (dict) – Additional functions that can be used in the expression.
- Returns:
the function
- Return type:
function
- make_gaussian_noise(field, *, rng)[source]
Create a function generating Gaussian white noise.
This noise is already scaled to respect different cell volumes of the grid.
- Parameters:
field (
FieldBase) – An example for the state from which the grid and other information can be extractedrng (
Generator) – Random number generator (default:default_rng()) used to initialize the seed.
- Return type:
Callable[[], jax.Array]
- make_inner_prod_operator(field, *, conjugate=True)[source]
Return operator calculating the dot product between two fields.
This supports both products between two vectors as well as products between a vector and a tensor.
- Parameters:
field (
DataFieldBase) – Field for which the inner product is definedconjugate (bool) – Whether to use the complex conjugate for the second operand
- Returns:
function that takes two instance of
ndarray, which contain the discretized data of the two operands. An optional third argument can specify the output array to which the result is written.- Return type:
Callable[[jax.Array, jax.Array, jax.Array | None], jax.Array]
- make_operator(grid, operator, *, bcs, dtype=None, **kwargs)[source]
Return a compiled function applying an operator with boundary conditions.
- Parameters:
grid (
GridBase) – Grid for which the operator is neededoperator (str) – Identifier for the operator. Some examples are ‘laplace’, ‘gradient’, or ‘divergence’. The registered operators for this grid can be obtained from the
operatorsattribute.bcs (
BoundariesBase, optional) – The boundary conditions used before the operator is applieddtype (numpy dtype) – The data type of the field.
**kwargs – Specifies extra arguments influencing how the operator is created.
- Return type:
The returned function takes the discretized data on the grid as an input and returns the data to which the operator operator has been applied. The function only takes the valid grid points and allocates memory for the ghost points internally to apply the boundary conditions specified as bc. Note that the function supports an optional argument out, which if given should provide space for the valid output array without the ghost cells. The result of the operator is then written into this output array.
The function also accepts an optional parameter args, which is forwarded to set_ghost_cells. This allows setting boundary conditions based on external parameters, like time. When this backend is used together with JAX’ just-in-time compilation (e.g. via
jax.jit()), the values passed through args need to be compatible with JAX’s JIT tracing rules.- Returns:
the function that applies the operator. This function has the signature (arr: NumericArray, out: NumericArray = None, args=None).
- Return type:
callable
- Parameters:
grid (GridBase)
operator (str | OperatorInfo)
bcs (BoundariesBase)
dtype (DTypeLike | None)
- make_operator_no_bc(grid, operator, *, dtype=None, **kwargs)[source]
Return a compiled function applying an operator without boundary conditions.
A function that takes the discretized full data as an input and an array of valid data points to which the result of applying the operator is written.
Note
The resulting function does not check whether the ghost cells of the input array have been supplied with sensible values. It is the responsibility of the user to set the values of the ghost cells beforehand. Use this function only if you absolutely know what you’re doing. In all other cases,
make_operator()is probably the better choice.- Parameters:
grid (
GridBase) – Grid for which the operator is neededoperator (str) – Identifier for the operator. Some examples are ‘laplace’, ‘gradient’, or ‘divergence’. The registered operators for this grid can be obtained from the
operatorsattribute.dtype (numpy dtype) – The data type of the field.
**kwargs – Specifies extra arguments influencing how the operator is created.
- Returns:
the function that applies the operator. This function has the signature (arr: NumericArray, out: NumericArray), so they out array need to be supplied explicitly.
- Return type:
callable
- make_outer_prod_operator(field)[source]
Return operator calculating the outer product between two fields.
This typically only supports products between two vector fields.
- Parameters:
field (
DataFieldBase) – Field for which the outer product is defined- Returns:
function that takes two instance of
ndarray, which contain the discretized data of the two operands. An optional third argument can specify the output array to which the result is written.- Return type:
Callable[[jax.Array, jax.Array, jax.Array | None], jax.Array]
- make_stepper(solver, state)[source]
Create a field-based stepping function for a given solver.
- Parameters:
solver (
SolverBase) – The solver instance, which determines how the stepper is constructedstate (
FieldBase) – An example for the state from which the grid and other information can be extracted
- Returns:
Function that can be called to advance the state from time t_start to time t_end. The function call signature is (state: numpy.ndarray, t_start: float, t_end: float)
- Return type: