"""Infrastructure for managing classes with parameters.
One aim is to allow easy management of inheritance of parameters.
.. autosummary::
:nosignatures:
Parameter
DeprecatedParameter
HideParameter
Parameterized
get_all_parameters
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import logging
import warnings
from collections.abc import Sequence
from typing import Any, Union
import numpy as np
from . import output
from .misc import hybridmethod, import_class
[docs]
class Parameter:
"""Class representing a single parameter."""
def __init__(
self,
name: str,
default_value=None,
cls=object,
description: str = "",
hidden: bool = False,
extra: dict[str, Any] | None = None,
):
"""Initialize a parameter.
Args:
name (str):
The name of the parameter
default_value:
The default value
cls:
The type of the parameter, which is used for conversion
description (str):
A string describing the impact of this parameter. This
description appears in the parameter help
hidden (bool):
Whether the parameter is hidden in the description summary
extra (dict):
Extra arguments that are stored with the parameter
"""
self.name = name
self.default_value = default_value
self.cls = cls
self.description = description
self.hidden = hidden
self.extra = {} if extra is None else extra
if cls is not object:
# check whether the default value is of the correct type
converted_value = cls(default_value)
if isinstance(converted_value, np.ndarray):
# numpy arrays are checked for each individual value
valid_default = np.allclose(
converted_value, default_value, equal_nan=True
)
else:
# other values are compared directly. Note that we also check identity
# to capture the case where the value is `math.nan`, where the direct
# comparison (nan == nan) would evaluate to False
valid_default = (
converted_value is default_value or converted_value == default_value
)
if not valid_default:
if hasattr(self, "_logger"):
logger: logging.Logger = self._logger
else:
logger = logging.getLogger(self.__class__.__module__)
logger.warning(
"Default value `%s` does not seem to be of type `%s`",
name,
cls.__name__,
)
def __repr__(self):
return (
f'{self.__class__.__name__}(name="{self.name}", default_value='
f'"{self.default_value}", cls="{self.cls.__name__}", '
f'description="{self.description}", hidden={self.hidden})'
)
__str__ = __repr__
def __getstate__(self):
# replace the object class by its class path
return {
"name": str(self.name),
"default_value": self.convert(),
"cls": object.__module__ + "." + self.cls.__name__,
"description": self.description,
"hidden": self.hidden,
"extra": self.extra,
}
def __setstate__(self, state):
# restore the object from the class path
state["cls"] = import_class(state["cls"])
# restore the state
self.__dict__.update(state)
[docs]
def convert(self, value=None):
"""Converts a `value` into the correct type for this parameter. If `value` is
not given, the default value is converted.
Note that this does not make a copy of the values, which could lead to
unexpected effects where the default value is changed by an instance.
Args:
value: The value to convert
Returns:
The converted value, which is of type `self.cls`
"""
if value is None:
value = self.default_value
if self.cls is object:
return value
else:
try:
return self.cls(value)
except ValueError as err:
raise ValueError(
f"Could not convert {value!r} to {self.cls.__name__} for parameter "
f"'{self.name}'"
) from err
[docs]
class DeprecatedParameter(Parameter):
"""A parameter that can still be used normally but is deprecated."""
[docs]
class HideParameter:
"""A helper class that allows hiding parameters of the parent classes."""
def __init__(self, name: str):
"""
Args:
name (str):
The name of the parameter
"""
self.name = name
ParameterListType = Sequence[Union[Parameter, HideParameter]]
[docs]
class Parameterized:
"""A mixin that manages the parameters of a class."""
parameters_default: ParameterListType = []
_subclasses: dict[str, type[Parameterized]] = {}
def __init__(self, parameters: dict[str, Any] | None = None):
"""Initialize the parameters of the object.
Args:
parameters (dict):
A dictionary of parameters to change the defaults. The allowed
parameters can be obtained from
:meth:`~Parameterized.get_parameters` or displayed by calling
:meth:`~Parameterized.show_parameters`.
"""
# deprecated on 2024-08-15
warnings.warn(
"`Parameterized` has been moved to `py-modelrunner`", DeprecationWarning
)
# set parameters if they have not been initialized, yet
if not hasattr(self, "parameters"):
self.parameters = self._parse_parameters(
parameters, include_deprecated=True
)
def __init_subclass__(cls, **kwargs):
"""Register all subclasses to reconstruct them later."""
# normalize the parameters_default attribute
if hasattr(cls, "parameters_default") and isinstance(
cls.parameters_default, dict
):
# default parameters are given as a dictionary
cls.parameters_default = [
Parameter(*args) for args in cls.parameters_default.items()
]
# register the subclasses
super().__init_subclass__(**kwargs)
if cls is not Parameterized:
if cls.__name__ in cls._subclasses:
warnings.warn(f"Redefining class {cls.__name__}")
cls._subclasses[cls.__name__] = cls
[docs]
@classmethod
def get_parameters(
cls,
include_hidden: bool = False,
include_deprecated: bool = False,
sort: bool = True,
) -> dict[str, Parameter]:
"""Return a dictionary of parameters that the class supports.
Args:
include_hidden (bool): Include hidden parameters
include_deprecated (bool): Include deprecated parameters
sort (bool): Return ordered dictionary with sorted keys
Returns:
dict: a dictionary of instance of :class:`Parameter` with their
names as keys.
"""
# collect the parameters from the class hierarchy
parameters: dict[str, Parameter] = {}
for parent_cls in reversed(cls.__mro__):
if hasattr(parent_cls, "parameters_default"):
for p in parent_cls.parameters_default:
if isinstance(p, HideParameter):
if include_hidden:
parameters[p.name].hidden = True
else:
del parameters[p.name]
else:
parameters[p.name] = p
# filter parameters based on hidden and deprecated flags
def show(p):
"""Helper function to decide whether parameter will be shown."""
# show based on hidden flag?
show1 = include_hidden or not p.hidden
# show based on deprecated flag?
show2 = include_deprecated or not isinstance(p, DeprecatedParameter)
return show1 and show2
# filter parameters based on `show`
result = {
name: parameter for name, parameter in parameters.items() if show(parameter)
}
if sort:
result = dict(sorted(result.items()))
return result
@classmethod
def _parse_parameters(
cls,
parameters: dict[str, Any] | None = None,
check_validity: bool = True,
allow_hidden: bool = True,
include_deprecated: bool = False,
) -> dict[str, Any]:
"""Parse parameters.
Args:
parameters (dict):
A dictionary of parameters that will be parsed.
check_validity (bool):
Determines whether a `ValueError` is raised if there are keys in
parameters that are not in the defaults. If `False`, additional
items are simply stored in `self.parameters`
allow_hidden (bool):
Allow setting hidden parameters
include_deprecated (bool):
Include deprecated parameters
"""
if parameters is None:
parameters = {}
else:
parameters = parameters.copy() # do not modify the original
# obtain all possible parameters
param_objs = cls.get_parameters(
include_hidden=allow_hidden, include_deprecated=include_deprecated
)
# initialize parameters with default ones from all parent classes
result: dict[str, Any] = {}
for name, param_obj in param_objs.items():
if not allow_hidden and param_obj.hidden:
continue # skip hidden parameters
# take value from parameters or set default value
result[name] = param_obj.convert(parameters.pop(name, None))
# update parameters with the supplied ones
if check_validity and parameters:
raise ValueError(
f"Parameters `{sorted(parameters.keys())}` were provided for an "
f"instance but are not defined for the class `{cls.__name__}`"
)
else:
result.update(parameters) # add remaining parameters
return result
[docs]
def get_parameter_default(self, name):
"""Return the default value for the parameter with `name`
Args:
name (str): The parameter name
"""
for cls in self.__class__.__mro__:
if hasattr(cls, "parameters_default"):
for p in cls.parameters_default:
if isinstance(p, Parameter) and p.name == name:
return p.default_value
raise KeyError(f"Parameter `{name}` is not defined")
@classmethod
def _show_parameters(
cls,
description: bool | None = None,
sort: bool = False,
show_hidden: bool = False,
show_deprecated: bool = False,
parameter_values: dict[str, Any] | None = None,
):
"""Private method showing all parameters in human readable format.
Args:
description (bool):
Flag determining whether the parameter description is shown. The
default is to show the description only when we are in a jupyter
notebook environment.
sort (bool):
Flag determining whether the parameters are sorted
show_hidden (bool):
Flag determining whether hidden parameters are shown
show_deprecated (bool):
Flag determining whether deprecated parameters are shown
parameter_values (dict):
A dictionary with values to show. Parameters not in this
dictionary are shown with their default value.
All flags default to `False`.
"""
# determine whether we are in a jupyter notebook and can return HTML
in_notebook = output.in_jupyter_notebook()
if description is None:
description = in_notebook # show only in notebook by default
# set the templates for displaying the data
if in_notebook:
writer: output.OutputBase = output.JupyterOutput(
'<style type="text/css">dl.py-pde_params dd {padding-left:2em}</style>'
'<dl class="py-pde_params">',
"</dl>",
)
# templates for HTML output
template = "<dt>{name} = {value!r}</dt>"
if description:
template += "<dd>{description}</dd>"
template_object = template
else:
# template for normal output
writer = output.BasicOutput()
template = "{name}: {type} = {value!r}"
template_object = "{name} = {value!r}"
if description:
template += " ({description})"
template_object += " ({description})"
# iterate over all parameters
params = cls.get_parameters(
include_hidden=show_hidden, include_deprecated=show_deprecated, sort=sort
)
for param in params.values():
# initialize the data to show
data = {
"name": param.name,
"type": param.cls.__name__,
"description": param.description,
}
# determine the value to show
if parameter_values is None:
data["value"] = param.default_value
else:
data["value"] = parameter_values[param.name]
# print the data to stdout
if param.cls is object:
writer(template_object.format(**data))
else:
writer(template.format(**data))
writer.show()
@hybridmethod
def show_parameters(
cls,
description: bool | None = None,
sort: bool = False,
show_hidden: bool = False,
show_deprecated: bool = False,
):
"""Show all parameters in human readable format.
Args:
description (bool):
Flag determining whether the parameter description is shown. The
default is to show the description only when we are in a jupyter
notebook environment.
sort (bool):
Flag determining whether the parameters are sorted
show_hidden (bool):
Flag determining whether hidden parameters are shown
show_deprecated (bool):
Flag determining whether deprecated parameters are shown
All flags default to `False`.
"""
cls._show_parameters(description, sort, show_hidden, show_deprecated)
[docs]
@show_parameters.instancemethod # type: ignore
def show_parameters(
self,
description: bool | None = None,
sort: bool = False,
show_hidden: bool = False,
show_deprecated: bool = False,
default_value: bool = False,
):
"""Show all parameters in human readable format.
Args:
description (bool):
Flag determining whether the parameter description is shown. The
default is to show the description only when we are in a jupyter
notebook environment.
sort (bool):
Flag determining whether the parameters are sorted
show_hidden (bool):
Flag determining whether hidden parameters are shown
show_deprecated (bool):
Flag determining whether deprecated parameters are shown
default_value (bool):
Flag determining whether the default values or the current
values are shown
All flags default to `False`.
"""
self._show_parameters(
description,
sort,
show_hidden,
show_deprecated,
parameter_values=None if default_value else self.parameters,
)
[docs]
def get_all_parameters(data: str = "name") -> dict[str, Any]:
"""Get a dictionary with all parameters of all registered classes.
Args:
data (str):
Determines what data is returned. Possible values are 'name',
'value', or 'description', to return the respective information
about the parameters.
"""
result = {}
for cls_name, cls in Parameterized._subclasses.items():
if data == "name":
parameters = set(cls.get_parameters().keys())
elif data == "value":
parameters = { # type: ignore
k: v.default_value for k, v in cls.get_parameters().items()
}
elif data == "description":
parameters = { # type: ignore
k: v.description for k, v in cls.get_parameters().items()
}
else:
raise ValueError(f"Cannot interpret data `{data}`")
result[cls_name] = parameters
return result
[docs]
def sphinx_display_parameters(app, what, name, obj, options, lines):
"""Helper function to display parameters in sphinx documentation.
Example:
This function should be connected to the 'autodoc-process-docstring'
event like so:
app.connect('autodoc-process-docstring', sphinx_display_parameters)
"""
if (
what == "class"
and issubclass(obj, Parameterized)
and any(":param parameters:" in line for line in lines)
):
# parse parameters
parameters = obj.get_parameters(sort=False)
if parameters:
lines.append(".. admonition::")
lines.append(f" Parameters of {obj.__name__}:")
lines.append(" ")
for p in parameters.values():
lines.append(f" {p.name}")
text = p.description.splitlines()
text.append(f"(Default value: :code:`{p.default_value!r}`)")
text = [" " + t for t in text]
lines.extend(text)
lines.append("")
lines.append("")