"""Handles configuration variables of the package.
.. autosummary::
:nosignatures:
Parameter
Config
get_package_versions
parse_version_str
check_package_version
packages_from_requirements
get_ffmpeg_version
is_hpc_environment
environment
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import collections
import contextlib
import importlib
import logging
import os
import re
import subprocess as sp
import sys
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
from .misc import module_available
if TYPE_CHECKING:
from collections.abc import Sequence
[docs]
class Parameter:
"""Class representing a single parameter."""
def __init__(
self,
name: str,
default_value=None,
cls=object,
description: str = "",
hidden: bool = False,
extra: dict[str, Any] | None = None,
):
"""Initialize a parameter.
Args:
name (str):
The name of the parameter
default_value:
The default value
cls:
The type of the parameter, which is used for conversion
description (str):
A string describing the impact of this parameter. This
description appears in the parameter help
hidden (bool):
Whether the parameter is hidden in the description summary
extra (dict):
Extra arguments that are stored with the parameter
"""
self.name = name
self.default_value = default_value
self.cls = cls
self.description = description
self.hidden = hidden
self.extra = {} if extra is None else extra
if cls is not object:
# check whether the default value is of the correct type
converted_value = cls(default_value)
if isinstance(converted_value, np.ndarray):
# numpy arrays are checked for each individual value
valid_default = np.allclose(
converted_value, default_value, equal_nan=True
)
else:
# other values are compared directly. Note that we also check identity
# to capture the case where the value is `math.nan`, where the direct
# comparison (nan == nan) would evaluate to False
valid_default = (
converted_value is default_value or converted_value == default_value
)
if not valid_default:
if hasattr(self, "_logger"):
logger: logging.Logger = self._logger
else:
logger = logging.getLogger(self.__class__.__module__)
logger.warning(
"Default value `%s` does not seem to be of type `%s`",
name,
cls.__name__,
)
def __repr__(self):
return (
f'{self.__class__.__name__}(name="{self.name}", default_value='
f'"{self.default_value}", cls="{self.cls.__name__}", '
f'description="{self.description}", hidden={self.hidden})'
)
__str__ = __repr__
[docs]
def convert(self, value=None):
"""Converts a `value` into the correct type for this parameter. If `value` is
not given, the default value is converted.
Note that this does not make a copy of the values, which could lead to
unexpected effects where the default value is changed by an instance.
Args:
value: The value to convert
Returns:
The converted value, which is of type `self.cls`
"""
if value is None:
value = self.default_value
if self.cls is object:
return value
try:
return self.cls(value)
except ValueError as err:
msg = (
f"Could not convert {value!r} to {self.cls.__name__} for parameter "
f"'{self.name}'"
)
raise ValueError(msg) from err
# define default parameter values
DEFAULT_CONFIG: list[Parameter] = [
Parameter(
"operators.conservative_stencil",
True,
bool,
"Indicates whether conservative stencils should be used for differential "
"operators on curvilinear grids. Conservative operators ensure mass "
"conservation at slightly slower computation speed. "
"Note that some backends might ignore this option.",
),
Parameter(
"operators.tensor_symmetry_check",
True,
bool,
"Indicates whether tensor fields are checked for having a suitable form for "
"evaluating differential operators in curvilinear coordinates where some axes "
"are assumed to be symmetric. In such cases, some tensor components might need "
"to vanish, so the result of the operator can be expressed. "
"Note that some backends might ignore this option.",
),
Parameter(
"operators.cartesian.laplacian_2d_corner_weight",
0.0,
float,
"Weighting factor for the corner points of the 2d cartesian Laplacian stencil. "
"The standard value is zero, corresponding to the traditional 5-point stencil. "
"Alternative choices are 1/2 (Oono-Puri stencil) and 1/3 (Patra-Karttunen or "
"Mehrstellen stencil); see https://en.wikipedia.org/wiki/Nine-point_stencil. "
"Note that some backends might ignore this option.",
),
Parameter(
"boundaries.accept_lists",
True,
bool,
"Indicate whether boundary conditions can be set using the deprecated legacy "
"format, where conditions for individual axes and sides where set using lists. "
"If disabled, only the new format using dicts is supported.",
),
Parameter(
"default_backend",
"numba",
str,
"Indicate which backend is selected by default.",
),
]
ConfigValueType = str | float | int | bool | None
[docs]
class Config(collections.UserDict):
"""Class handling general configurations.
Configurations are basically dictionaries with string keys that hold
:class:`Parameter` values, which contain a value with some extra information. For
user-friendliness, we also allow basic values, like strings or numbers as values.
"""
data: dict[str, ConfigValueType | Parameter]
def __init__(
self,
items: Sequence[Parameter] | dict[str, Any] | Config | None = None,
mode: str = "update",
):
"""
Args:
items (dict, optional):
Configuration values that should be added or overwritten to initialize
the configuration.
mode (str):
Defines the mode in which the configuration is used. Possible values are
* `insert`: any new configuration key can be inserted
* `update`: only the values of pre-existing items can be updated
* `locked`: no values can be changed
Note that the items specified by `items` will always be inserted,
independent of the `mode`.
"""
super().__init__()
self.mode = "insert" # temporarily allow inserting items
if isinstance(items, dict):
# use the parameters from the supplied dictionary
self.update(items)
elif isinstance(items, Config):
# use the underlying dictionary to copy the actual Parameter instances
self.update(items.data)
elif items:
# assume that this is a sequence of Parameter
self.update({p.name: p for p in items})
self.mode = mode
def __getitem__(self, key: str):
"""Retrieve item `key`.
Args:
key (str): The configuration key
"""
parameter = self.data[key]
if isinstance(parameter, Parameter):
return parameter.convert()
return parameter
def __setitem__(self, key: str, value: ConfigValueType | Parameter):
"""Update item `key` with `value`.
Args:
key (str): The configuration key
value: The value to set
"""
# determine how to set the item
if self.mode == "insert":
self.data[key] = value
elif self.mode == "update":
try:
self[key] # test whether the key already exist (including magic keys)
except KeyError as err:
msg = f"{key} is not present and config is not in `insert` mode"
raise KeyError(msg) from err
self.data[key] = value
elif self.mode == "locked":
msg = "Configuration is locked"
raise RuntimeError(msg)
else:
msg = f"Unsupported configuration mode `{self.mode}`"
raise ValueError(msg)
def __delitem__(self, key: str):
"""Removes item `key`.
Args:
key (str): The configuration key
"""
if self.mode == "insert":
del self.data[key]
else:
msg = "Configuration is not in `insert` mode"
raise RuntimeError(msg)
[docs]
def to_dict(self, *, ret_values: bool = False) -> dict[str, Any]:
"""Convert the configuration to a simple dictionary.
Args:
ret_values (bool):
Whether to return only values (and not :class:`Parameter` instances)
Returns:
dict: A representation of the configuration in a normal :class:`dict`.
"""
if ret_values:
return dict(**self)
return self.data.copy()
def __repr__(self) -> str:
"""Represent the configuration as a string."""
return f"{self.__class__.__name__}({self.to_dict()!r})"
@contextlib.contextmanager
def __call__(self, values: dict[str, Any] | None = None, **kwargs):
"""Context manager temporarily changing the configuration.
Args:
values (dict): New configuration parameters
**kwargs: New configuration parameters
"""
data_initial = self.data.copy() # save old configuration
# set new configuration
if values is not None:
self.data.update(values)
self.data.update(kwargs)
yield # return to caller
# restore old configuration
self.data = data_initial
[docs]
class GlobalConfig:
"""Class handling the global package configuration.
This class contains additional logic that allows managing multiple configurations,
e.g., including the ones defined in :mod:`pde.backends`. The class also contains
logic to deal with deprecated configuration options.
"""
def __init__(
self,
items: Sequence[Parameter] | dict[str, Any] | None = None,
mode: str = "update",
):
"""
Args:
items (dict, optional):
Configuration values that should be added or overwritten to initialize
the configuration.
mode (str):
Defines the mode in which the configuration is used. Possible values are
* `insert`: any new configuration key can be inserted
* `update`: only the values of pre-existing items can be updated
* `locked`: no values can be changed
Note that the items specified by `items` will always be inserted,
independent of the `mode`.
"""
self._config = Config(items, mode=mode)
def _get_sub_config(self, key: str) -> tuple[Config, str]:
"""Determine the actual configuration where the data is stored.
Some configuration items are stored in sub-configurations, e.g., those for the
backends.
Args:
key (str):
Global key to the configuration option
Returns:
tuple of :class:`Config` and str:
The actual configuration where the configuration is stored and the key
into this configuration.
"""
if key.startswith("backend."):
# use configurations of backend
from ..backends import backends
_, backend, config_key = key.split(".", 2)
return backends[backend].config, config_key
if key.startswith("numba."):
# legacy location for numba related configurations; deprecated on 2025-12-22
warnings.warn(
f"Configuration `{key}` is deprecated. Use `backend.{key}` instead.",
DeprecationWarning,
stacklevel=2,
)
from ..backends import backends
backend, config_key = key.split(".", 1)
return backends[backend].config, config_key
# use global configuration
return self._config, key
def _convert_value(self, key: str, value):
"""Helper function converting certain values.
Args:
key (str): The configuration key
value: The value to convert
Returns:
The converted value
"""
if key.endswith("numba.multithreading") and isinstance(value, bool):
value = "always" if value else "never"
# Deprecated on 2025-02-12
warnings.warn(
"Boolean options are deprecated for `numba.multithreading`. Use "
f"config['numba.multithreading'] = '{value}' instead.",
DeprecationWarning,
stacklevel=2,
)
return value
def __contains__(self, key: str) -> bool:
config, data_key = self._get_sub_config(key)
return data_key in config
def __getitem__(self, key: str):
"""Retrieve item `key`.
Args:
key (str): The configuration key
"""
config, data_key = self._get_sub_config(key)
return config[data_key]
def __iter__(self):
from ..backends import backends
yield from self._config
for backend, config in backends._configs.items():
for subkey in config:
yield f"backend.{backend}.{subkey}"
[docs]
def items(self, just_values: bool = True):
"""Iterate over configuration items.
Args:
just_values (bool):
Whether to yield converted parameter values (`True`) or raw
:class:`Parameter` objects (`False`)
Yields:
tuple: Key-value pairs of configuration items, including items from all
backend configurations with keys prefixed by `backend.<name>.`.
"""
from ..backends import backends
if just_values:
yield from self._config.items()
for backend, config in backends._configs.items():
for subkey, value in config.items():
yield f"backend.{backend}.{subkey}", value
else:
for key in self._config:
yield key, self._config.data[key]
for backend, config in backends._configs.items():
for subkey in config:
yield f"backend.{backend}.{subkey}", config.data[subkey]
[docs]
def update(self, items: dict[str, Any]) -> None:
for k, v in items.items():
self[k] = v
def __setitem__(self, key: str, value):
"""Update item `key` with `value`.
Args:
key (str): The configuration key
value: The value to set
"""
config, data_key = self._get_sub_config(key)
config[data_key] = self._convert_value(key, value)
def __delitem__(self, key: str):
"""Removes item `key`.
Args:
key (str): The configuration key
"""
config, data_key = self._get_sub_config(key)
del config[data_key]
[docs]
def to_dict(
self, *, ret_values: bool = False, incl_backends: bool = True
) -> dict[str, Any]:
"""Convert the configuration to a simple dictionary.
Args:
ret_values (bool):
Whether to return only values (and not :class:`Parameter` instances)
incl_backends (bool):
Whether to include items from the backends
Returns:
dict: A representation of the configuration in a normal :class:`dict`.
"""
# return the global configuration
res = self._config.to_dict(ret_values=ret_values)
# add configurations of the actual backends
if incl_backends:
from ..backends import backends
for backend, config in backends._configs.items():
for name, p in config.to_dict(ret_values=ret_values).items():
res[f"backend.{backend}.{name}"] = p
return res
def __repr__(self) -> str:
"""Represent the configuration as a string."""
return f"{self.__class__.__name__}({self.to_dict(incl_backends=False)!r})"
@contextlib.contextmanager
def __call__(self, values: dict[str, Any] | None = None, **kwargs):
"""Context manager temporarily changing the configuration.
Args:
values (dict): New configuration parameters
**kwargs: New configuration parameters
"""
if values is None:
values = kwargs
else:
values.update(kwargs)
if not values:
# nothing to do
yield
return
data_initial = {key: self[key] for key in values} # save old configuration
# set new configuration
self.update(values)
yield # return to caller
# restore old configuration
self.update(data_initial)
[docs]
def get_package_versions(
packages: list[str], *, na_str="not available"
) -> dict[str, str]:
"""Tries to load certain python packages and returns their version.
Args:
packages (list): The names of all packages
na_str (str): Text to return if package is not available
Returns:
dict: Dictionary with version for each package name
"""
versions: dict[str, str] = {}
for name in sorted(packages):
try:
version = importlib.metadata.version(name)
except ImportError:
versions[name] = na_str
else:
versions[name] = version
return versions
[docs]
def parse_version_str(ver_str: str) -> list[int]:
"""Helper function converting a version string into a list of integers.
Args:
ver_str (str): The version string to parse
Returns:
list[int]: List of version numbers as integers
"""
result = []
for token in ver_str.split(".")[:3]:
with contextlib.suppress(ValueError):
result.append(int(token))
return result
[docs]
def check_package_version(package_name: str, min_version: str):
"""Checks whether a package has a sufficient version.
Args:
package_name (str): The name of the package to check
min_version (str): The minimum required version
"""
msg = f"`{package_name}` version {min_version} required for py-pde"
try:
# obtain version of the package
version = importlib.import_module(package_name).__version__
except ImportError:
warnings.warn(f"{msg} (but none installed)", stacklevel=2)
else:
# check whether it is installed and works
if parse_version_str(version) < parse_version_str(min_version):
warnings.warn(f"{msg} (installed: {version})", stacklevel=2)
[docs]
def packages_from_requirements(requirements_file: Path | str) -> list[str]:
"""Read package names from a requirements file.
Args:
requirements_file (str or :class:`~pathlib.Path`):
The file from which everything is read
Returns:
list of package names
"""
result = []
try:
with Path(requirements_file).open() as fp:
for line in fp:
line_s = line.strip()
if line_s.startswith("#"):
continue
res = re.search(r"[a-zA-Z0-9_\-]+", line_s)
if res:
result.append(res.group(0))
except FileNotFoundError:
result.append(f"Could not open {requirements_file:s}")
return result
[docs]
def get_ffmpeg_version() -> str | None:
"""Read version number of ffmpeg program."""
# run ffmpeg to get its version
try:
version_bytes = sp.check_output(["ffmpeg", "-version"])
except Exception:
return None
# extract the version number from the output
version_string = version_bytes.splitlines()[0].decode("utf-8")
match = re.search(r"version\s+([\w\.]+)\s+copyright", version_string, re.IGNORECASE)
if match:
return match.group(1)
return None
[docs]
def is_hpc_environment() -> bool:
"""Check whether the code is running in a high-performance computing environment.
Returns:
bool: True if running in an HPC environment, False otherwise.
"""
hpc_env_vars = ["SLURM_JOB_ID", "PBS_JOBID", "LSB_JOBID"]
return any(var in os.environ for var in hpc_env_vars)
[docs]
def environment() -> dict[str, Any]:
"""Obtain information about the compute environment.
Returns:
dict: information about the python installation and packages
"""
import matplotlib as mpl
from pde import config
from .. import __version__ as package_version
from ..backends.numba.utils import numba_environment
from . import mpi
from .plotting import get_plotting_context
RESOURCE_PATH = Path(__file__).resolve().parents[1] / "tools" / "resources"
result: dict[str, Any] = {}
result["package version"] = package_version
result["python version"] = sys.version
# check the compute environment
result["environment"] = {"platform": sys.platform, "is_hpc": is_hpc_environment()}
# add ffmpeg version if available
ffmpeg_version = get_ffmpeg_version()
if ffmpeg_version:
result["ffmpeg version"] = ffmpeg_version
# add the package configuration
result["config"] = config.to_dict(ret_values=True)
# add details for mandatory packages
packages_min = packages_from_requirements(RESOURCE_PATH / "requirements_basic.txt")
result["mandatory packages"] = get_package_versions(packages_min)
result["matplotlib environment"] = {
"backend": mpl.get_backend(),
"plotting context": get_plotting_context().__class__.__name__,
}
# add information about jupyter environment
result["jupyter environment"] = get_package_versions(
[
"ipykernel",
"ipywidgets",
"jupyter_client",
"jupyter_core",
"jupyter_server",
"notebook",
]
)
# add details about optional packages
packages = set(packages_from_requirements(RESOURCE_PATH / "requirements_full.txt"))
packages |= set(packages_from_requirements(RESOURCE_PATH / "requirements_mpi.txt"))
packages -= set(packages_min)
result["optional packages"] = get_package_versions(sorted(packages))
if module_available("numba"):
result["numba environment"] = numba_environment()
# add information about MPI environment
if mpi.initialized:
result["multiprocessing"] = {"initialized": True, "size": mpi.size}
else:
result["multiprocessing"] = {"initialized": False}
return result