"""
Defines a collection of fields to represent multiple fields defined on a common
grid.
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import json
import logging
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
import numpy as np
from ..grids.base import GridBase
from ..tools.docstrings import fill_in_docstring
from ..tools.misc import Number, number_array
from ..tools.plotting import PlotReference, plot_on_figure
from ..tools.typing import NumberOrArray
from .base import DataFieldBase, FieldBase
from .scalar import ScalarField
[docs]class FieldCollection(FieldBase):
"""Collection of fields defined on the same grid
Note that all fields in the same collection must have the same data type. This might
lead to upcasting, where for instance a combination of a real-valued and a
complex-valued field will be both stored as complex fields.
"""
def __init__(
self,
fields: Sequence[DataFieldBase],
*,
copy_fields: bool = False,
label: Optional[str] = None,
labels: List[Optional[str]] = None,
dtype=None,
):
"""
Args:
fields:
Sequence of the individual fields
copy_fields (bool):
Flag determining whether the individual fields given in `fields` are
copied. Note that fields are always copied if some of the supplied
fields are identical.
label (str):
Label of the field collection
labels (list of str):
Labels of the individual fields. If omitted, the labels from the
`fields` argument are used.
dtype (numpy dtype):
The data type of the field. All the numpy dtypes are supported. If
omitted, it will be determined from `data` automatically.
"""
if isinstance(fields, FieldCollection):
# support assigning a field collection for convenience
fields = fields.fields
if len(fields) == 0:
raise ValueError("At least one field must be defined")
# check if grids are compatible
grid = fields[0].grid
if any(grid != f.grid for f in fields[1:]):
grids = [f.grid for f in fields]
raise RuntimeError(f"Grids are incompatible: {grids}")
# check whether some fields are identical
if not copy_fields and len(fields) != len(set(id(field) for field in fields)):
self._logger = logging.getLogger(self.__class__.__name__)
self._logger.warning("Creating a copy of identical fields in collection")
copy_fields = True
# create the list of underlying fields
if copy_fields:
self._fields = [field.copy() for field in fields]
else:
self._fields = fields # type: ignore
# extract data from individual fields
fields_data: List[np.ndarray] = []
self._slices: List[slice] = []
dof = 0 # count local degrees of freedom
for field in self.fields:
if not isinstance(field, DataFieldBase):
raise RuntimeError(
"Individual fields must be of type DataFieldBase. Field "
"collections cannot be nested."
)
start = len(fields_data)
this_data = field._data_flat
fields_data.extend(this_data)
self._slices.append(slice(start, len(fields_data)))
dof += len(this_data)
# initialize the data from the individual fields
data_arr = number_array(fields_data, dtype=dtype, copy=False)
# initialize the class
super().__init__(grid, data_arr, label=label)
if not copy_fields:
# link the data of the original fields back to self._data
for i, field in enumerate(self.fields):
field_shape = field.data.shape
field._data_flat = self._data_full[self._slices[i]]
# check whether the field data is based on our data field
assert field.data.shape == field_shape
assert np.may_share_memory(field._data_full, self._data_full)
if labels is not None:
self.labels = labels # type: ignore
def __repr__(self):
"""return instance as string"""
fields = []
for f in self.fields:
name = f.__class__.__name__
if f.label:
fields.append(f'{name}(..., label="{f.label}")')
else:
fields.append(f"{name}(...)")
return f"{self.__class__.__name__}({', '.join(fields)})"
def __len__(self):
"""return the number of stored fields"""
return len(self.fields)
def __iter__(self) -> Iterator[DataFieldBase]:
"""return iterator over the actual fields"""
return iter(self.fields)
def __getitem__(self, index: Union[int, str]) -> DataFieldBase:
"""return a specific field"""
if isinstance(index, int):
# simple numerical index
return self.fields[index]
elif isinstance(index, str):
# index specifying the label of the field
for field in self.fields:
if field.label == index:
return field
raise KeyError(f"No field with name `{index}`")
else:
raise TypeError(f"Unsupported index `{index}`")
def __setitem__(self, index: Union[int, str], value: NumberOrArray):
"""set the value of a specific field
Args:
index (int or str):
Determines which field is updated. If `index` is an integer it specifies
the position of the field that will be updated. If `index` is a string,
the first field with this name will be updated.
value (float or :class:`~numpy.ndarray`):
The updated value(s) of the chosen field.
"""
# We need to load the field and set data explicitly
# WARNING: Do not use `self.fields[index] = value`, since this would
# break the connection between the data fields
if isinstance(index, int):
# simple numerical index
self.fields[index].data = value # type: ignore
elif isinstance(index, str):
# index specifying the label of the field
for field in self.fields:
if field.label == index:
field.data = value # type: ignore
break # indicates that a field has been found
else:
raise KeyError(f"No field with name `{index}`")
else:
raise TypeError(f"Unsupported index `{index}`")
@property
def fields(self) -> List[DataFieldBase]:
"""list: the fields of this collection"""
return self._fields
@property
def labels(self) -> "_FieldLabels":
""":class:`_FieldLabels`: the labels of all fields
Note:
The attribute returns a special class :class:`_FieldLabels` to allow
specific manipulations of the field labels. The returned object behaves
much like a list, but assigning values will modify the labels of the fields
in the collection.
"""
return _FieldLabels(self)
@labels.setter
def labels(self, values: List[Optional[str]]):
"""sets the labels of all fields"""
if len(values) != len(self):
raise ValueError("Require a label for each field")
for field, value in zip(self.fields, values):
field.label = value
def __eq__(self, other):
"""test fields for equality, ignoring the label"""
if not isinstance(other, self.__class__):
return NotImplemented
return self.fields == other.fields
[docs] @classmethod
def from_state(
cls, attributes: Dict[str, Any], data: np.ndarray = None
) -> FieldCollection:
"""create a field collection from given state.
Args:
attributes (dict):
The attributes that describe the current instance
data (:class:`~numpy.ndarray`, optional):
Data values at support points of the grid defining all fields
"""
if "class" in attributes:
class_name = attributes.pop("class")
assert class_name == cls.__name__
# restore the individual fields (without data)
fields = [
FieldBase.from_state(field_state)
for field_state in attributes.pop("fields")
]
# create the collection
collection = cls(fields, **attributes) # type: ignore
if data is not None:
collection.data = data # set the data of all fields
return collection
@classmethod
def _from_hdf_dataset(cls, dataset) -> FieldCollection:
"""construct the class by reading data from an hdf5 dataset"""
# copy attributes from hdf
attributes = dict(dataset.attrs)
# determine class
class_name = json.loads(attributes.pop("class"))
assert class_name == cls.__name__
# determine the fields
field_attrs = json.loads(attributes.pop("fields"))
fields = [
DataFieldBase._from_hdf_dataset(dataset[f"field_{i}"])
for i in range(len(field_attrs))
]
# unserialize remaining attributes
attributes = cls.unserialize_attributes(attributes)
return cls(fields, **attributes) # type: ignore
def _write_hdf_dataset(self, hdf_path):
"""write data to a given hdf5 path `hdf_path`"""
# write attributes of the collection
for key, value in self.attributes_serialized.items():
hdf_path.attrs[key] = value
# write individual fields
for i, field in enumerate(self.fields):
field._write_hdf_dataset(hdf_path, f"field_{i}")
[docs] def assert_field_compatible(self, other: FieldBase, accept_scalar: bool = False):
"""checks whether `other` is compatible with the current field
Args:
other (FieldBase): Other field this is compared to
accept_scalar (bool, optional): Determines whether it is acceptable
that `other` is an instance of
:class:`~pde.fields.ScalarField`.
"""
super().assert_field_compatible(other, accept_scalar=accept_scalar)
# check whether all sub fields are compatible
if isinstance(other, FieldCollection):
for f1, f2 in zip(self, other):
f1.assert_field_compatible(f2, accept_scalar=accept_scalar)
[docs] @classmethod
@fill_in_docstring
def from_scalar_expressions(
cls,
grid: GridBase,
expressions: Sequence[str],
*,
label: str = None,
labels: Optional[Sequence[str]] = None,
dtype=None,
) -> FieldCollection:
"""create a field collection on a grid from given expressions
Warning:
{WARNING_EXEC}
Args:
grid (:class:`~pde.grids.base.GridBase`):
Grid defining the space on which this field is defined
expressions (list of str):
A list of mathematical expression, one for each field in the
collection. The expressions determine the values as a function
of the position on the grid. The expressions may contain
standard mathematical functions and they may depend on the axes
labels of the grid.
label (str, optional):
Name of the whole collection
labels (list of str, optional):
Names of the individual fields
dtype (numpy dtype):
The data type of the field. All the numpy dtypes are supported. If
omitted, it will be determined from `data` automatically.
"""
if isinstance(expressions, str):
expressions = [expressions]
if labels is None:
labels = [None] * len(expressions) # type: ignore
# evaluate all expressions at all points
fields = [
ScalarField.from_expression(grid, expression, label=labels[i], dtype=dtype)
for i, expression in enumerate(expressions)
]
# create vector field from the data
return cls(fields=fields, label=label) # lgtm [py/call-to-non-callable]
@property
def attributes(self) -> Dict[str, Any]:
"""dict: describes the state of the instance (without the data)"""
results = super().attributes
# store the attributes of the individual fields in a separate attribute
results["fields"] = [f.attributes for f in self.fields]
# the grid information does not need to be stored since it is included in the
# attributes of the individual fields
del results["grid"]
return results
@property
def attributes_serialized(self) -> Dict[str, str]:
"""dict: serialized version of the attributes"""
results = {}
for key, value in self.attributes.items():
if key == "fields":
fields = [f.attributes_serialized for f in self.fields]
results[key] = json.dumps(fields)
elif key == "dtype":
results[key] = json.dumps(value.str)
else:
results[key] = json.dumps(value)
return results
[docs] @classmethod
def unserialize_attributes(cls, attributes: Dict[str, str]) -> Dict[str, Any]:
"""unserializes the given attributes
Args:
attributes (dict):
The serialized attributes
Returns:
dict: The unserialized attributes
"""
results = {}
for key, value in attributes.items():
if key == "fields":
results[key] = [
FieldBase.unserialize_attributes(attrs)
for attrs in json.loads(value)
]
else:
results[key] = json.loads(value)
return results
[docs] def copy(
self: FieldCollection,
*,
label: str = None,
dtype=None,
) -> FieldCollection:
"""return a copy of the data, but not of the grid
Args:
label (str, optional):
Name of the returned field
dtype (numpy dtype):
The data type of the field. If omitted, it will be determined from
`data` automatically.
"""
if label is None:
label = self.label
fields = [f.copy() for f in self.fields]
# create the collection from the copied fields
return self.__class__(fields, copy_fields=False, label=label, dtype=dtype)
[docs] def interpolate_to_grid(
self,
grid: GridBase,
*,
backend: str = "numba",
method: str = "linear",
fill: Number = None,
label: Optional[str] = None,
) -> FieldCollection:
"""interpolate the data of this field collection to another grid.
Args:
grid (:class:`~pde.grids.base.GridBase`):
The grid of the new field onto which the current field is
interpolated.
backend (str):
The accepted values "scipy" and "numba" determine the backend that is
used for the interpolation.
method (str):
Determines the method being used for interpolation. Typical values that
are "nearest" and "linear", but the supported values depend on the
chosen `backend`.
fill (Number, optional):
Determines how values out of bounds are handled. If `None`, a
`ValueError` is raised when out-of-bounds points are requested.
Otherwise, the given value is returned.
label (str, optional):
Name of the returned field collection
Returns:
:class:`~pde.fields.coolection.FieldCollection`: Interpolated data
"""
if label is None:
label = self.label
fields = [
f.interpolate_to_grid(grid, backend=backend, method=method, fill=fill)
for f in self.fields
]
return self.__class__(fields, label=label)
[docs] def smooth(
self,
sigma: float = 1,
*,
out: Optional[FieldCollection] = None,
label: str = None,
) -> FieldCollection:
"""applies Gaussian smoothing with the given standard deviation
This function respects periodic boundary conditions of the underlying
grid, using reflection when no periodicity is specified.
sigma (float):
Gives the standard deviation of the smoothing in real length units
(default: 1)
out (FieldCollection, optional):
Optional field into which the smoothed data is stored
label (str, optional):
Name of the returned field
Returns:
Field collection with smoothed data, stored at `out` if given.
"""
# allocate memory for storing output
if out is None:
out = self.copy(label=label)
else:
self.assert_field_compatible(out)
if label:
out.label = label
# apply Gaussian smoothing for each axis
for f_in, f_out in zip(self, out):
f_in.smooth(sigma=sigma, out=f_out)
return out
@property
def integrals(self) -> List:
"""integrals of all fields"""
return [field.integral for field in self]
@property
def averages(self) -> List:
"""averages of all fields"""
return [field.average for field in self]
@property
def magnitudes(self) -> np.ndarray:
""":class:`~numpy.ndarray`: scalar magnitudes of all fields"""
return np.array([field.magnitude for field in self])
[docs] def get_line_data( # type: ignore
self,
index: int = 0,
scalar: str = "auto",
extract: str = "auto",
) -> Dict[str, Any]:
r"""return data for a line plot of the field
Args:
index (int):
Index of the field whose data is returned
scalar (str or int):
The method for extracting scalars as described in
:meth:`DataFieldBase.to_scalar`.
extract (str):
The method used for extracting the line data. See the docstring
of the grid method `get_line_data` to find supported values.
Returns:
dict: Information useful for performing a line plot of the field
"""
return self[index].get_line_data(scalar=scalar, extract=extract)
[docs] def get_image_data(self, index: int = 0, **kwargs) -> Dict[str, Any]:
r"""return data for plotting an image of the field
Args:
index (int): Index of the field whose data is returned
\**kwargs: Arguments forwarded to the `get_image_data` method
Returns:
dict: Information useful for plotting an image of the field
"""
return self[index].get_image_data(**kwargs)
def _update_plot(self, reference: List[PlotReference]) -> None:
"""update a plot collection with the current field values
Args:
reference (list of :class:`PlotReference`):
All references of the plot to update
"""
for field, ref in zip(self.fields, reference):
field._update_plot(ref)
[docs] @plot_on_figure(update_method="_update_plot")
def plot(
self,
kind: Union[str, Sequence[str]] = "auto",
resize_fig=None,
figsize="auto",
arrangement="horizontal",
fig=None,
subplot_args=None,
**kwargs,
) -> List[PlotReference]:
r"""visualize all the fields in the collection
Args:
kind (str or list of str):
Determines the kind of the visualizations. Supported values are `image`,
`line`, `vector`, or `interactive`. Alternatively, `auto` determines the
best visualization based on each field itself. Instead of a single value
for all fields, a list with individual values can be given.
resize_fig (bool):
Whether to resize the figure to adjust to the number of panels
figsize (str or tuple of numbers):
Determines the figure size. The figure size is unchanged if the string
`default` is passed. Conversely, the size is adjusted automatically when
`auto` is passed. Finally, a specific figure size can be specified using
two values, using :func:`matplotlib.figure.Figure.set_size_inches`.
arrangement (str):
Determines how the subpanels will be arranged. The default value
`horizontal` places all subplots next to each other. The alternative
value `vertical` puts them below each other.
{PLOT_ARGS}
subplot_args (list):
Additional arguments for the specific subplots. Should be a list with a
dictionary of arguments for each subplot. Supplying an empty dict allows
to keep the default setting of specific subplots.
\**kwargs:
All additional keyword arguments are forwarded to the actual plotting
function of all subplots.
Returns:
List of :class:`PlotReference`: Instances that contain information
to update all the plots with new data later.
"""
if resize_fig is not None:
# Deprecated this argument on 2021-02-01
import warnings
warnings.warn(
"`resize_fig` is a deprecated argument. Use `figsize` directly",
DeprecationWarning,
)
if resize_fig is True:
figsize = "auto"
elif resize_fig is False:
figsize = "default"
else:
raise ValueError
# set the size of the figure
if figsize == "default":
pass # just leave the figure size at its default value
elif figsize == "auto":
# adjust the size of the figure
if arrangement == "horizontal":
fig.set_size_inches((4 * len(self), 3), forward=True)
elif arrangement == "vertical":
fig.set_size_inches((4, 3 * len(self)), forward=True)
else:
# assume that an actual tuple is given
fig.set_size_inches(figsize, forward=True)
# create all the subpanels
if arrangement == "horizontal":
(axs,) = fig.subplots(1, len(self), squeeze=False)
elif arrangement == "vertical":
axs = fig.subplots(len(self), 1, squeeze=False)
axs = [a[0] for a in axs] # transpose
else:
raise ValueError(f"Unknown arrangement `{arrangement}`")
if subplot_args is None:
subplot_args = [{}] * len(self)
if isinstance(kind, str):
kind = [kind] * len(self.fields)
# plot all the elements onto the respective axes
reference = [
field.plot(kind=knd, ax=ax, action="create", **kwargs, **sp_args)
for field, knd, ax, sp_args in zip( # @UnusedVariable
self.fields, kind, axs, subplot_args
)
]
# return the references for all subplots
return reference
def _get_napari_data(self, **kwargs) -> Dict[str, Dict[str, Any]]:
r"""returns data for plotting all fields
Args:
\**kwargs: all arguments are forwarded to `_get_napari_layer_data`
Returns:
dict: all the information necessary to plot all fields
"""
result = {}
for i, field in enumerate(self, 1):
name = f"Field {i}" if field.label is None else field.label
result[name] = field._get_napari_layer_data(**kwargs)
return result
class _FieldLabels:
"""helper class that allows manipulating all labels of field collections"""
def __init__(self, collection: FieldCollection):
"""
Args:
collection (:class:`pde.fields.collection.FieldCollection`):
The field collection that these labels are associated with
"""
self.collection = collection
def __repr__(self):
return repr(list(self))
def __str__(self):
return str(list(self))
def __len__(self):
return len(self.collection)
def __eq__(self, other):
return list(self) == list(other)
def __iter__(self) -> Iterator[Optional[str]]:
for field in self.collection:
yield field.label
def __getitem__(
self, index: Union[int, slice]
) -> Union[Optional[str], List[Optional[str]]]:
"""return one or many labels of a field in the collection"""
if isinstance(index, int):
return self.collection[index].label
elif isinstance(index, slice):
return list(self)[index]
else:
raise TypeError("Unsupported index type")
def __setitem__(
self, index: Union[int, slice], value: Union[Optional[str], List[Optional[str]]]
):
"""change one or many labels of a field in the collection"""
if isinstance(index, int):
self.collection.fields[index].label = value # type: ignore
elif isinstance(index, slice):
fields = self.collection.fields[index]
if value is None or isinstance(value, str):
value = [value] * len(fields)
if len(fields) != len(value):
raise ValueError("Require a label for each field")
for field, label in zip(fields, value):
field.label = label
else:
raise TypeError("Unsupported index type")
def index(self, label: str) -> int:
"""return the index in the field labels where a certain label is stored
Args:
label (str):
The label that is sought
Returns:
int: The index in the list (or `ValueError` if it cannot be found)
"""
for i, field in enumerate(self.collection):
if field.label == label:
return i
raise ValueError(f"Label `{label}` is not present in the collection")