"""
Defines base classes of fields, which are discretized on grids
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import functools
import json
import logging
import warnings
from abc import ABCMeta, abstractmethod
from inspect import isabstract
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
import numpy as np
from ..grids.base import DimensionError, DomainError, GridBase, discretize_interval
from ..grids.boundaries.axes import BoundariesData
from ..grids.cartesian import CartesianGridBase
from ..tools.cache import cached_method
from ..tools.docstrings import fill_in_docstring
from ..tools.misc import Number, number_array
from ..tools.numba import jit, make_array_constructor, numba_dict
from ..tools.plotting import (
PlotReference,
napari_add_layers,
napari_viewer,
plot_on_axes,
)
from ..tools.typing import ArrayLike, NumberOrArray
if TYPE_CHECKING:
from .scalar import ScalarField # @UnusedImport
TField = TypeVar("TField", bound="FieldBase")
[docs]class RankError(TypeError):
"""error indicating that the field has the wrong rank"""
[docs]class FieldBase(metaclass=ABCMeta):
"""abstract base class for describing (discretized) fields"""
_subclasses: Dict[str, Type[FieldBase]] = {} # all classes inheriting from this
_grid: GridBase # the grid on which the field is defined
__data_full: np.ndarray # the data on the grid including ghost points
_data_valid: np.ndarray # the valid data without ghost points
_label: Optional[str]
def __init__(
self,
grid: GridBase,
data: np.ndarray,
*,
label: Optional[str] = None,
):
"""
Args:
grid (:class:`~pde.grids.base.GridBase`):
Grid defining the space on which this field is defined
data (:class:`~numpy.ndarray`, optional):
Field values at the support points of the grid and the ghost cells
label (str, optional):
Name of the field
"""
self._grid = grid
self._data_full = data
self.label = label
self._logger = logging.getLogger(self.__class__.__name__)
def __init_subclass__(cls, **kwargs): # @NoSelf
"""register all subclassess to reconstruct them later"""
super().__init_subclass__(**kwargs)
cls._subclasses[cls.__name__] = cls
@property
def data(self) -> np.ndarray:
""":class:`~numpy.ndarray`: discretized data at the support points"""
return self._data_valid
@data.setter
def data(self, value: NumberOrArray) -> None:
"""set the valid data of the field
Args:
value:
The value of the valid data. If a scalar is supplied all data points get
the same value. The value of ghost cells are not changed.
"""
if isinstance(value, FieldBase):
# copy data into current field
self.assert_field_compatible(value, accept_scalar=True)
self._data_valid[:] = value.data
else:
self._data_valid[:] = value
@property
def _idx_valid(self) -> Tuple[slice, ...]:
"""tuple: slices to extract valid data from full data"""
idx_comp = (slice(None),) * (self.__data_full.ndim - self.grid.num_axes)
return idx_comp + self.grid._idx_valid
@property
def _data_full(self) -> np.ndarray:
""":class:`~numpy.ndarray`: the full data including ghost cells"""
return self.__data_full
@_data_full.setter
def _data_full(self, value: NumberOrArray) -> None:
"""set the full data including ghost cells
Args:
value:
The value of the full data including those for ghost cells. If a scalar
is supplied all data points get the same value.
"""
if not self.writeable:
raise ValueError("assignment destination is read-only")
if np.isscalar(value):
# supplied value is a scalar
self.__data_full[:] = value
elif isinstance(value, np.ndarray):
# check the shape of the supplied array
if value.shape[-self.grid.num_axes :] != self.grid._shape_full:
raise ValueError(
f"Supplied data has wrong shape: {value.shape} is not compatible "
f"with {self.grid._shape_full}"
)
# actually set the data
self.__data_full = value
else:
raise TypeError(f"Cannot set field values to {value}")
# set reference to valid data
self._data_valid = self.__data_full[self._idx_valid]
@property
def _data_flat(self) -> np.ndarray:
""":class:`~numpy.ndarray`: flat version of discretized data with ghost cells"""
# flatten the first dimension of the internal data by creating a view and then
# setting the new shape. This disallows accidental copying of the data
data_flat = self._data_full.view()
data_flat.shape = (-1, *self.grid._shape_full)
return data_flat
@_data_flat.setter
def _data_flat(self, value: np.ndarray) -> None:
"""set the full data including ghost cells from a flattened array"""
# simply set the data -> this might need to be overwritten
self._data_full = value
@property
def writeable(self) -> bool:
"""bool: whether the field data can be changed or not"""
return not hasattr(self, "_data_full") or self._data_full.flags.writeable
@writeable.setter
def writeable(self, value: bool) -> None:
"""set whether the field data can be changed or not"""
self._data_full.flags.writeable = value
self._data_valid.flags.writeable = value
@property
def label(self) -> Optional[str]:
"""str: the name of the field"""
return self._label
@label.setter
def label(self, value: str = None):
"""set the new label of the field"""
if value is None or isinstance(value, str):
self._label = value
else:
raise TypeError("Label must be a string or None")
[docs] @classmethod
def from_state(
cls, attributes: Dict[str, Any], data: np.ndarray = None
) -> FieldBase:
"""create a field from given state.
Args:
attributes (dict):
The attributes that describe the current instance
data (:class:`~numpy.ndarray`, optional):
Data values at the support points of the grid defining the field
"""
# base class was chosen => select correct class from attributes
class_name = attributes.pop("class")
if class_name == cls.__name__:
raise RuntimeError(f"Cannot reconstruct abstract class `{class_name}`")
# call possibly overwritten classmethod from subclass
return cls._subclasses[class_name].from_state(attributes, data)
[docs] @classmethod
def from_file(cls, filename: str) -> FieldBase:
"""create field from data stored in a file
Field can be written to a file using :meth:`FieldBase.to_file`.
Example:
Write a field to a file and then read it back:
.. code-block:: python
field = pde.ScalarField(...)
field.write_to("test.hdf5")
field_copy = pde.FieldBase.from_file("test.hdf5")
Args:
filename (str): Path to the file being read
Returns:
:class:`FieldBase`: The field with the appropriate sub-class
"""
import h5py
from .collection import FieldCollection
with h5py.File(filename, "r") as fp:
if "class" in fp.attrs:
# this should be a field collection
assert json.loads(fp.attrs["class"]) == "FieldCollection"
obj = FieldCollection._from_hdf_dataset(fp)
elif len(fp) == 1:
# a single field is stored in the data
dataset = fp[list(fp.keys())[0]] # retrieve only dataset
obj = cls._from_hdf_dataset(dataset) # type: ignore
else:
raise RuntimeError(
"Multiple data fields were found in the "
"file but no FieldCollection is expected"
)
return obj
@classmethod
def _from_hdf_dataset(cls, dataset) -> FieldBase:
"""construct a field by reading data from an hdf5 dataset"""
# copy attributes from hdf
attributes = dict(dataset.attrs)
# determine class
class_name = json.loads(attributes.pop("class"))
field_cls = cls._subclasses[class_name]
# unserialize the attributes
attributes = field_cls.unserialize_attributes(attributes)
return field_cls.from_state(attributes, data=dataset)
@property
def grid(self) -> GridBase:
"""GridBase: The grid on which the field is defined"""
return self._grid
[docs] def to_file(self, filename: str, **kwargs):
r"""store field in a file
The extension of the filename determines what format is being used. If it ends
in `.h5` or `.hdf`, the Hierarchical Data Format is used. The other supported
format are images, where only the most typical formats are supported.
To load the field back from the file, you may use :meth:`FieldBase.from_file`.
Example:
Write a field to a file and then read it back:
.. code-block:: python
field = pde.ScalarField(...)
field.write_to("test.hdf5")
field_copy = pde.FieldBase.from_file("test.hdf5")
Args:
filename (str):
Path where the data is stored
\**kwargs:
Additional parameters may be supported for some formats
"""
extension = Path(filename).suffix.lower()
if extension in {".hdf", ".hdf5", ".he5", ".h5"}:
# save data in hdf5 format
import h5py
with h5py.File(filename, "w") as fp:
self._write_hdf_dataset(fp, **kwargs)
elif extension in {".png", ".jpg", ".jpeg", ".tif", ".pdf", ".svg"}:
# save data as an image
self._write_to_image(filename, **kwargs)
else:
raise ValueError(f"Do not know how to save data to `*{extension}`")
def _write_hdf_dataset(self, hdf_path, key: str = "data"):
"""write data to a given hdf5 path `hdf_path`"""
# write the data
dataset = hdf_path.create_dataset(key, data=self.data)
# write attributes
for key, value in self.attributes_serialized.items():
dataset.attrs[key] = value
def _write_to_image(self, filename: str, **kwargs):
"""write data to image
Args:
filename (str): The path to the image that will be created
"""
raise NotImplementedError(f"Cannot save {self.__class__.__name__} as an image")
[docs] @abstractmethod
def copy(
self: TField,
*,
label: str = None,
dtype=None,
) -> TField:
pass
[docs] def assert_field_compatible(self, other: FieldBase, accept_scalar: bool = False):
"""checks whether `other` is compatible with the current field
Args:
other (FieldBase):
The other field this one is compared to
accept_scalar (bool, optional):
Determines whether it is acceptable that `other` is an instance of
:class:`~pde.fields.ScalarField`.
"""
from .scalar import ScalarField # @Reimport
# check whether they are the same class
is_scalar = accept_scalar and isinstance(other, ScalarField)
class_compatible = self.__class__ == other.__class__ or is_scalar
if not class_compatible:
raise TypeError(f"Fields {self} and {other} are incompatible")
# check whether the associated grids are identical
if not self.grid.compatible_with(other.grid):
raise ValueError(f"Grids {self.grid} and {other.grid} are incompatible")
@property
def dtype(self):
"""returns the numpy dtype of the underlying data"""
# this property is necessary to support np.iscomplexobj for DataFieldBases
return self.data.dtype
@property
def is_complex(self) -> bool:
"""bool: whether the field contains real or complex data"""
return np.iscomplexobj(self.data) # type: ignore
@property
def attributes(self) -> Dict[str, Any]:
"""dict: describes the state of the instance (without the data)"""
return {
"class": self.__class__.__name__,
"grid": self.grid,
"label": self.label,
"dtype": self.dtype,
}
@property
def attributes_serialized(self) -> Dict[str, str]:
"""dict: serialized version of the attributes"""
results = {}
for key, value in self.attributes.items():
if key == "grid":
results[key] = value.state_serialized
elif key == "dtype":
results[key] = json.dumps(value.str)
else:
results[key] = json.dumps(value)
return results
[docs] @classmethod
def unserialize_attributes(cls, attributes: Dict[str, str]) -> Dict[str, Any]:
"""unserializes the given attributes
Args:
attributes (dict):
The serialized attributes
Returns:
dict: The unserialized attributes
"""
# base class was chosen => select correct class from attributes
class_name = json.loads(attributes["class"])
if class_name == cls.__name__:
raise RuntimeError(f"Cannot reconstruct abstract class `{class_name}`")
# call possibly overwritten classmethod from subclass
return cls._subclasses[class_name].unserialize_attributes(attributes)
def __eq__(self, other):
"""test fields for equality, ignoring the label"""
if not isinstance(other, self.__class__):
return NotImplemented
return self.grid == other.grid and np.array_equal(self.data, other.data)
def _unary_operation(self: TField, op: Callable) -> TField:
"""perform an unary operation on this field
Args:
op (callable):
A function calculating the result
Returns:
FieldBase: An field that contains the result of the operation.
"""
data = op(self.data)
result = self.copy(dtype=data.dtype)
result.data = data
return result
@property
def real(self: TField) -> TField:
""":class:`FieldBase`: Real part of the field"""
return self._unary_operation(np.real)
@property
def imag(self: TField) -> TField:
""":class:`FieldBase`: Imaginary part of the field"""
return self._unary_operation(np.imag)
[docs] def conjugate(self: TField) -> TField:
"""returns complex conjugate of the field"""
return self._unary_operation(np.conjugate)
def __neg__(self):
"""return the negative of the current field"""
return self._unary_operation(np.negative)
def _binary_operation(
self, other, op: Callable, scalar_second: bool = True
) -> FieldBase:
"""perform a binary operation between this field and `other`
Args:
other (number of FieldBase):
The second term of the operator
op (callable):
A binary function calculating the result
scalar_second (bool):
Flag determining whether the second operator must be a scalar
Returns:
FieldBase: An field that contains the result of the operation. If
`scalar_second == True`, the type of FieldBase is the same as `self`
"""
# determine the dtype of the output
if isinstance(other, FieldBase):
# right operator is a field
from .scalar import ScalarField # @Reimport
# determine the dtype of the result of the operation
dtype = np.result_type(self.data, other.data)
if scalar_second:
# right operator must be a scalar or scalar field
if not isinstance(other, ScalarField):
raise TypeError("Right operator must be a scalar field")
self.grid.assert_grid_compatible(other.grid)
result: FieldBase = self.copy(dtype=dtype)
elif isinstance(self, ScalarField):
# left operator is a scalar field (right can be tensor)
self.grid.assert_grid_compatible(other.grid)
result = other.copy(dtype=dtype)
else:
# left operator is tensor and right one might be anything
self.assert_field_compatible(other, accept_scalar=True)
result = self.copy(dtype=dtype)
op(self.data, other.data, out=result.data)
else:
# the second operator is a number or a numpy array
dtype = np.result_type(self.data, other)
result = self.copy(dtype=dtype)
op(self.data, other, out=result.data)
return result
def _binary_operation_inplace(
self: TField, other, op_inplace: Callable, scalar_second: bool = True
) -> TField:
"""perform an in-place binary operation between this field and `other`
Args:
other (number of FieldBase):
The second term of the operator
op_inplace (callable):
A binary function storing its result in the first argument
scalar_second (bool):
Flag determining whether the second operator must be a scalar.
Returns:
FieldBase: The field `self` with updated data
"""
if isinstance(other, FieldBase):
# right operator is a field
from .scalar import ScalarField # @Reimport
if scalar_second:
# right operator must be a scalar
if not isinstance(other, ScalarField):
raise TypeError("Right operator must be a scalar field")
self.grid.assert_grid_compatible(other.grid)
else:
# left operator is tensor and right one might be anything
self.assert_field_compatible(other, accept_scalar=True)
# operators only affect the valid data and do not touch the ghost cells
op_inplace(self.data, other.data, out=self.data)
else:
# the second operator is a number or a numpy array
op_inplace(self.data, other, out=self.data)
return self
def __add__(self, other) -> FieldBase:
"""add two fields"""
return self._binary_operation(other, np.add, scalar_second=False)
__radd__ = __add__
def __iadd__(self: TField, other) -> TField:
"""add `other` to the current field"""
return self._binary_operation_inplace(other, np.add, scalar_second=False)
def __sub__(self, other) -> FieldBase:
"""subtract two fields"""
return self._binary_operation(other, np.subtract, scalar_second=False)
def __rsub__(self, other) -> FieldBase:
"""subtract two fields"""
return self._binary_operation(
other, lambda x, y, out: np.subtract(y, x, out=out), scalar_second=False
)
def __isub__(self: TField, other) -> TField:
"""add `other` to the current field"""
return self._binary_operation_inplace(other, np.subtract, scalar_second=False)
def __mul__(self, other) -> FieldBase:
"""multiply field by value"""
return self._binary_operation(other, np.multiply, scalar_second=False)
__rmul__ = __mul__
def __imul__(self: TField, other) -> TField:
"""multiply field by value"""
return self._binary_operation_inplace(other, np.multiply, scalar_second=False)
def __truediv__(self, other) -> FieldBase:
"""divide field by value"""
return self._binary_operation(other, np.true_divide, scalar_second=True)
def __rtruediv__(self, other) -> FieldBase:
"""divide field by value"""
def rdivision(x, y, **kwargs):
return np.true_divide(y, x, **kwargs)
return self._binary_operation(other, rdivision, scalar_second=True)
def __itruediv__(self: TField, other) -> TField:
"""divide field by value"""
return self._binary_operation_inplace(other, np.true_divide, scalar_second=True)
def __pow__(self, exponent: float) -> FieldBase:
"""raise data of the field to a certain power"""
if not np.isscalar(exponent):
raise NotImplementedError("Only scalar exponents are supported")
return self._binary_operation(exponent, np.power, scalar_second=True)
def __ipow__(self: TField, exponent: float) -> TField:
"""raise data of the field to a certain power in-place"""
if not np.isscalar(exponent):
raise NotImplementedError("Only scalar exponents are supported")
self.data **= exponent
return self
[docs] def apply(
self: TField, func: Callable, out: Optional[TField] = None, label: str = None
) -> TField:
"""applies a function to the data and returns it as a field
Args:
func (callable or str):
The (vectorized) function being applied to the data or the name
of an operator that is defined for the grid of this field.
out (FieldBase, optional):
Optional field into which the data is written
label (str, optional):
Name of the returned field
Returns:
Field with new data. This is stored at `out` if given.
"""
if out is None:
out = self.copy(label=label)
out.data = func(self.data)
else:
self.assert_field_compatible(out)
out.data[:] = func(self.data)
if label:
out.label = label
return out
[docs] @abstractmethod
def get_line_data(
self, scalar: str = "auto", extract: str = "auto"
) -> Dict[str, Any]:
pass
[docs] @abstractmethod
def get_image_data(self) -> Dict[str, Any]:
pass
[docs] @abstractmethod
def plot(self, *args, **kwargs):
pass
@abstractmethod
def _get_napari_data(self, **kwargs) -> Dict[str, Dict[str, Any]]:
pass
[docs] def plot_interactive(self, viewer_args: Dict[str, Any] = None, **kwargs):
"""create an interactive plot of the field using :mod:`napari`
For a detailed description of the launched program, see the
`napari webpage <http://napari.org/>`_.
Args:
viewer_args (dict):
Arguments passed to :class:`napari.viewer.Viewer` to affect the viewer.
**kwargs:
Extra arguments passed to the plotting function
"""
if viewer_args is None:
viewer_args = {}
if self.grid.num_axes == 1:
raise RuntimeError(
"Interactive plotting needs at least 2 spatial dimensions"
)
with napari_viewer(self.grid, **viewer_args) as viewer:
napari_add_layers(viewer, self._get_napari_data(**kwargs))
TDataField = TypeVar("TDataField", bound="DataFieldBase")
[docs]class DataFieldBase(FieldBase, metaclass=ABCMeta):
"""abstract base class for describing fields of single entities"""
rank: int # the rank of the tensor field
def __init__(
self,
grid: GridBase,
data: Optional[Union[ArrayLike, str]] = "zeros",
*,
label: str = None,
dtype=None,
with_ghost_cells: bool = False,
):
"""
Args:
grid (:class:`~pde.grids.base.GridBase`):
Grid defining the space on which this field is defined.
data (Number or :class:`~numpy.ndarray`, optional):
Field values at the support points of the grid. The flag
`with_ghost_cells` determines whether this data array contains values
for the ghost cells, too. The resulting field will contain real data
unless the `data` argument contains complex values. Special values are
"zeros" or None, initializing the field with zeros, and "empty", just
allocating memory with unspecified values.
label (str, optional):
Name of the field
dtype (numpy dtype):
The data type of the field. All the numpy dtypes are supported. If
omitted, it will be determined from `data` automatically.
with_ghost_cells (bool):
Indicates whether the ghost cells are included in data
"""
if isinstance(data, self.__class__):
# special case where a DataFieldBase is supplied
data_arr = number_array(data._data_full, dtype=dtype, copy=True)
super().__init__(grid, data=data_arr, label=label)
elif with_ghost_cells:
# use full data without copying (unless necessary)
if data is None or isinstance(data, str):
raise ValueError("`data` must be supplied if with_ghost_cells==True")
data_arr = number_array(data, dtype=dtype, copy=False)
super().__init__(grid, data=data_arr, label=label)
else:
# ghost cells are not supplied => allocate full array and write valid data
full_shape = (grid.dim,) * self.rank + tuple(s + 2 for s in grid.shape)
if data is None:
# fill full data with zeros by default
data_arr = np.zeros(full_shape, dtype=dtype)
super().__init__(grid, data=data_arr, label=label)
elif isinstance(data, str):
# allocate empty data
if data == "empty":
data_arr = np.empty(full_shape, dtype=dtype)
elif data == "zeros":
data_arr = np.zeros(full_shape, dtype=dtype)
elif data == "ones":
data_arr = np.ones(full_shape, dtype=dtype)
else:
raise ValueError(f"Unknown data '{data}'")
super().__init__(grid, data=data_arr, label=label)
elif isinstance(data, DataFieldBase):
# copy the full data from the supplied field
grid.assert_grid_compatible(data.grid)
data_arr = number_array(data._data_full, dtype=dtype, copy=True)
super().__init__(grid, data=data_arr, label=label)
else:
# initialize empty data and set the valid data
data_arr = number_array(data, dtype=dtype, copy=False)
empty_data = np.empty(full_shape, dtype=data_arr.dtype)
super().__init__(grid, data=empty_data, label=label)
self.data = data_arr
def __repr__(self) -> str:
"""return instance as string"""
class_name = self.__class__.__name__
result = f"{class_name}(grid={self.grid!r}, data={self.data}"
if self.label:
result += f', label="{self.label}"'
return result + ")"
def __str__(self) -> str:
"""return instance as string"""
result = (
f"{self.__class__.__name__}(grid={self.grid}, "
f"data=Array{self.data.shape}"
)
if self.label:
result += f', label="{self.label}"'
return result + ")"
[docs] @classmethod
def random_normal(
cls,
grid: GridBase,
mean: float = 0,
std: float = 1,
*,
scaling: str = "none",
label: Optional[str] = None,
dtype=None,
rng: np.random.Generator = None,
) -> DataFieldBase:
"""create field with normal distributed random values
These values are uncorrelated in space. A complex field is returned when either
`mean` or `std` is a complex number. In this case, the real and imaginary parts
of these arguments are used to determine the distribution of the real and
imaginary parts of the resulting field. Consequently,
:code:`ScalarField.random_normal(grid, 0, 1 + 1j)` creates a complex field where
the real and imaginary parts are chosen from a standard normal distribution.
Args:
grid (:class:`~pde.grids.base.GridBase`):
Grid defining the space on which this field is defined
mean (float):
Mean of the Gaussian distribution
std (float):
Standard deviation of the Gaussian distribution.
scaling (str):
Determines how the values are scaled. Possible choices are 'none'
(values are drawn from a normal distribution with given mean and
standard deviation) or 'physical' (the variance of the random number is
scaled by the inverse volume of the grid cell; this is for instance
useful for concentration fields, which vary less in larger volumes).
label (str, optional):
Name of the field
dtype (numpy dtype):
The data type of the field. If omitted, it defaults to `double` if both
`mean` and `std` are real, otherwise it is `complex`.
rng (:class:`~numpy.random.Generator`):
Random number generator (default: :func:`~numpy.random.default_rng()`)
"""
if rng is None:
rng = np.random.default_rng()
if scaling == "none":
scale = 1
elif scaling == "physical":
scale = 1 / np.sqrt(grid.cell_volumes)
else:
raise ValueError(f"Unknown noise scaling {scaling}")
# determine the shape of the data array
shape = (grid.dim,) * cls.rank + grid.shape
if np.iscomplexobj(mean) or np.iscomplexobj(std):
# create complex random numbers for the field
real_part = np.real(mean) + np.real(std) * scale * rng.normal(size=shape)
imag_part = np.imag(mean) + np.imag(std) * scale * rng.normal(size=shape)
data = real_part + 1j * imag_part
else:
# create real random numbers for the field
data = mean + std * scale * rng.normal(size=shape)
return cls(grid, data=data, label=label, dtype=dtype)
[docs] @classmethod
def random_harmonic(
cls,
grid: GridBase,
modes: int = 3,
harmonic=np.cos,
axis_combination=np.multiply,
*,
label: Optional[str] = None,
dtype=None,
rng: np.random.Generator = None,
) -> DataFieldBase:
r"""create a random field build from harmonics
The resulting fields will be highly correlated in space and can thus
serve for testing differential operators.
With the default settings, the resulting field :math:`c_i(\mathbf{x})`
is given by
.. math::
c_i(\mathbf{x}) = \prod_{\alpha=1}^N \sum_{j=1}^M a_{ij\alpha}
\cos\left(\frac{2 \pi x_\alpha}{j L_\alpha}\right) \;,
where :math:`N` is the number of spatial dimensions, each with length
:math:`L_\alpha`, :math:`M` is the number of modes given by `modes`, and
:math:`a_{ij\alpha}` are random amplitudes, chosen from a uniform
distribution over the interval [0, 1].
Note that the product could be replaced by a sum when
`axis_combination = numpy.add` and the :math:`\cos()` could be any other
function given by the parameter `harmonic`.
Args:
grid (:class:`~pde.grids.base.GridBase`):
Grid defining the space on which this field is defined
modes (int):
Number :math:`M` of harmonic modes
harmonic (callable):
Determines which harmonic function is used. Typical values are
:func:`numpy.sin` and :func:`numpy.cos`, which basically relate
to different boundary conditions applied at the grid boundaries.
axis_combination (callable):
Determines how values from different axis are combined. Typical
choices are :func:`numpy.multiply` and :func:`numpy.add`
resulting in products and sums of the values along axes,
respectively.
label (str, optional):
Name of the field
dtype (numpy dtype):
The data type of the field. If omitted, it defaults to `double`.
rng (:class:`~numpy.random.Generator`):
Random number generator (default: :func:`~numpy.random.default_rng()`)
"""
if rng is None:
rng = np.random.default_rng()
tensor_shape = (grid.dim,) * cls.rank
data = np.empty(tensor_shape + grid.shape)
# determine random field for each component
for index in np.ndindex(*tensor_shape):
data_axis = []
# random harmonic function along each axis
for i in range(len(grid.axes)):
# choose wave vectors
ampl = rng.random(size=modes) # amplitudes
x = discretize_interval(0, 2 * np.pi, grid.shape[i])[0]
data_axis.append(
sum(a * harmonic(n * x) for n, a in enumerate(ampl, 1))
)
# full dataset is product of values along axes
data[index] = functools.reduce(axis_combination.outer, data_axis)
return cls(grid, data=data, label=label, dtype=dtype)
[docs] @classmethod
def random_colored(
cls,
grid: GridBase,
exponent: float = 0,
scale: float = 1,
*,
label: Optional[str] = None,
dtype=None,
rng: np.random.Generator = None,
) -> DataFieldBase:
r"""create a field of random values with colored noise
The spatially correlated values obey
.. math::
\langle c_i(\boldsymbol k) c_j(\boldsymbol k’) \rangle =
\Gamma^2 |\boldsymbol k|^\nu \delta_{ij}
\delta(\boldsymbol k - \boldsymbol k’)
in spectral space. The special case :math:`\nu = 0` corresponds to white
noise. Note that the components of vector or tensor fields are
uncorrelated.
Args:
grid (:class:`~pde.grids.base.GridBase`):
Grid defining the space on which this field is defined
exponent (float):
Exponent :math:`\nu` of the power spectrum
scale (float):
Scaling factor :math:`\Gamma` determining noise strength
label (str, optional):
Name of the field
dtype (numpy dtype):
The data type of the field. If omitted, it defaults to `double`.
rng (:class:`~numpy.random.Generator`):
Random number generator (default: :func:`~numpy.random.default_rng()`)
"""
# get function making colored noise
from ..tools.spectral import make_colored_noise
make_noise = make_colored_noise(
grid.shape, dx=grid.discretization, exponent=exponent, scale=scale, rng=rng
)
# create random fields for each tensor component
tensor_shape = (grid.dim,) * cls.rank
data = np.empty(tensor_shape + grid.shape)
# determine random field for each component
for index in np.ndindex(*tensor_shape):
data[index] = make_noise()
return cls(grid, data=data, label=label, dtype=dtype)
[docs] @classmethod
def get_class_by_rank(cls, rank: int) -> Type[DataFieldBase]:
"""return a :class:`DataFieldBase` subclass describing a field with a given rank
Args:
rank (int): The rank of the tensor field
"""
for field_cls in cls._subclasses.values():
if (
issubclass(field_cls, DataFieldBase)
and not isabstract(field_cls)
and field_cls.rank == rank
):
return field_cls
raise RuntimeError(f"Could not find field class for rank {rank}")
[docs] @classmethod
def from_state(
cls, attributes: Dict[str, Any], data: np.ndarray = None
) -> DataFieldBase:
"""create a field from given state.
Args:
attributes (dict):
The attributes that describe the current instance
data (:class:`~numpy.ndarray`, optional):
Data values at the support points of the grid defining the field
"""
if "class" in attributes:
class_name = attributes.pop("class")
assert class_name == cls.__name__
# create the instance from the attributes
return cls(attributes.pop("grid"), data=data, **attributes)
[docs] def copy(
self: TDataField,
*,
label: str = None,
dtype=None,
) -> TDataField:
"""return a copy of the data, but not of the grid
Args:
label (str, optional):
Name of the returned field
dtype (numpy dtype):
The data type of the field. If omitted, it will be determined from
`data` automatically or the dtype of the current field is used.
"""
if label is None:
label = self.label
return self.__class__(
self.grid,
data=np.array(self._data_full, dtype=dtype, copy=True),
label=label,
dtype=dtype,
with_ghost_cells=True,
)
@property
def data_shape(self) -> Tuple[int, ...]:
"""tuple: the shape of the data at each grid point"""
return (self.grid.dim,) * self.rank
[docs] @classmethod
def unserialize_attributes(cls, attributes: Dict[str, str]) -> Dict[str, Any]:
"""unserializes the given attributes
Args:
attributes (dict):
The serialized attributes
Returns:
dict: The unserialized attributes
"""
results = {}
for key, value in attributes.items():
if key == "grid":
results[key] = GridBase.from_state(value)
else:
results[key] = json.loads(value)
return results
def _write_to_image(self, filename: str, **kwargs):
r"""write data to image
Args:
filename (str):
The path to the image that will be created
\**kwargs:
Additional keyword arguments that affect the image. For
instance, some fields support a `scalar` argument that
determines how they are converted to a scalar. Non-Cartesian
grids might support a `performance_goal` argument to influence
how an image is created from the raw data. Finally, the
remaining arguments are are passed to
:func:`matplotlib.pyplot.imsave` to affect the appearance.
"""
import matplotlib.pyplot as plt
# obtain image data
get_image_args = {}
for key in ["performance_goal", "scalar"]:
if key in kwargs:
get_image_args[key] = kwargs.pop(key)
img = self.get_image_data(**get_image_args)
kwargs.setdefault("cmap", "gray")
plt.imsave(filename, img["data"].T, origin="lower", **kwargs)
def _make_interpolator_scipy(
self, method: str = "linear", fill: Number = None, **kwargs
) -> Callable[[np.ndarray, np.ndarray], NumberOrArray]:
r"""returns a function that can be used to interpolate values.
This uses :class:`scipy.interpolate.RegularGridInterpolator` and thus supports
extra options supplied by keyword arguments. Note that this interpolator does
not respect periodic boundary conditions, yet.
Args:
method (str):
The method used for interpolation. Currently, "linear" and "nearest" are
supported by :class:`~scipy.interpolate.RegularGridInterpolator`.
fill (Number, optional):
Determines how values out of bounds are handled. If `None`, a
`ValueError` is raised when out-of-bounds points are requested.
Otherwise, the given value is returned.
\**kwargs: All keyword arguments are forwarded to
:class:`~scipy.interpolate.RegularGridInterpolator`
Returns:
A function which returns interpolated values when called with
arbitrary positions within the space of the grid.
"""
from scipy import interpolate
coords_src = self.grid.axes_coords
grid_dim = len(self.grid.axes)
if self.rank == 0:
# scalar field => data layout is already usable
data = self.data
revert_shape = False
else:
# spatial dimensions need to come first => move data to last axis
assert self.data.shape[:-grid_dim] == self.data_shape
idx = (slice(None),) + (slice(1, -1),) * self.grid.num_axes
data_flat = self._data_flat[idx]
data_flat = np.moveaxis(data_flat, 0, -1)
new_shape = self.grid.shape + (-1,)
data = data_flat.reshape(new_shape)
assert data.shape[-1] == self.grid.dim ** self.rank
revert_shape = True
# set the fill behavior
if fill is None:
kwargs["bounds_error"] = True
else:
kwargs["bounds_error"] = False
kwargs["fill_value"] = fill
# prepare the interpolator
intp = interpolate.RegularGridInterpolator(
coords_src, data, method=method, **kwargs
)
# determine under which conditions the axes can be squeezed
if grid_dim == 1:
scalar_dim = 0
else:
scalar_dim = 1
# introduce wrapper function to process arrays
def interpolator(point: np.ndarray, **kwargs) -> NumberOrArray:
"""return the interpolated value at the position `point`"""
point = np.atleast_1d(point)
# apply periodic boundary conditions to grid point
point = self.grid.normalize_point(point, reflect=False)
out = intp(point, **kwargs)
if point.ndim == scalar_dim or point.ndim == point.size == 1:
out = out[0]
if revert_shape:
# revert the shuffling of spatial and local axes
out = np.moveaxis(out, point.ndim - 1, 0)
out = out.reshape(self.data_shape + point.shape[:-1])
return out # type: ignore
return interpolator # type: ignore
def _make_interpolator_numba(
self, fill: Number = None, **kwargs
) -> Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]:
"""return a compiled interpolator
Args:
fill (Number, optional):
Determines how values out of bounds are handled. If `None`, a
`ValueError` is raised when out-of-bounds points are requested.
Otherwise, the given value is returned.
Returns:
A function which returns interpolated values when called with
arbitrary positions within the space of the grid.
"""
grid = self.grid
num_axes = self.grid.num_axes
data_shape = self.data_shape
# convert `fill` to dtype of data
if fill is not None:
if self.rank == 0:
fill = self.data.dtype.type(fill)
else:
fill = np.broadcast_to(fill, self.data_shape).astype(self.data.dtype)
# use the full array and assume BCs are set via ghost points
interpolate_single = grid._make_interpolator_compiled(fill=fill, **kwargs)
# extract information about the data field
if kwargs.get("full_data", False):
get_data_array = make_array_constructor(self._data_full)
else:
get_data_array = make_array_constructor(self.data)
dim_error_msg = f"Dimension of point does not match axes count {num_axes}"
@jit
def interpolator(point: np.ndarray, data: np.ndarray = None) -> np.ndarray:
"""return the interpolated value at the position `point`
Args:
point (:class:`~numpy.ndarray`):
The list of points. This point coordinates should be given along the
last axis, i.e., the shape should be `(..., dim)`.
data (:class:`~numpy.ndarray`, optional):
The discretized field values. If omitted, the data of the current
field is used, which should be the default. However, this option can
be useful to interpolate other fields defined on the same grid
without recreating the interpolator. If a data array is supplied, it
needs to be the valid data, without ghost points.
Returns:
:class:`~numpy.ndarray`: The interpolated values at the points
"""
# check input
point = np.atleast_1d(point)
if point.shape[-1] != num_axes:
raise DimensionError(dim_error_msg)
point_shape = point.shape[:-1]
if data is None:
# reconstruct data field from memory address
data = get_data_array()
# interpolate at every valid point
out = np.empty(data_shape + point_shape, dtype=data.dtype)
for idx in np.ndindex(*point_shape):
out[(...,) + idx] = interpolate_single(data, point[idx])
return out
# store a reference to the data so it is not garbage collected too early
interpolator._data = self.data
return interpolator # type: ignore
[docs] @cached_method()
def make_interpolator(
self,
method: str = "linear",
*,
fill: Number = None,
backend: str = "numba",
**kwargs,
) -> Callable[[np.ndarray, np.ndarray], NumberOrArray]:
r"""returns a function that can be used to interpolate values.
Args:
backend (str):
The accepted values `scipy` and `numba` determine the backend that is
used for the interpolation.
method (str):
Determines the method being used for interpolation. Typical values that
are "nearest" and "linear", but the supported values depend on the
chosen `backend`.
fill (Number, optional):
Determines how values out of bounds are handled. If `None`, a
`ValueError` is raised when out-of-bounds points are requested.
Otherwise, the given value is returned.
\**kwargs:
Additional keyword arguments are passed to the individual
interpolator methods and can be used to further affect the
behavior.
The scipy implementations use scipy.interpolate.RegularGridInterpolator
and thus do not respect boundary conditions. Additional keyword
arguments are directly forwarded to the constructor of
`RegularGridInterpolator`.
The numba implementation respect boundary conditions, which can be set
using the `bc` keywords argument. Supported values are the same as for
the operators, e.g., the Laplacian. If no boundary conditions are
specified, natural boundary conditions are assumed, which are periodic
conditions for periodic axes and Neumann conditions otherwise.
Returns:
A function which returns interpolated values when called with
arbitrary positions within the space of the grid.
"""
if backend == "scipy":
return self._make_interpolator_scipy(method=method, fill=fill, **kwargs)
elif backend == "numba":
if method != "linear":
raise NotImplementedError(
"The numba backend currently only supports linear interpolation"
)
return self._make_interpolator_numba(fill=fill, **kwargs)
else:
raise ValueError(f"Unknown backend `{backend}`")
[docs] def interpolate(
self,
point: np.ndarray,
*,
backend: str = "numba",
method: str = "linear",
fill: Number = None,
**kwargs,
) -> NumberOrArray:
r"""interpolate the field to points between support points
Args:
point (:class:`~numpy.ndarray`):
The points at which the values should be obtained. This is given in grid
coordinates.
backend (str):
The accepted values "scipy" and "numba" determine the backend that is
used for the interpolation.
method (str):
Determines the method being used for interpolation. Typical values that
are "nearest" and "linear", but the supported values depend on the
chosen `backend`.
fill (Number, optional):
Determines how values out of bounds are handled. If `None`, a
`ValueError` is raised when out-of-bounds points are requested.
Otherwise, the given value is returned.
\**kwargs:
Additional keyword arguments are forwarded to the method
:meth:`DataFieldBase.make_interpolator`.
Returns:
:class:`~numpy.ndarray`: the values of the field
"""
interpolator = self.make_interpolator(
backend=backend, method=method, fill=fill, **kwargs
)
return interpolator(np.asarray(point)) # type: ignore
[docs] def interpolate_to_grid(
self: TDataField,
grid: GridBase,
*,
backend: str = "numba",
method: str = "linear",
fill: Number = None,
label: Optional[str] = None,
) -> TDataField:
"""interpolate the data of this field to another grid.
Args:
grid (:class:`~pde.grids.base.GridBase`):
The grid of the new field onto which the current field is
interpolated.
backend (str):
The accepted values "scipy" and "numba" determine the backend that is
used for the interpolation.
method (str):
Determines the method being used for interpolation. Typical values that
are "nearest" and "linear", but the supported values depend on the
chosen `backend`.
fill (Number, optional):
Determines how values out of bounds are handled. If `None`, a
`ValueError` is raised when out-of-bounds points are requested.
Otherwise, the given value is returned.
label (str, optional):
Name of the returned field
Returns:
Field of the same rank as the current one.
"""
if self.grid.dim != grid.dim:
raise DimensionError(
f"Grid dimensions are incompatible "
f"({self.grid.dim:d} != {grid.dim:d})"
)
# determine the points at which data needs to be calculated
if isinstance(grid, CartesianGridBase):
# convert Cartesian coordinates to coordinates in current grid
points = self.grid.point_from_cartesian(grid.cell_coords)
elif (
self.grid.__class__ is grid.__class__
and self.grid.num_axes == grid.num_axes
):
# convert within the same grid class
points = grid.cell_coords
else:
# this type of interpolation is not supported
grid_in = self.grid.__class__.__name__
grid_out = grid.__class__.__name__
raise NotImplementedError(f"Can't interpolate from {grid_in} to {grid_out}")
# interpolate the data to the grid
data = self.interpolate(points, backend=backend, method=method, fill=fill)
return self.__class__(grid, data, label=label)
[docs] def add_interpolated(self, point: np.ndarray, amount: ArrayLike) -> None:
"""deprecated alias of method `insert`"""
# this was deprecated on 2021-02-23
warnings.warn(
"`add_interpolated` is deprecated. Use `insert` instead",
DeprecationWarning,
)
self.insert(point, amount)
[docs] def insert(self, point: np.ndarray, amount: ArrayLike) -> None:
"""adds an (integrated) value to the field at an interpolated position
Args:
point (:class:`~numpy.ndarray`):
The point inside the grid where the value is added. This is
given in grid coordinates.
amount (Number or :class:`~numpy.ndarray`):
The amount that will be added to the field. The value describes
an integrated quantity (given by the field value times the
discretization volume). This is important for consistency with
different discretizations and in particular grids with
non-uniform discretizations.
"""
point = np.atleast_1d(point)
amount = np.broadcast_to(amount, self.data_shape)
grid = self.grid
grid_dim = len(grid.axes)
if point.size != grid_dim or point.ndim != 1:
raise DimensionError(f"Dimension mismatch for point {point}")
# determine the grid coordinates next to the chosen points
low = np.array(grid.axes_bounds)[:, 0]
c_l, d_l = np.divmod((point - low) / grid.discretization - 0.5, 1.0)
c_l = c_l.astype(np.intc) # support points to the left of the chosen points
w_l = 1 - d_l # weights of the low point
w_h = d_l # weights of the high point
# apply periodic boundary conditions to grid coordinates
c_h = c_l + 1 # support points to the right of the chosen points
for ax in np.flatnonzero(grid.periodic):
c_l[..., ax] %= grid.shape[ax]
c_h[..., ax] %= grid.shape[ax]
# determine the valid points and the total weight in first iteration
total_weight = 0
cells = []
for i in np.ndindex(*((2,) * grid_dim)):
coords = np.choose(i, [c_l, c_h])
if np.all(coords >= 0) and np.all(coords < grid.shape):
weight = np.prod(np.choose(i, [w_l, w_h]))
total_weight += weight
cells.append((tuple(coords), weight))
if total_weight == 0:
raise DomainError("Point lies outside grid")
# alter each point in second iteration
for coords, weight in cells:
chng = weight * amount / (total_weight * grid.cell_volumes[coords])
self.data[(Ellipsis,) + coords] += chng
[docs] @fill_in_docstring
def get_boundary_values(
self, axis: int, upper: bool, bc: Optional[BoundariesData] = None
) -> NumberOrArray:
"""get the field values directly on the specified boundary
Args:
axis (int):
The axis perpendicular to the boundary
upper (bool):
Whether the boundary is at the upper side of the axis
bc:
The boundary conditions applied to the field.
{ARG_BOUNDARIES_OPTIONAL}
Returns:
:class:`~numpy.ndarray`: The discretized values on the boundary
"""
if bc is not None:
self.set_ghost_cells(bc=bc)
l_wall: List[Union[slice, int]] = [slice(1, -1)] * self.grid.num_axes
l_ghost = l_wall.copy()
if upper:
l_wall[axis] = -2
l_ghost[axis] = -1
else:
l_wall[axis] = 1
l_ghost[axis] = 0
i_wall = (...,) + tuple(l_wall)
i_ghost = (...,) + tuple(l_ghost)
return (self._data_full[i_wall] + self._data_full[i_ghost]) / 2 # type: ignore
[docs] @fill_in_docstring
def set_ghost_cells(self, bc: BoundariesData, *, args=None) -> None:
"""set the boundary values on virtual points for all boundaries
Args:
bc (str or list or tuple or dict):
The boundary conditions applied to the field.
{ARG_BOUNDARIES}
args:
Additional arguments that might be supported by special boundary
conditions.
"""
bcs = self.grid.get_boundary_conditions(bc, rank=self.rank)
bcs.set_ghost_cells(self._data_full, args=args)
@property
@abstractmethod
def integral(self) -> NumberOrArray:
pass
[docs] @abstractmethod
def to_scalar(
self, scalar: str = "auto", *, label: Optional[str] = None
) -> "ScalarField":
pass
@property
def average(self) -> NumberOrArray:
"""determine the average of data
This is calculated by integrating each component of the field over space
and dividing by the grid volume
"""
return self.integral / self.grid.volume
@property
def fluctuations(self) -> NumberOrArray:
""":class:`~numpy.ndarray`: fluctuations over the entire space.
The fluctuations are defined as the standard deviation of the data scaled by the
cell volume. This definition makes the fluctuations independent of the
discretization. It corresponds to the physical scaling available in the
:func:`~DataFieldBase.random_normal`.
Returns:
:class:`~numpy.ndarray`: A tensor with the same rank of the field,
specifying the fluctuations of each component of the tensor field
individually. Consequently, a simple scalar is returned for a
:class:`~pde.fields.scalar.ScalarField`.
"""
scaled_data = self.data * np.sqrt(self.grid.cell_volumes)
axes = tuple(range(self.rank, self.data.ndim))
return np.std(scaled_data, axis=axes) # type: ignore
@property
def magnitude(self) -> float:
"""float: determine the magnitude of the field.
This is calculated by getting a scalar field using the default arguments of the
:func:`to_scalar` method, averaging the result over the whole grid, and taking
the absolute value.
"""
if self.rank == 0:
return abs(self.average) # type: ignore
elif self.rank > 0:
return abs(self.to_scalar().average) # type: ignore
else:
raise AssertionError("Rank must be non-negative")
@fill_in_docstring
def _apply_operator(
self,
operator: str,
bc: Optional[BoundariesData],
out: Optional[DataFieldBase] = None,
*,
label: str = None,
args: Dict[str, Any] = None,
**kwargs,
) -> DataFieldBase:
r"""apply an operator and return result as a field
Args:
operator (str):
An identifier determining the registered on the grid.
bc:
The boundary conditions applied to the field.
{ARG_BOUNDARIES_OPTIONAL}
out (ScalarField, optional):
Optional scalar field to which the result is written.
label (str, optional):
Name of the returned field
**kwargs:
Additional arguments affecting how the operator behaves.
Returns:
Field with new data. This is stored at `out` if given.
"""
# get information about the operator
operator_info = self.grid._get_operator_info(operator)
out_cls = self.get_class_by_rank(operator_info.rank_out)
# prepare the output field
if out is None:
out = out_cls(self.grid, data="empty", label=label, dtype=self.dtype)
elif not isinstance(out, out_cls):
raise RankError(f"`out` must be a {out_cls.__name__}")
else:
self.grid.assert_grid_compatible(out.grid)
if label is not None:
out.label = label
if bc is None:
# apply the operator without imposing boundary conditions
op_raw = self.grid.make_operator_no_bc(operator_info, **kwargs)
op_raw(self._data_full, out.data)
else:
# apply the operator with boundary conditions
op_with_bcs = self.grid.make_operator(operator_info, bc=bc, **kwargs)
# we need to convert the python dict into a numba dict for compilation
op_with_bcs(self.data, out=out.data, args=numba_dict(args))
return out
[docs] def smooth(
self: TDataField,
sigma: float = 1,
*,
out: Optional[TDataField] = None,
label: str = None,
) -> TDataField:
"""applies Gaussian smoothing with the given standard deviation
This function respects periodic boundary conditions of the underlying
grid, using reflection when no periodicity is specified.
sigma (float):
Gives the standard deviation of the smoothing in real length units
(default: 1)
out (FieldBase, optional):
Optional field into which the smoothed data is stored. Setting this
to the input field enables in-place smoothing.
label (str, optional):
Name of the returned field
Returns:
Field with smoothed data. This is stored at `out` if given.
"""
from scipy import ndimage
# allocate memory for storing output
if out is None:
out = self.__class__(self.grid, label=self.label)
else:
self.assert_field_compatible(out)
# apply Gaussian smoothing for each axis
data_in = self.data # use the field data as input
data_out = out.data # write to the output
for axis in range(-len(self.grid.axes), 0):
sigma_dx = sigma / self.grid.discretization[axis]
mode = "wrap" if self.grid.periodic[axis] else "reflect"
ndimage.gaussian_filter1d(
data_in, sigma=sigma_dx, axis=axis, output=data_out, mode=mode
)
data_in = data_out # use this smoothed data as input for next axis
# return the data in the correct field class
if label:
out.label = label
return out
[docs] def get_line_data(
self, scalar: str = "auto", extract: str = "auto"
) -> Dict[str, Any]:
"""return data for a line plot of the field
Args:
scalar (str or int):
The method for extracting scalars as described in
:meth:`DataFieldBase.to_scalar`.
extract (str):
The method used for extracting the line data. See the docstring
of the grid method `get_line_data` to find supported values.
Returns:
dict: Information useful for performing a line plot of the field
"""
# turn field into scalar field
scalar_data = self.to_scalar(scalar).data
# extract the line data
data = self.grid.get_line_data(scalar_data, extract=extract)
if "label_y" in data and data["label_y"]:
if self.label:
data["label_y"] = f"{self.label} ({data['label_y']})"
else:
data["label_y"] = self.label
return data
[docs] def get_image_data(
self, scalar: str = "auto", transpose: bool = False, **kwargs
) -> Dict[str, Any]:
r"""return data for plotting an image of the field
Args:
scalar (str or int):
The method for extracting scalars as described in
:meth:`DataFieldBase.to_scalar`.
transpose (bool):
Determines whether the transpose of the data should is plotted
\**kwargs:
Additional parameters are forwarded to `grid.get_image_data`
Returns:
dict: Information useful for plotting an image of the field
"""
# turn field into scalar field
scalar_data = self.to_scalar(scalar).data
# remove imaginary parts
if self.is_complex:
self._logger.warning("Only the absolute value of complex data is shown")
scalar_data = abs(scalar_data)
# extract the image data
data = self.grid.get_image_data(scalar_data, **kwargs)
data["title"] = self.label
if transpose:
# adjust image data such that the transpose is plotted
data["data"] = data["data"].T
data["label_x"], data["label_y"] = data["label_y"], data["label_x"]
return data
[docs] def get_vector_data(self, **kwargs) -> Dict[str, Any]:
r"""return data for a vector plot of the field
Args:
\**kwargs: Additional parameters are forwarded to
`grid.get_image_data`
Returns:
dict: Information useful for plotting an vector field
"""
raise NotImplementedError()
def _plot_line(
self,
ax,
scalar: str = "auto",
extract: str = "auto",
ylabel: str = None,
**kwargs,
) -> PlotReference:
r"""visualize a field using a 1d line plot
Args:
ax (:class:`matplotlib.axes.Axes`):
Figure axes to be used for plotting.
scalar (str or int):
The method for extracting scalars as described in
:meth:`DataFieldBase.to_scalar`.
extract (str):
The method used for extracting the line data.
ylabel (str):
Label of the y-axis. If omitted, the label is chosen
automatically from the data field.
\**kwargs:
Additional keyword arguments are passed to
:func:`matplotlib.pyplot.plot`
Returns:
:class:`PlotReference`: Instance that contains information to update
the plot with new data later.
"""
# obtain data for the plot
line_data = self.get_line_data(scalar=scalar, extract=extract)
# warn if there is an imaginary part
if np.any(np.iscomplex(line_data["data_y"])):
self._logger.warning("Only the real part of the complex data is shown")
# do the plot
(line2d,) = ax.plot(line_data["data_x"], line_data["data_y"].real, **kwargs)
# set some default properties
ax.set_xlabel(line_data["label_x"])
if ylabel is None:
ylabel = line_data.get("label_y", self.label)
if ylabel:
ax.set_ylabel(ylabel)
return PlotReference(ax, line2d, {"scalar": scalar, "extract": extract})
def _update_line_plot(self, reference: PlotReference) -> None:
"""update a line plot with the current field values
Args:
reference (:class:`PlotReference`):
The reference to the plot that is updated
"""
import matplotlib as mpl
# obtain data for the plot
scalar = reference.parameters.get("scalar", "auto")
extract = reference.parameters.get("extract", "auto")
line_data = self.get_line_data(scalar=scalar, extract=extract)
line2d = reference.element
if isinstance(line2d, mpl.lines.Line2D):
# update old plot
line2d.set_xdata(line_data["data_x"])
line2d.set_ydata(line_data["data_y"].real)
else:
raise ValueError(f"Unsupported plot reference {reference}")
def _plot_image(
self,
ax,
colorbar: bool = True,
scalar: str = "auto",
transpose: bool = False,
**kwargs,
) -> PlotReference:
r"""visualize a field using a 2d density plot
Args:
ax (:class:`matplotlib.axes.Axes`):
Figure axes to be used for plotting.
colorbar (bool):
Determines whether a colorbar is shown
scalar (str or int):
The method for extracting scalars as described in
:meth:`DataFieldBase.to_scalar`.
transpose (bool):
Determines whether the transpose of the data should is plotted
\**kwargs:
Additional keyword arguments that affect the image. For instance, some
fields support a `scalar` argument that determines how they are
converted to a scalar. Non-Cartesian grids might support a
`performance_goal` argument to influence how an image is created from
the raw data. Finally, the remaining arguments are are passed to
:func:`matplotlib.pyplot.imshow` to affect the appearance.
Returns:
:class:`PlotReference`: Instance that contains information to update
the plot with new data later.
"""
# obtain image data with appropriate parameters
data_kws = {}
for arg in ["performance_goal", "scalar", "transpose"]:
if arg in kwargs:
data_kws[arg] = kwargs.pop(arg)
data = self.get_image_data(scalar, transpose, **data_kws)
if ax is None:
import matplotlib.pyplot as plt
# create new figure
ax = plt.subplots()[1]
# plot the image
kwargs.setdefault("origin", "lower")
kwargs.setdefault("interpolation", "none")
axes_image = ax.imshow(data["data"].T, extent=data["extent"], **kwargs)
# set some default properties
ax.set_xlabel(data["label_x"])
ax.set_ylabel(data["label_y"])
ax.set_title(data.get("title", self.label))
if colorbar:
from ..tools.plotting import add_scaled_colorbar
add_scaled_colorbar(axes_image, ax=ax)
# store parameters of the plot that are necessary for updating
parameters = {"scalar": scalar, "transpose": transpose}
if "vmin" in kwargs:
parameters["vmin"] = kwargs["vmin"]
if "vmax" in kwargs:
parameters["vmax"] = kwargs["vmax"]
return PlotReference(ax, axes_image, parameters)
def _update_image_plot(self, reference: PlotReference) -> None:
"""update an image plot with the current field values
Args:
reference (:class:`PlotReference`):
The reference to the plot that is updated
"""
# obtain image data
p = reference.parameters
data = self.get_image_data(
scalar=p.get("scalar", "auto"), transpose=p.get("transpose", False)
)
# update the axes image
reference.element.set_data(data["data"].T)
# adjust the colorbar limits
vmin = p["vmin"] if "vmin" in p else data["data"].min()
vmax = p["vmax"] if "vmax" in p else data["data"].max()
reference.element.set_clim(vmin, vmax)
def _plot_vector(
self,
ax,
method: str = "quiver",
transpose: bool = False,
max_points: int = 16,
**kwargs,
) -> PlotReference:
r"""visualize a field using a 2d vector plot
Args:
ax (:class:`matplotlib.axes.Axes`):
Figure axes to be used for plotting.
method (str):
Plot type that is used. This can be either `quiver` or
`streamplot`.
transpose (bool):
Determines whether the transpose of the data should be plotted.
max_points (int):
The maximal number of points that is used along each axis. This
argument is only used for quiver plots.
\**kwargs:
Additional keyword arguments are passed to
:func:`matplotlib.pyplot.quiver` or
:func:`matplotlib.pyplot.streamplot`.
Returns:
:class:`PlotReference`: Instance that contains information to update
the plot with new data later.
"""
# store the parameters of this plot for later updating
parameters = {
"method": method,
"transpose": transpose,
"kwargs": kwargs,
}
if method == "quiver":
# plot vector field using a quiver plot
data = self.get_vector_data(transpose=transpose, max_points=max_points)
parameters["max_points"] = max_points # only save for quiver plot
element = ax.quiver(
data["x"], data["y"], data["data_x"].T, data["data_y"].T, **kwargs
)
elif method == "streamplot":
# plot vector field using a streamplot
data = self.get_vector_data(transpose=transpose)
element = ax.streamplot(
data["x"], data["y"], data["data_x"].T, data["data_y"].T, **kwargs
)
else:
raise ValueError(f"Vector plot `{method}` is not supported.")
# set some default properties of the plot
ax.set_aspect("equal")
ax.set_xlabel(data["label_x"])
ax.set_ylabel(data["label_y"])
ax.set_title(data.get("title", self.label))
return PlotReference(ax, element, parameters)
def _update_vector_plot(self, reference: PlotReference) -> None:
"""update a vector plot with the current field values
Args:
reference (:class:`PlotReference`):
The reference to the plot that is updated
"""
# extract general parameters
method = reference.parameters.get("method", "quiver")
transpose = reference.parameters.get("transpose", False)
if method == "quiver":
# update the data of a quiver plot
max_points = reference.parameters.get("max_points")
data = self.get_vector_data(transpose=transpose, max_points=max_points)
reference.element.set_UVC(data["data_x"], data["data_y"])
elif method == "streamplot":
# update a streamplot by redrawing it completely
ax = reference.ax
kwargs = reference.parameters.get("kwargs", {})
data = self.get_vector_data(transpose=transpose)
# remove old streamplot
ax.cla()
# update with new streamplot
reference.element = ax.streamplot(
data["x"], data["y"], data["data_x"].T, data["data_y"].T, **kwargs
)
else:
raise ValueError(f"Vector plot `{method}` is not supported.")
def _update_plot(self, reference: PlotReference) -> None:
"""update a plot with the current field values
Args:
reference (:class:`PlotReference`):
The reference to the plot to updated
"""
import matplotlib as mpl
# update the plot based on the given reference
el = reference.element
if isinstance(el, mpl.lines.Line2D):
self._update_line_plot(reference)
elif isinstance(el, mpl.image.AxesImage):
self._update_image_plot(reference)
elif isinstance(el, (mpl.quiver.Quiver, mpl.streamplot.StreamplotSet)):
self._update_vector_plot(reference)
else:
raise ValueError(f"Unknown plot element {el.__class__.__name__}")
[docs] @plot_on_axes(update_method="_update_plot")
def plot(self, kind: str = "auto", **kwargs) -> PlotReference:
r"""visualize the field
Args:
kind (str):
Determines the visualizations. Supported values are `image`,
`line`, `vector`, or `interactive`. Alternatively, `auto`
determines the best visualization based on the field itself.
{PLOT_ARGS}
\**kwargs:
All additional keyword arguments are forwarded to the actual
plotting function.
Returns:
:class:`PlotReference`: Instance that contains information to update
the plot with new data later.
"""
# determine the correct kind of plotting
if kind == "auto":
# determine best plot for this field
if (
isinstance(self, DataFieldBase)
and self.rank == 1
and self.grid.dim == 2
):
kind = "vector"
elif len(self.grid.shape) == 1:
kind = "line"
else:
kind = "image"
elif kind == "quiver":
kind = "vector"
kwargs["method"] = "quiver"
elif kind == "streamplot":
kind = "vector"
kwargs["method"] = "streamplot"
# do the actual plotting
if kind == "image":
reference = self._plot_image(**kwargs)
elif kind == "line":
reference = self._plot_line(**kwargs)
elif kind == "vector":
reference = self._plot_vector(**kwargs)
else:
raise ValueError(
f"Unsupported plot `{kind}`. Possible choices are `image`, `line`, "
"`vector`, or `auto`."
)
return reference
def _get_napari_layer_data(
self, scalar: str = "auto", args: Dict[str, Any] = None
) -> Dict[str, Any]:
"""returns data for plotting on a single napari layer
Args:
scalar (str):
Indicates how the scalar field is generated; see `to_scalar`
args (dict):
Additional arguments returned in the result, which affect how the layer
is shown.
Returns:
dict: all the information necessary to plot this field
"""
result = {} if args is None else args.copy()
result.setdefault("scale", self.grid.discretization)
result.setdefault("rgb", False)
result["type"] = "image"
result["data"] = self.to_scalar(scalar).data
return result
def _get_napari_data(self, **kwargs) -> Dict[str, Dict[str, Any]]:
r"""returns data for plotting this field
Args:
\**kwargs: all arguments are forwarded to `_get_napari_layer_data`
Returns:
dict: all the information necessary to plot this field
"""
name = "Field" if self.label is None else self.label
return {name: self._get_napari_layer_data(**kwargs)}
def _get_field_class_by_rank(rank: int) -> Type[DataFieldBase]:
"""return a field class associated with a certain rank
Args:
rank (int): The rank of the tensor field
"""
# deprecated on 2021-09-17
warnings.warn("Use DataFieldBase.get_class_by_rank instead.", DeprecationWarning)
return DataFieldBase.get_class_by_rank(rank)