diff --git a/examples/fmm-error.py b/examples/fmm-error.py index 9b4d0fca7..5b925bd58 100644 --- a/examples/fmm-error.py +++ b/examples/fmm-error.py @@ -58,7 +58,7 @@ def main(): "unaccel_qbx": unaccel_qbx, "qbx": unaccel_qbx.copy(fmm_order=10), "targets": PointsTarget(actx.freeze(actx.from_numpy(fplot.points))) - }) + }, auto_where=("qbx", "targets")) density_discr = places.get_discretization("unaccel_qbx") nodes = thaw(density_discr.nodes(), actx) diff --git a/pytential/symbolic/execution.py b/pytential/symbolic/execution.py index 9feb1578a..b4d56ad59 100644 --- a/pytential/symbolic/execution.py +++ b/pytential/symbolic/execution.py @@ -634,39 +634,62 @@ class _GeometryCollectionConnectionCacheKey: class GeometryCollection: """A mapping from symbolic identifiers ("place IDs", typically strings) to 'geometries', where a geometry can be a - :class:`pytential.source.PotentialSource` - or a :class:`pytential.target.TargetBase`. + :class:`~pytential.source.PotentialSource`, a + :class:`~pytential.target.TargetBase` or a + :class:`~meshmode.discretization.Discretization`. + This class is meant to hold a specific combination of sources and targets serve to host caches of information derived from them, e.g. FMM trees of subsets of them, as well as related common subexpressions such as metric terms. + Refinement of :class:`pytential.qbx.QBXLayerPotentialSource` entries is + performed on demand, i.e. on calls to :meth:`get_discretization` with + a specific *discr_stage*. To perform refinement explicitly, call + :func:`pytential.qbx.refinement.refine_geometry_collection`, + which allows more customization of the refinement process through + parameters. + + .. automethod:: __init__ + + .. attribute:: auto_source + + Default :class:`~pytential.symbolic.primitives.DOFDescriptor` for the + source geometry. + + .. attribute:: auto_target + + Default :class:`~pytential.symbolic.primitives.DOFDescriptor` for the + target geometry. + .. automethod:: get_geometry - .. automethod:: get_connection .. automethod:: get_discretization + .. automethod:: get_connection .. automethod:: copy .. automethod:: merge - Refinement of :class:`pytential.qbx.QBXLayerPotentialSource` entries is - performed on demand, or it may be performed by explcitly calling - :func:`pytential.qbx.refinement.refine_geometry_collection`, - which allows more customization of the refinement process through - parameters. """ def __init__(self, places, auto_where=None): - """ + r""" :arg places: a scalar, tuple of or mapping of symbolic names to geometry objects. Supported objects are :class:`~pytential.source.PotentialSource`, :class:`~pytential.target.TargetBase` and :class:`~meshmode.discretization.Discretization`. If this is a mapping, the keys that are strings must be valid Python identifiers. - :arg auto_where: location identifier for each geometry object, used - to denote specific discretizations, e.g. in the case where - *places* is a :class:`~pytential.source.LayerPotentialSourceBase`. - By default, we assume + The tuple should contain only two entries, denoting the source and + target geometries for layer potential evaluation, identified by + *auto_where*. + + :arg auto_where: a single or a tuple of two + :class:`~pytential.symbolic.primitives.DOFDescriptor`\ s, or values + that can be converted to one using + :func:`~pytential.symbolic.primitives.as_dofdesc`. The two + descriptors are used to define the default source and target + geometries for layer potential evaluations. + By default, they are set to :class:`~pytential.symbolic.primitives.DEFAULT_SOURCE` and :class:`~pytential.symbolic.primitives.DEFAULT_TARGET` for sources and targets, respectively. @@ -705,6 +728,15 @@ def __init__(self, places, auto_where=None): # {{{ validate + # check auto_where + if auto_source.geometry not in self.places: + raise ValueError("'auto_where' source geometry is not in the " + f"collection: '{auto_source.geometry}'") + + if auto_target.geometry not in self.places: + raise ValueError("'auto_where' target geometry is not in the " + f"collection: '{auto_target.geometry}'") + # check allowed identifiers for name in self.places: if not isinstance(name, str): @@ -715,8 +747,9 @@ def __init__(self, places, auto_where=None): # check allowed types for p in self.places.values(): if not isinstance(p, (PotentialSource, TargetBase, Discretization)): - raise TypeError("Values in 'places' must be discretization, targets " - "or layer potential sources.") + raise TypeError( + "Values in 'places' must be discretization, targets " + f"or layer potential sources, got '{type(p).__name__}'") # check ambient_dim from pytools import is_single_valued @@ -746,8 +779,9 @@ def _get_discr_from_cache(self, geometry, discr_stage): key = (geometry, discr_stage) if key not in cache: - raise KeyError("cached discretization does not exist on '{}'" - "for stage '{}'".format(geometry, discr_stage)) + raise KeyError( + "cached discretization does not exist on '{geometry}'" + "for stage '{discr_stage}'") return cache[key] @@ -756,7 +790,8 @@ def _add_discr_to_cache(self, discr, geometry, discr_stage): key = (geometry, discr_stage) if key in cache: - raise RuntimeError("trying to overwrite the cache") + raise RuntimeError("trying to overwrite the discretization cache of " + f"'{geometry}' for stage '{discr_stage}'") cache[key] = discr @@ -765,8 +800,8 @@ def _get_conn_from_cache(self, geometry, from_stage, to_stage): key = (geometry, from_stage, to_stage) if key not in cache: - raise KeyError("cached connection does not exist on '{}' " - "from '{}' to '{}'".format(geometry, from_stage, to_stage)) + raise KeyError("cached connection does not exist on " + f"'{geometry}' from stage '{from_stage}' to '{to_stage}'") return cache[key] @@ -775,7 +810,8 @@ def _add_conn_to_cache(self, conn, geometry, from_stage, to_stage): key = (geometry, from_stage, to_stage) if key in cache: - raise RuntimeError("trying to overwrite the cache") + raise RuntimeError("trying to overwrite the connection cache of " + f"'{geometry}' from stage '{from_stage}' to '{to_stage}'") cache[key] = conn @@ -801,19 +837,38 @@ def _get_qbx_discretization(self, geometry, discr_stage): # }}} def get_connection(self, from_dd, to_dd): + """Construct a connection from *from_dd* to *to_dd* geometries. + + :param from_dd: a :class:`~pytential.symbolic.primitives.DOFDescriptor` + or a value that can be converted to one using + :func:`~pytential.symbolic.primitives.as_dofdesc`. + :param to_dd: as *from_dd*. + + :returns: an object compatible with the + :class:`~meshmode.discretization.connection.DiscretizationConnection` + interface. + """ + from pytential.symbolic.dof_connection import connection_from_dds return connection_from_dds(self, from_dd, to_dd) def get_discretization(self, geometry, discr_stage=None): - """ - :arg dofdesc: a :class:`~pytential.symbolic.primitives.DOFDescriptor` - specifying the desired discretization. - - :return: a geometry object in the collection corresponding to the - key *dofdesc*. If it is a - :class:`~pytential.source.LayerPotentialSourceBase`, we look for - the corresponding :class:`~meshmode.discretization.Discretization` - in its attributes instead. + """Get the geometry or discretization in the collection. + + If a specific QBX stage discretization is requested, refinement is + performed on demand and cached for subsequent calls. + + :param geometry: the identifier of the geometry in the collection. + :param discr_stage: if the geometry is a + :class:`~pytential.source.LayerPotentialSourceBase`, this denotes + the QBX stage of the returned discretization. Can be one of + :class:`~pytential.symbolic.primitives.QBX_SOURCE_STAGE1` (default), + :class:`~pytential.symbolic.primitives.QBX_SOURCE_STAGE2` or + :class:`~pytential.symbolic.primitives.QBX_SOURCE_QUAD_STAGE2`. + + :returns: a geometry object in the collection or a + :class:`~meshmode.discretization.Discretization` corresponding to + *discr_stage*. """ if discr_stage is None: discr_stage = sym.QBX_SOURCE_STAGE1 @@ -830,13 +885,18 @@ def get_discretization(self, geometry, discr_stage=None): return discr def get_geometry(self, geometry): + """ + :param geometry: the identifier of the geometry in the collection. + """ + try: return self.places[geometry] except KeyError: - raise KeyError("geometry not in the collection: '{}'".format( - geometry)) + raise KeyError(f"geometry not in the collection: '{geometry}'") def copy(self, places=None, auto_where=None): + """Get a shallow copy of the geometry collection.""" + places = self.places if places is None else places return type(self)( places=places.copy(), @@ -845,7 +905,7 @@ def copy(self, places=None, auto_where=None): def merge(self, places): """Merges two geometry collections and returns the new collection. - :arg places: A :class:`dict` or :class:`GeometryCollection` to + :param places: a :class:`dict` or :class:`GeometryCollection` to merge with the current collection. If it is empty, a copy of the current collection is returned. """ @@ -859,10 +919,10 @@ def merge(self, places): return self.copy(places=new_places) def __repr__(self): - return "{}({})".format(type(self).__name__, repr(self.places)) + return "{type(self).__name__}({repr(self.places)})" def __str__(self): - return "{}({})".format(type(self).__name__, str(self.places)) + return "{type(self).__name__}({repr(self.places)})" # }}} diff --git a/test/test_layer_pot.py b/test/test_layer_pot.py index f9ce9c83b..bd91b1b08 100644 --- a/test/test_layer_pot.py +++ b/test/test_layer_pot.py @@ -187,7 +187,8 @@ def test_off_surface_eval_vs_direct(actx_factory, do_plot=False): places = GeometryCollection({ "direct_qbx": direct_qbx, "fmm_qbx": fmm_qbx, - "target": ptarget}) + "target": ptarget, + }, auto_where=("fmm_qbx", "target")) direct_density_discr = places.get_discretization("direct_qbx") fmm_density_discr = places.get_discretization("fmm_qbx") @@ -273,7 +274,8 @@ def test_single_plus_double_with_single_fmm(actx_factory, do_plot=False): places = GeometryCollection({ "direct_qbx": direct_qbx, "fmm_qbx": fmm_qbx, - "target": ptarget}) + "target": ptarget, + }, auto_where=("fmm_qbx", "target")) direct_density_discr = places.get_discretization("direct_qbx") fmm_density_discr = places.get_discretization("fmm_qbx") @@ -413,7 +415,8 @@ def test_unregularized_off_surface_fmm_vs_direct(actx_factory): places = GeometryCollection({ "unregularized_direct": direct, "unregularized_fmm": fmm, - "targets": ptarget}) + "targets": ptarget, + }, auto_where=("unregularized_fmm", "targets")) # }}} diff --git a/test/test_tools.py b/test/test_tools.py index 5657ed4ec..1fa024388 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -145,7 +145,7 @@ def test_geometry_collection_caching(actx_factory): # construct a geometry collection from pytential import GeometryCollection - places = GeometryCollection(dict(zip(sources, lpots))) + places = GeometryCollection(dict(zip(sources, lpots)), auto_where=sources[0]) print(places.places) # check on-demand refinement