Source code for pde.tools.nested_dict

"""Provides a nested dictionary that stores hierarchical mappings.

.. autosummary::
   :nosignatures:

   NestedDict

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

from collections.abc import Iterator, MutableMapping
from typing import Any, Generic, Literal, TypeAlias, TypeVar, Union, overload

from typing_extensions import Self

# values are of generic type TValue, which will be specified
TValue = TypeVar("TValue")
# trees are nested dicts
TNestedDict: TypeAlias = "NestedDict[TValue]"
# nodes are trees or values
TNestedDictValue = Union[TValue, TNestedDict]  # noqa: UP007
# the dictionary version of the entire tree can have subtrees
TDictTree = dict[str, Union[TValue, "TDictTree"]]

T = TypeVar


[docs] class NestedDict(MutableMapping[str, TNestedDictValue], Generic[TValue]): """Stores hierarchical mappings with string paths as keys. `NestedDict` wraps nested mappings and supports reading and writing nested values using a separator-based key syntax (for example ``"a.b.c"``). It can convert between flat and nested representations and recursively traverses children when requested. Note: Equivalent entries can overwrite each other during initialization. For instance, ``NestedDict({'a.b': 1, 'a': {'b': 2}})`` stores only one final value for ``a.b``. """ sep: str = "." """str: Separator used in key paths to traverse nested levels.""" data: MutableMapping[str, TNestedDictValue] """dict: Internal mapping storing top-level keys and values for this instance.""" def __init__( self, data: MutableMapping[str, TNestedDictValue] | None = None ) -> None: """Initializes a nested dictionary from an optional mapping. Args: data (MutableMapping[str, Any] | None): Optional mapping used to populate the instance. Nested plain dictionaries are converted into `NestedDict` children. """ self.data = self._make_dict() if data is not None: self.update_recursive(data) def _make_dict(self) -> MutableMapping[str, TNestedDictValue]: """Create the backing mapping used to store top-level entries.""" return {} def _make_node(self) -> Self: """Create an empty child node of the current mapping type.""" return self.__class__() def _node(self, key: str, *, parent: str = "") -> tuple[TNestedDict, str, bool]: """Resolve a key path to the owning node and local key. Args: key: Key or nested key path to resolve. parent: Prefix accumulated while recursing through nested nodes. Returns: tuple[NestedDict[TValue], str, bool]: The owning node, the final local key, and whether the resolved value is itself a nested node. """ if not isinstance(key, str): msg = f"Keys must be strings, not {key!r}" raise TypeError(msg) if self.sep not in key: # key denotes a node try: node = self.data[key] except KeyError as err: # node did not exist msg = f"`{parent}{key}` not in {list(self.data.keys())}" raise KeyError(msg) from err is_tree = isinstance(node, NestedDict) return self, key, is_tree # key denotes entire branch child, grandchildren = key.split(self.sep, 1) try: node = self.data[child] # next node in branch except KeyError as err: # node did not exist msg = f"`{parent}{key}` not in {list(self.data.keys())}" raise KeyError(msg) from err if not isinstance(node, NestedDict): msg = f"`{child}` is not a tree node." raise TypeError(msg) # traverse branch recursively return node._node(grandchildren, parent=parent + key + self.sep) def __getitem__(self, key: str) -> TNestedDictValue: """Returns an item using dictionary indexing syntax. Args: key (str): Key or nested key path to resolve. Returns: Any: Value associated with `key`. """ node, subkey, _ = self._node(key) res = node.data[subkey] return res def __setitem__(self, key: str, value: TNestedDictValue) -> None: """Assigns a value to a key or nested key path. Args: key (str): Target key. If it contains the separator, missing intermediate `NestedDict` nodes are created automatically. value (Any): Value to store. Raises: TypeError: If path assignment traverses a non-`NestedDict` child. """ # prepare keys and values if not isinstance(key, str): msg = "Keys must be strings" raise TypeError(msg) try: node, subkey, is_tree = self._node(key) except KeyError: # entry does not exist yet if self.sep in key: # create parents node_key, value_key = key.rsplit(self.sep, 1) subnode: TNestedDict = self.create_node(node_key) else: subnode, value_key = self, key subnode.data[value_key] = value else: # update existing entry if is_tree: node.data[subkey].update_recursive(value) elif isinstance(value, MutableMapping): msg = "Cannot replace normal value with tree" raise TypeError(msg) else: node.data[subkey] = value def __delitem__(self, key: str) -> None: """Deletes an item addressed by a simple or nested key. Args: key (str): Key or key path identifying the value to remove. Raises: KeyError: If the key path cannot be resolved. """ node, subkey, _ = self._node(key) del node.data[subkey] def __contains__(self, key) -> bool: """Checks whether a key or key path is present. Args: key (object): Candidate key to test. The implementation accepts only strings. Returns: bool: `True` if the key path exists, otherwise `False`. """ if not isinstance(key, str): return False node = self for node_key in key.split(self.sep): try: node = node[node_key] except (KeyError, TypeError): return False return True def __len__(self) -> int: """Returns the number of top-level keys. Returns: int: Number of entries stored at the current level. """ return len(self.data) def __iter__(self) -> Iterator[str]: """Iterates over top-level keys. Returns: Iterator[str]: Iterator yielding top-level keys. """ return self.data.__iter__()
[docs] def clear(self) -> None: """Removes all top-level entries from the mapping.""" self.data.clear()
@overload # type: ignore def values( self, *, flatten: Literal[False] = False ) -> Iterator[TNestedDictValue]: ... @overload def values(self, *, flatten: Literal[True]) -> Iterator[TValue]: ...
[docs] def values(self, *, flatten: bool = False) -> Iterator[TNestedDictValue]: """Iterates over values, optionally recursing into nested children. Args: flatten (bool): If `True`, yields values from all descendant `NestedDict` instances. If `False`, yields only top-level values. Returns: Iterator[Any]: Iterator over values according to `flatten`. """ if flatten: for value in self.data.values(): if isinstance(value, NestedDict): yield from value.values(flatten=True) # recurse into sub dictionary else: yield value else: yield from self.data.values()
[docs] def keys(self, *, flatten: bool = False) -> Iterator[str]: # type: ignore """Iterates over keys, optionally returning flattened key paths. Args: flatten (bool): If `True`, yields separator-joined paths for descendant keys. If `False`, yields only top-level keys. Returns: Iterator[str]: Iterator over keys or flattened key paths. Raises: TypeError: If a key used during flattening is not a string. """ if flatten: for key, value in self.data.items(): if isinstance(value, NestedDict): # recurse into sub dictionary for k in value.keys(flatten=True): yield key + self.sep + k else: yield key else: yield from self.data.keys()
@overload # type: ignore def items( self, *, flatten: Literal[False] = False ) -> Iterator[tuple[str, TNestedDictValue]]: ... @overload def items(self, *, flatten: Literal[True]) -> Iterator[tuple[str, TValue]]: ...
[docs] def items(self, *, flatten: bool = False) -> Iterator[tuple[str, TNestedDictValue]]: """Iterates over key-value pairs, optionally flattening nested paths. Args: flatten (bool): If `True`, yields `(path, value)` pairs for all descendants. If `False`, yields only top-level pairs. Returns: Iterator[tuple[str, Any]]: Iterator over key-value pairs according to `flatten`. Raises: TypeError: If a key used during flattening is not a string. """ if flatten: for key, value in self.data.items(): if isinstance(value, NestedDict): # recurse into sub dictionary for k, v in value.items(flatten=True): yield key + self.sep + k, v else: yield key, value else: yield from self.data.items()
def __repr__(self) -> str: """Builds a debug representation of this mapping. Returns: str: String containing the class name and internal data mapping. """ return f"{self.__class__.__name__}({self.data!r})"
[docs] def create_node(self, key: str) -> Self: """Create an empty node at the given location. Creates all necessary parent nodes recursively. Skips nodes that already exist. Args: key: Key or nested key path identifying the node to create. Returns: The leaf node """ if not isinstance(key, str): msg = "Keys must be strings" raise TypeError(msg) if self.sep not in key: if key not in self.data: self.data[key] = self._make_node() return self.data[key] # type: ignore # need to create whole branch child, grandchildren = key.split(self.sep, 1) if child not in self.data: self.data[child] = self._make_node() return self.data[child].create_node(grandchildren) # type: ignore
[docs] def update_recursive(self, other: MutableMapping[str, Any]) -> None: """Recursively merges another mapping into this instance. Args: other (MutableMapping[str, Any]): Mapping whose entries are merged into this object. If both sides contain nested mappings at a key, values are merged recursively. """ if not isinstance(other, MutableMapping): raise TypeError for k, v in other.items(): if isinstance(v, MutableMapping): self.create_node(k).update_recursive(v) else: self[k] = v
[docs] def update(self, other) -> None: # type: ignore """Update this mapping from another mapping recursively. This method implements :class:`collections.abc.MutableMapping` update semantics for mapping-like inputs and forwards the actual merge to :meth:`update_recursive`. Args: other: Mapping containing keys and values to merge into this instance. Raises: TypeError: If `other` is not a mutable mapping. """ self.update_recursive(other)
[docs] def copy(self) -> TNestedDict: """Creates a structural copy with copied nested mappings. Child dictionaries and child `NestedDict` instances are copied, while non-mapping leaf values are reused by reference. Returns: NestedDict: New instance containing copied nested structure. """ res = self._make_node() res.update_recursive(self) return res
@overload def to_dict(self, *, flatten: Literal[False] = False) -> TDictTree: ... @overload def to_dict(self, *, flatten: Literal[True]) -> dict[str, TValue]: ...
[docs] def to_dict(self, *, flatten: bool = False) -> TDictTree: # type: ignore """Converts this object to a plain dictionary representation. Args: flatten (bool): If `True`, returns a flat mapping with separator-joined paths as keys. If `False`, returns nested dictionaries. Returns: dict[str, Any]: Dictionary representation of this object. Raises: TypeError: If flattening encounters non-string keys. """ if flatten: return dict(self.items(flatten=True)) # return hierarchical dictionaries return { k: v.to_dict(flatten=False) if isinstance(v, NestedDict) else v for k, v in self.data.items() }
[docs] def pprint(self, *args: Any, flatten: bool = False, **kwargs: Any) -> None: """Pretty-prints the current data as nested plain dictionaries. Args: *args (Any): Positional arguments forwarded to `pprint.pprint`. flatten (bool): If `True`, returns a flat mapping with separator-joined paths as keys. If `False`, returns nested dictionaries. **kwargs (Any): Keyword arguments forwarded to `pprint.pprint`. """ from pprint import pprint pprint(self.to_dict(flatten=flatten), *args, **kwargs) # type: ignore