Source code for pde.backends.registry

"""Defines the registry for managing backends.

.. autosummary::
   :nosignatures:

   BackendRegistry
   load_default_config
   get_backend
   registered_backends

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

import contextlib
import importlib
import importlib.util
import logging
from typing import TYPE_CHECKING

from .. import config as global_config
from .base import _RESERVED_BACKEND_NAMES, BackendBase

if TYPE_CHECKING:
    from collections.abc import Iterator, Sequence
    from pathlib import Path

    from ..tools.config import ConfigLike, Parameter


_logger = logging.getLogger(__name__)
""":class:`logging.Logger`: Logger instance."""


[docs] class BackendRegistry: """Class handling all backends and their configurations. Backends can exist in three different states in registry: * Registered meta-information on how to load a backend package * Loaded backend module, so the class is available * Fully instantiated :class:`~pde.backends.base.BackendBase` classes """ _packages: dict[str, str] """dict: backends whose packages have been registered""" _classes: dict[str, type[BackendBase]] """dict: backends whose classes have been defined""" _backends: dict[str, BackendBase] """dict: backends that have been instantiated""" def __init__(self): self._packages = {} self._classes = {} self._backends = {}
[docs] def register_package( self, name: str, package_path: str, *, config: ConfigLike | Sequence[Parameter] | None = None, ) -> None: """Register a backend python package (without loading it yet) Args: name (str): Name of the backend package_path (str): Import path for the package config (list): Configuration options for the package """ if name in _RESERVED_BACKEND_NAMES: _logger.warning("Reserved backend name `%s` should not be used.", name) if name in self._packages: msg = f"Cannot redefine backend `{name}`" raise RuntimeError(msg) self._packages[name] = package_path with global_config.changed_mode(node="insert", leaf="insert"): if config is None: global_config["backend"].create_node(name) else: global_config["backend"][name] = config
[docs] def register_class(self, name: str, cls: type[BackendBase]): """Register a backend class. Args: name (str): Name of the backend cls (subclass of :class:`~pde.backend.base.BackendBase`): The class for creating a backend """ if name in _RESERVED_BACKEND_NAMES: _logger.warning("Reserved backend name `%s` should not be used.", name) if name in self._classes: _logger.info("Redefining backend `%s`", name) self._classes[name] = cls
[docs] def register_backend(self, backend: BackendBase) -> None: """Register a loaded backend object. Args: backend (:class:`~pde.backends.base.BackendBase`): Implementation of the backend """ if backend.name in _RESERVED_BACKEND_NAMES: _logger.warning( "Reserved backend name `%s` should not be used.", backend.name ) if backend.name in self._backends: if isinstance(self._backends[backend.name], str): _logger.info("Loading backend `%s`", backend.name) else: _logger.info("Reloading backend `%s`", backend.name) self._backends[backend.name] = backend
def _get_class(self, name: str) -> type[BackendBase]: """Get the class associated with a particular backend. Args: name (str): The name of the backend class to load """ if name not in self._classes: # determine the backend information to load it try: package_path = self._packages[name] except KeyError as err: backends = ", ".join(self._packages.keys()) msg = f"Backend `{name}` not in [{backends}]" raise KeyError(msg) from err # load the backend from a python package, which should register its class importlib.import_module(package_path) if name not in self._classes: msg = f"Backend `{name}` was loaded, but did not register its class." raise RuntimeError(msg) return self._classes[name] def _get_backend( self, name: str, *, link_global_config: bool, **kwargs ) -> BackendBase: """Create backend object, potentially loading the respective package. Args: name (str): Name of the backend to be loaded. link_global_config (bool): If true, this backend may directly use the global configuration as its own, linking changes to its config to the global one. This will only be enabled if `name` directly specifies the backend without additional configurations. **kwargs: Additional options of the backend Returns: :class:`~pde.backends.base.BackendBase`: An instance of the backend with the particular configuration """ # create backend from the class definition parts = name.split(":", 1) if len(parts) == 2: cls_name, args = parts elif len(parts) == 1: cls_name, args = parts[0], None else: raise RuntimeError cls = self._get_class(cls_name) # determine the configuration of the backend if link_global_config and name == cls_name: # directly use the global configuration backend_config = global_config["backend"][cls_name] else: # detach config of specific backend from general backend backend_config = global_config["backend"][cls_name].copy() # create the backend if args: backend_obj = cls.from_args(backend_config, args, name=name, **kwargs) else: backend_obj = cls(backend_config, name=name, **kwargs) return backend_obj
[docs] def get_backend( self, name: str, *, config: ConfigLike | None = None, **kwargs ) -> BackendBase: """Return backend object, potentially loading the respective package. The returned backend is cached if `config` is not specified. Consequently, the same object will be returned for repeated calls to `get_backend`, which allows sharing configuration parameters. Moreover, if `name` only specifies a backend, i.e., does not contain a colon `:`, the configuration of this backend is linked with the global configuration :obj:`pde.config`, such that changes to the config are reflected globally. Args: name (str): Name of the backend to be loaded. config (dict): Additional configuration options for this specific backend. The full configuration will be taken from the global configuration and merged with the given options here. **kwargs: Additional options of the backend Returns: :class:`~pde.backends.base.BackendBase`: An instance of the backend with the particular configuration """ # handle special names if name == "default": name = global_config["default_backend"] if config: # return backend with custom configuration, which will not be cached backend = self._get_backend(name, link_global_config=False, **kwargs) backend.config.update_recursive(config) return backend # check whether the precise backend has been instantiated already if name not in self._backends: # create backend from the class definition backend = self._get_backend(name, link_global_config=True, **kwargs) self.register_backend(backend) return self._backends[name]
def __contains__(self, name: str) -> bool: backends = self._packages.keys() | self._classes.keys() | self._backends.keys() return name == "default" or name in backends def __iter__(self) -> Iterator[str]: """Iterate over the names of the defined backends.""" backends = self._packages.keys() | self._classes.keys() | self._backends.keys() return sorted(backends).__iter__()
[docs] def values(self) -> Iterator[BackendBase]: """Iterate over all backends that can be imported.""" for name in self: with contextlib.suppress(ImportError): yield self.get_backend(name)
# initiate the backend registry - there should only be one instance of this class backend_registry = BackendRegistry()
[docs] def load_default_config(module_path: str | Path) -> list[Parameter] | None: """Load a default configuration from a module. Args: module_path (str): String to the module to be loaded """ module_name = ( str(module_path).replace(".", "_").replace("/", "_").replace("\\", "_") ) spec = importlib.util.spec_from_file_location(module_name, module_path) if spec is None: _logger.warning("Could not load module `%s`", module_path) return None module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # type: ignore try: return module.DEFAULT_CONFIG # type: ignore except AttributeError: _logger.warning("Configuration module had no variable `DEFAULT_CONFIG`") return None
[docs] def get_backend( backend: str | BackendBase, config: ConfigLike | None = None ) -> BackendBase: """Return backend specified by string of instance. The returned backend is cached if `config` is not specified. Consequently, the same object will be returned for repeated calls to `get_backend`, which allows sharing configuration parameters. Moreover, if `name` only specifies a backend, i.e., does not contain a colon `:`, the configuration of this backend is linked with the global configuration :obj:`pde.config`, such that changes to the config are reflected globally. Args: backend (str or :class:`~pde.backends.base.BackendBase`): Backend specified by name given as a string. If the string contains a colon, the first part determines the backend, whereas the second part can be used to convey additional information. For example, :code:`torch:cuda` may load a torch backend and use a cuda device. As a special case, we also allow full backend objects, which are simply returned. This is a simple way to allow providing full backend objects in places where we otherwise would expect a backend name. config (dict): Additional configuration options for this specific backend. The full configuration will be taken from the global configuration and merged with the given options here. This option is only permitted if `backend` is a string since there otherwise might be unintended side effects of modifying an existing backend. Returns: :class:`~pde.backends.base.BackendBase`: An initialized backend """ if isinstance(backend, BackendBase): # backend already initialized -> optionally update the configuration if config: msg = "Configuration can only be set for new backend." raise RuntimeError(msg) return backend if isinstance(backend, str): # create a new backend given by name return backend_registry.get_backend(backend, config=config) raise TypeError
[docs] def registered_backends() -> list[str]: """Returns all registered backends.""" return list(backend_registry)