diff --git a/doc/operators.rst b/doc/operators.rst index 550a78023..881afd945 100644 --- a/doc/operators.rst +++ b/doc/operators.rst @@ -3,6 +3,7 @@ Discontinuous Galerkin operators .. automodule:: grudge.op .. automodule:: grudge.trace_pair +.. automodule:: grudge.flux_differencing Transfering data between discretizations diff --git a/grudge/array_context.py b/grudge/array_context.py index 00e6fb0d6..62f117871 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -38,7 +38,9 @@ from pytools.tag import Tag from meshmode.array_context import ( PyOpenCLArrayContext as _PyOpenCLArrayContextBase, - PytatoPyOpenCLArrayContext as _PytatoPyOpenCLArrayContextBase) + # TODO: Get SingleGridWorkBalancingPytatoArrayContext merged into main + SingleGridWorkBalancingPytatoArrayContext as _PytatoPyOpenCLArrayContextBase, + ) from pyrsistent import pmap from warnings import warn diff --git a/grudge/flux_differencing.py b/grudge/flux_differencing.py new file mode 100644 index 000000000..8dc6421d2 --- /dev/null +++ b/grudge/flux_differencing.py @@ -0,0 +1,264 @@ +"""Grudge module for flux-differencing in entropy-stable DG methods +Flux-differencing +----------------- +.. autofunction:: volume_flux_differencing +""" + +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +from arraycontext import ( + ArrayContext, + map_array_container, + freeze +) +from arraycontext.context import ArrayOrContainerT + +from functools import partial + +from meshmode.transform_metadata import FirstAxisIsElementsTag +from meshmode.dof_array import DOFArray + +from grudge.discretization import DiscretizationCollection +from grudge.dof_desc import DOFDesc + +from pytools import memoize_in, keyed_memoize_in + +import numpy as np + + +def _reference_skew_symmetric_hybridized_sbp_operators( + actx: ArrayContext, + base_element_group, + vol_quad_element_group, + face_quad_element_group, dtype): + @keyed_memoize_in( + actx, _reference_skew_symmetric_hybridized_sbp_operators, + lambda base_grp, quad_vol_grp, face_quad_grp: ( + base_grp.discretization_key(), + quad_vol_grp.discretization_key(), + face_quad_grp.discretization_key())) + def get_reference_skew_symetric_hybridized_diff_mats( + base_grp, quad_vol_grp, face_quad_grp): + from meshmode.discretization.poly_element import diff_matrices + from modepy import faces_for_shape, face_normal + from grudge.interpolation import ( + volume_quadrature_interpolation_matrix, + surface_quadrature_interpolation_matrix + ) + from grudge.op import reference_inverse_mass_matrix + + # {{{ Volume operators + + weights = quad_vol_grp.quadrature_rule().weights + vdm_q = actx.to_numpy( + volume_quadrature_interpolation_matrix(actx, base_grp, quad_vol_grp)) + inv_mass_mat = actx.to_numpy( + reference_inverse_mass_matrix(actx, base_grp)) + p_mat = inv_mass_mat @ (vdm_q.T * weights) + + # }}} + + # {{{ Surface operators + + faces = faces_for_shape(base_grp.shape) + nfaces = len(faces) + # NOTE: assumes same quadrature rule on all faces + face_weights = np.tile(face_quad_grp.quadrature_rule().weights, nfaces) + face_normals = [face_normal(face) for face in faces] + nnods_per_face = face_quad_grp.nunit_dofs + e = np.ones(shape=(nnods_per_face,)) + nrstj = [ + # nsrtJ = nhat * Jhatf, where nhat is the reference normal + # and Jhatf is the Jacobian det. of the transformation from + # the face of the reference element to the reference face. + np.concatenate([np.sign(nhat[idx])*e for nhat in face_normals]) + for idx in range(base_grp.dim) + ] + b_mats = [np.diag(face_weights*nrstj[d]) for d in range(base_grp.dim)] + vf_mat = actx.to_numpy( + surface_quadrature_interpolation_matrix( + actx, + base_element_group=base_grp, + face_quad_element_group=face_quad_grp)) + zero_mat = np.zeros((nfaces*nnods_per_face, nfaces*nnods_per_face), + dtype=dtype) + + # }}} + + # {{{ Hybridized (volume + surface) operators + + q_mats = [p_mat.T @ (weights * vdm_q.T @ vdm_q) @ diff_mat @ p_mat + for diff_mat in diff_matrices(base_grp)] + e_mat = vf_mat @ p_mat + q_skew_hybridized = np.asarray( + [ + np.block( + [[q_mats[d] - q_mats[d].T, e_mat.T @ b_mats[d]], + [-b_mats[d] @ e_mat, zero_mat]] + ) for d in range(base_grp.dim) + ], + order="C" + ) + + # }}} + + return actx.freeze(actx.from_numpy(q_skew_hybridized)) + + return get_reference_skew_symetric_hybridized_diff_mats( + base_element_group, + vol_quad_element_group, + face_quad_element_group + ) + + +def _single_axis_hybridized_derivative_kernel( + dcoll, dd_quad, dd_face_quad, xyz_axis, flux_matrix): + if not dcoll._has_affine_groups(): + raise NotImplementedError("Not implemented for non-affine elements yet.") + + if not isinstance(flux_matrix, DOFArray): + return map_array_container( + partial(_single_axis_hybridized_derivative_kernel, + dcoll, dd_quad, dd_face_quad, xyz_axis), + flux_matrix + ) + + from grudge.geometry import \ + area_element, inverse_surface_metric_derivative + from grudge.interpolation import ( + volume_and_surface_interpolation_matrix, + volume_and_surface_quadrature_interpolation + ) + + actx = flux_matrix.array_context + + # FIXME: This is kinda meh + def inverse_jac_matrix(): + @memoize_in( + dcoll, + (_single_axis_hybridized_derivative_kernel, dd_quad, dd_face_quad)) + def _inv_surf_metric_deriv(): + return freeze( + actx.np.stack( + [ + actx.np.stack( + [ + volume_and_surface_quadrature_interpolation( + dcoll, dd_quad, dd_face_quad, + area_element(actx, dcoll) + * inverse_surface_metric_derivative( + actx, dcoll, + rst_ax, xyz_axis + ) + ) for rst_ax in range(dcoll.dim) + ] + ) for xyz_axis in range(dcoll.ambient_dim) + ] + ), + actx + ) + return _inv_surf_metric_deriv() + + return DOFArray( + actx, + data=tuple( + # r for rst axis + actx.einsum("ik,rej,rij,eij->ek", + volume_and_surface_interpolation_matrix( + actx, + base_element_group=bgrp, + vol_quad_element_group=qvgrp, + face_quad_element_group=qafgrp + ), + ijm_i[xyz_axis], + _reference_skew_symmetric_hybridized_sbp_operators( + actx, + bgrp, + qvgrp, + qafgrp, + fmat_i.dtype + ), + fmat_i, + arg_names=("Vh_mat_t", "inv_jac_t", "Q_mat", "F_mat"), + tagged=(FirstAxisIsElementsTag(),)) + + for bgrp, qvgrp, qafgrp, fmat_i, ijm_i in zip( + dcoll.discr_from_dd("vol").groups, + dcoll.discr_from_dd(dd_quad).groups, + dcoll.discr_from_dd(dd_face_quad).groups, + flux_matrix, + inverse_jac_matrix() + ) + ) + ) + + +def volume_flux_differencing( + dcoll: DiscretizationCollection, + dd_quad: DOFDesc, + dd_face_quad: DOFDesc, + flux_matrices: ArrayOrContainerT) -> ArrayOrContainerT: + r"""Computes the volume contribution of the DG divergence operator using + flux-differencing: + .. math:: + \mathrm{VOL} = \sum_{i=1}^{d} + \begin{bmatrix} + \mathbf{V}_q \\ \mathbf{V}_f + \end{bmatrix}^T + \left( + \left( \mathbf{Q}_{i} - \mathbf{Q}^T_{i} \right) + \circ \mathbf{F}_{i} + \right)\mathbf{1} + where :math:`\circ` denotes the + `Hadamard product `__, + :math:`\mathbf{F}_{i}` are matrices whose entries are computed + as the evaluation of an entropy-conserving two-point flux function + (e.g. :func:`grudge.models.euler.divergence_flux_chandrashekar`) + and :math:`\mathbf{Q}_{i} - \mathbf{Q}^T_{i}` are the skew-symmetric + hybridized differentiation operators defined in (15) of + `this paper `__. + :arg flux_matrices: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them containing + evaluations of two-point flux. + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. + """ +# from grudge.op import _div_helper +# +# return _div_helper( +# dcoll, +# lambda _, i, flux_mat_i: _single_axis_hybridized_derivative_kernel( +# dcoll, dd_quad, dd_face_quad, i, flux_mat_i), +# flux_matrices +# ) + + from grudge.tools import rec_map_subarrays + return rec_map_subarrays( + f=lambda vec: sum( + local_d_dx(dcol, i, vec_i) + for i, vec_i, in enumerate(vec)), + in_shape=(dcoll.ambient_dim,) + out_shape=(), + is_scalar=lambda v: isinstance(v, DOFArray), + ary=vecs) diff --git a/grudge/interpolation.py b/grudge/interpolation.py index 61bdf1a13..05c0fc72e 100644 --- a/grudge/interpolation.py +++ b/grudge/interpolation.py @@ -30,11 +30,26 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import numpy as np + +from arraycontext import ( + ArrayContext, + map_array_container +) +from arraycontext.context import ArrayOrContainerT + +from functools import partial + +from meshmode.transform_metadata import FirstAxisIsElementsTag from grudge.discretization import DiscretizationCollection +from grudge.dof_desc import DOFDesc +from meshmode.dof_array import DOFArray + +from pytools import keyed_memoize_in # FIXME: Should revamp interp and make clear distinctions # between projection and interpolations. # Related issue: https://github.com/inducer/grudge/issues/38 @@ -46,3 +61,120 @@ def interp(dcoll: DiscretizationCollection, src, tgt, vec): from grudge.projection import project return project(dcoll, src, tgt, vec) +def volume_quadrature_interpolation_matrix( + actx: ArrayContext, base_element_group, vol_quad_element_group): + @keyed_memoize_in( + actx, volume_quadrature_interpolation_matrix, + lambda base_grp, vol_quad_grp: (base_grp.discretization_key(), + vol_quad_grp.discretization_key())) + def get_volume_vand(base_grp, vol_quad_grp): + from modepy import vandermonde + + basis = base_grp.basis_obj() + vdm_inv = np.linalg.inv(vandermonde(basis.functions, + base_grp.unit_nodes)) + vdm_q = vandermonde(basis.functions, vol_quad_grp.unit_nodes) @ vdm_inv + return actx.freeze(actx.from_numpy(vdm_q)) + + return get_volume_vand(base_element_group, vol_quad_element_group) + + +def surface_quadrature_interpolation_matrix( + actx: ArrayContext, base_element_group, face_quad_element_group): + @keyed_memoize_in( + actx, surface_quadrature_interpolation_matrix, + lambda base_grp, face_quad_grp: (base_grp.discretization_key(), + face_quad_grp.discretization_key())) + def get_surface_vand(base_grp, face_quad_grp): + nfaces = base_grp.mesh_el_group.nfaces + assert face_quad_grp.nelements == nfaces * base_grp.nelements + + from modepy import vandermonde, faces_for_shape + + basis = base_grp.basis_obj() + vdm_inv = np.linalg.inv(vandermonde(basis.functions, + base_grp.unit_nodes)) + faces = faces_for_shape(base_grp.shape) + # NOTE: Assumes same quadrature rule on each face + face_quadrature = face_quad_grp.quadrature_rule() + + surface_nodes = faces[0].map_to_volume(face_quadrature.nodes) + for fidx in range(1, nfaces): + surface_nodes = np.append( + surface_nodes, + faces[fidx].map_to_volume(face_quadrature.nodes), + axis=1 + ) + vdm_f = vandermonde(basis.functions, surface_nodes) @ vdm_inv + return actx.freeze(actx.from_numpy(vdm_f)) + + return get_surface_vand(base_element_group, face_quad_element_group) + + +def volume_and_surface_interpolation_matrix( + actx: ArrayContext, + base_element_group, vol_quad_element_group, face_quad_element_group): + @keyed_memoize_in( + actx, volume_and_surface_interpolation_matrix, + lambda base_grp, vol_quad_grp, face_quad_grp: ( + base_grp.discretization_key(), + vol_quad_grp.discretization_key(), + face_quad_grp.discretization_key())) + def get_vol_surf_interpolation_matrix(base_grp, vol_quad_grp, face_quad_grp): + vq_mat = actx.to_numpy( + volume_quadrature_interpolation_matrix( + actx, + base_element_group=base_grp, + vol_quad_element_group=vol_quad_grp)) + vf_mat = actx.to_numpy( + surface_quadrature_interpolation_matrix( + actx, + base_element_group=base_grp, + face_quad_element_group=face_quad_grp)) + return actx.freeze(actx.from_numpy(np.block([[vq_mat], [vf_mat]]))) + + return get_vol_surf_interpolation_matrix( + base_element_group, vol_quad_element_group, face_quad_element_group + ) + +# }}} + + +def volume_and_surface_quadrature_interpolation( + dcoll: DiscretizationCollection, + dd_quad: DOFDesc, + dd_face_quad: DOFDesc, + vec: ArrayOrContainerT) -> ArrayOrContainerT: + """todo. + """ + if not isinstance(vec, DOFArray): + return map_array_container( + partial(volume_and_surface_quadrature_interpolation, + dcoll, dd_quad, dd_face_quad), vec + ) + + actx = vec.array_context + discr = dcoll.discr_from_dd("vol") + quad_volm_discr = dcoll.discr_from_dd(dd_quad) + quad_face_discr = dcoll.discr_from_dd(dd_face_quad) + + return DOFArray( + actx, + data=tuple( + actx.einsum("ij,ej->ei", + volume_and_surface_interpolation_matrix( + actx, + base_element_group=bgrp, + vol_quad_element_group=qvgrp, + face_quad_element_group=qfgrp + ), + vec_i, + arg_names=("Vh_mat", "vec"), + tagged=(FirstAxisIsElementsTag(),)) + + for bgrp, qvgrp, qfgrp, vec_i in zip( + discr.groups, + quad_volm_discr.groups, + quad_face_discr.groups, vec) + ) + ) diff --git a/grudge/op.py b/grudge/op.py index f5781f4be..d6dc959b2 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -95,7 +95,7 @@ ) from grudge.interpolation import interp -from grudge.projection import project +from grudge.projection import project, volume_quadrature_project from grudge.reductions import ( norm, @@ -1063,5 +1063,4 @@ def face_mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer: # }}} - # vim: foldmethod=marker diff --git a/grudge/projection.py b/grudge/projection.py index e21e02295..db7b10857 100644 --- a/grudge/projection.py +++ b/grudge/projection.py @@ -5,6 +5,7 @@ ----------- .. autofunction:: project +.. autofunction:: volume_quadrature_project """ from __future__ import annotations @@ -32,17 +33,25 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import numpy as np - +from arraycontext import map_array_container from arraycontext import ArrayOrContainer +from functools import partial + from grudge.discretization import DiscretizationCollection from grudge.dof_desc import ( as_dofdesc, VolumeDomainTag, BoundaryDomainTag, ConvertibleToDOFDesc) + +from meshmode.transform_metadata import FirstAxisIsElementsTag +from meshmode.dof_array import DOFArray + +from pytools import keyed_memoize_in from numbers import Number @@ -82,3 +91,59 @@ def project( return vec return dcoll.connection_from_dds(src_dofdesc, tgt_dofdesc)(vec) + +def volume_quadrature_project( + dcoll: DiscretizationCollection, dd_q, vec) -> ArrayOrContainerT: + """Projects a field on the quadrature discreization, described by *dd_q*, + into the polynomial space described by the volume discretization. + :arg dd_q: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` like *vec*. + """ + if not isinstance(vec, DOFArray): + return map_array_container( + partial(volume_quadrature_project, dcoll, dd_q), vec + ) + + from grudge.geometry import area_element + from grudge.interpolation import volume_quadrature_interpolation_matrix + from grudge.op import inverse_mass + + actx = vec.array_context + discr = dcoll.discr_from_dd("vol") + quad_discr = dcoll.discr_from_dd(dd_q) + jacobians = area_element( + actx, dcoll, dd=dd_q, + _use_geoderiv_connection=actx.supports_nonscalar_broadcasting) + + @keyed_memoize_in( + actx, volume_quadrature_project, + lambda base_grp, vol_quad_grp: (base_grp.discretization_key(), + vol_quad_grp.discretization_key())) + def get_mat(base_grp, vol_quad_grp): + vdm_q = actx.to_numpy( + volume_quadrature_interpolation_matrix( + actx, base_grp, vol_quad_grp + ) + ) + weights = np.diag(vol_quad_grp.quadrature_rule().weights) + return actx.freeze(actx.from_numpy(vdm_q.T @ weights)) + + return inverse_mass( + dcoll, + DOFArray( + actx, + data=tuple( + actx.einsum("ij,ej,ej->ei", + get_mat(bgrp, qgrp), + jac_i, + vec_i, + arg_names=("vqw_t", "jac", "vec"), + tagged=(FirstAxisIsElementsTag(),)) + for bgrp, qgrp, vec_i, jac_i in zip( + discr.groups, quad_discr.groups, vec, jacobians) + ) + ) + )