Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion irksome/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,10 @@
from .scheme import ContinuousPetrovGalerkinScheme, DiscontinuousGalerkinScheme # noqa: F401
from .galerkin_stepper import ContinuousPetrovGalerkinTimeStepper # noqa: F401
from .discontinuous_galerkin_stepper import DiscontinuousGalerkinTimeStepper # noqa: F401
from .labeling import TimeQuadratureLabel # noqa: F401
from .labeling import (
TimeQuadratureLabel, MeasureOverride, # noqa: F401
dx_override, ds_override, dS_override, dr_override, dP_override, # noqa: F401
dc_override, dC_override, dI_override, dO_override, # noqa: F401
ds_b_override, ds_t_override, ds_v_override, # noqa: F401
dS_h_override, dS_v_override, # noqa: F401
)
186 changes: 167 additions & 19 deletions irksome/labeling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from firedrake.fml import Label, keep, drop, LabelledForm
from .scheme import create_time_quadrature
from ufl.form import Form
from ufl.measure import Measure
import numpy as np

explicit = Label("explicit")
Expand Down Expand Up @@ -33,27 +35,81 @@ def get_weights(self):


def split_quadrature(F, Qdefault=None):
if not isinstance(F, LabelledForm):
"""Split a form into subforms grouped by time quadrature rule.

Supports two mechanisms:
1) firedrake.fml labels using TimeQuadratureLabel/TimeQuadratureRule
2) UFL integral metadata containing Irksome keys
("quadrature_degree_time" and optionally "quadrature_scheme_time").

If neither labelling nor metadata overrides are present, returns
a single entry mapping Qdefault -> F.
"""
# Case 1: LabelledForm path (existing behaviour)
if isinstance(F, LabelledForm):
quad_labels = set()
for term in F.terms:
cur_labels = [label for label in term.labels if isinstance(label, TimeQuadratureRule)]
if len(cur_labels) == 1:
quad_labels.update(cur_labels)
elif len(cur_labels) > 1:
raise ValueError("Multiple quadrature labels on one term.")

splitting = {Q: F.label_map(lambda t: Q in t.labels, map_if_true=keep, map_if_false=drop)
for Q in quad_labels}
splitting[Qdefault] = F.label_map(lambda t: len(quad_labels.intersection(t.labels)) > 0,
map_if_true=drop, map_if_false=keep)
for Q in list(splitting):
try:
splitting[Q] = splitting[Q].form
except TypeError:
splitting.pop(Q)
return splitting

# Case 2: Plain UFL form with per-integral metadata overrides
# See if I can recover integral; it not, return default
try:
integrals = F.integrals()
except Exception:
return {Qdefault: F}

# Scan for Irksome metadata; if none present, return default
IRK_DEG = "quadrature_degree_override"
IRK_SCH = "quadrature_scheme_override"
has_override = any(
(IRK_DEG in (I.metadata() or {}) or IRK_SCH in (I.metadata() or {}))
for I in integrals
)
if not has_override:
return {Qdefault: F}

quad_labels = set()
for term in F.terms:
cur_labels = [label for label in term.labels if isinstance(label, TimeQuadratureRule)]
if len(cur_labels) == 1:
quad_labels.update(cur_labels)
elif len(cur_labels) > 1:
raise ValueError("Multiple quadrature labels on one term.")

splitting = {Q: F.label_map(lambda t: Q in t.labels, map_if_true=keep, map_if_false=drop)
for Q in quad_labels}
splitting[Qdefault] = F.label_map(lambda t: len(quad_labels.intersection(t.labels)) > 0,
map_if_true=drop, map_if_false=keep)
for Q in list(splitting):
try:
splitting[Q] = splitting[Q].form
except TypeError:
splitting.pop(Q)
return splitting
# Since we got here, build groups keyed by (degree, scheme) tuples
groups = {}
default_ints = []
# For each integral...
for I in integrals:
# ...get the metadata...
md = I.metadata() or {}
deg = md.get(IRK_DEG, None)
sch = md.get(IRK_SCH, None)
if deg is None:
# ...if no quadrature override is specified, add to default...
default_ints.append(I)
else:
# ...and otherwise, record in groups
sch = sch if sch is not None else "default"
key = (int(deg), str(sch))
groups.setdefault(key, []).append(I)

# Now, assemble into a dictionary as required using create_time_quadrature
result = {}
for (deg, sch), ints in groups.items():
Q = create_time_quadrature(deg, scheme=sch)
result[Q] = Form(ints)
if default_ints:
result[Qdefault] = Form(default_ints)

return result


def split_explicit(F):
Expand All @@ -67,3 +123,95 @@ def split_explicit(F):
map_if_true=keep, map_if_false=drop)

return imp_part.form, exp_part.form


class MeasureOverride(Measure):
"""Thin wrappers around UFL Measures that allow users to tag
individual integrals with Irksome-specific overrides for
time quadrature used by Galerkin-in-time discretisations.

Usage example:
F = inner(Dt(u), v) * dx_override(time_degree_override=5) + inner(u, v) * dx

Here, only the first term will be integrated in time with a rule of
degree 5; the other terms will use the scheme defaults.
"""
def __call__(
self,
subdomain_id=None,
metadata=None,
domain=None,
subdomain_data=None,
degree=None,
scheme=None,
*,
time_degree_override=None,
time_scheme_override=None,
):
"""Reconfigure measure with (optional) time quadrature overrides.

The optional keyword-only arguments time_degree_override and time_scheme_override
are stored in metadata keys understood by Irksome's Galerkin-in-time
machinery in split_quadrature().
"""
# Inject time overrides into metadata
if time_degree_override is None and time_scheme_override is not None:
raise ValueError(
"Time quadrature override requires specification of time_degree_override."
)
if time_degree_override is not None or time_scheme_override is not None:
metadata = {} if metadata is None else metadata.copy()
if time_degree_override is not None:
metadata["quadrature_degree_override"] = time_degree_override
if time_scheme_override is not None:
metadata["quadrature_scheme_override"] = time_scheme_override

# Inject spatial (degree, scheme) into metadata if provided, mirroring
# ufl.measure.Measure.__call__ semantics.
if (degree, scheme) != (None, None):
metadata = {} if metadata is None else metadata.copy()
if degree is not None:
metadata["quadrature_degree"] = degree
if scheme is not None:
metadata["quadrature_rule"] = scheme

# Support dx(domain) style: if first positional looks like a domain, treat accordingly
if subdomain_id is not None and hasattr(subdomain_id, "ufl_domain"):
if domain is not None:
raise ValueError(
"Ambiguous: setting domain both as keyword argument and first argument."
)
subdomain_id, domain = "everywhere", subdomain_id

# Without args, return everywhere
if all(x is None for x in (subdomain_id, metadata, domain, subdomain_data, degree, scheme)) and (
time_degree_override is None and time_scheme_override is None
):
return self.reconstruct(subdomain_id="everywhere")

# Construct new Measure
return Measure(
self.integral_type(),
domain=domain or self.ufl_domain(),
subdomain_id=subdomain_id if subdomain_id is not None else self.subdomain_id(),
metadata=metadata if metadata is not None else self.metadata(),
subdomain_data=subdomain_data if subdomain_data is not None else self.subdomain_data(),
)


# Convenience instances mirroring Firedrake/UFL defaults
dx_override = MeasureOverride("cell")
ds_override = MeasureOverride("exterior_facet")
dS_override = MeasureOverride("interior_facet")
dr_override = MeasureOverride("ridge")
dP_override = MeasureOverride("vertex")
dc_override = MeasureOverride("custom")
dC_override = MeasureOverride("cutcell")
dI_override = MeasureOverride("interface")
dO_override = MeasureOverride("overlap")
ds_b_override = MeasureOverride("exterior_facet_bottom")
ds_t_override = MeasureOverride("exterior_facet_top")
ds_v_override = MeasureOverride("exterior_facet_vert")
dS_h_override = MeasureOverride("interior_facet_horiz")
dS_v_override = MeasureOverride("interior_facet_vert")

56 changes: 56 additions & 0 deletions tests/test_measureoverride.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

from firedrake import *
from irksome import Dt, TimeStepper, ContinuousPetrovGalerkinScheme, dx_override

@pytest.mark.parametrize("order", [1, 2, 3])
@pytest.mark.parametrize("scheme", ["gauss", "cpg"])
def test_nls(order, scheme):
# Domain and space
mesh = PeriodicUnitIntervalMesh(10)
x, = SpatialCoordinate(mesh)
V = FunctionSpace(mesh, "CG", 1)
Z = V * V

# State and test functions
psi = Function(Z)
a, b = split(psi)
c, d = TestFunctions(Z)

# Initial condition: cosine
psi.project(as_vector([cos(x), 0]))

# Time parameters
t = Constant(0.0)
dt = Constant(0.1)

# Residual
dx_highorder = dx if scheme == "gauss" else dx_override(time_degree_override=4*order-1)
amp_sq = a**2 + b**2
F = (
inner(Dt(b), c) * dx
+ 0.5 * inner(grad(a), grad(c)) * dx
- inner(amp_sq * a, c) * dx_highorder
- inner(Dt(a), d) * dx
+ 0.5 * inner(grad(b), grad(d)) * dx
- inner(amp_sq * b, d) * dx_highorder
)

# Energy
E = 0.5 * (inner(grad(a), grad(a)) + inner(grad(b), grad(b)) - amp_sq**2) * dx

# Time stepper with cPG(k); default time quadrature is 2k-1
scheme_ = ContinuousPetrovGalerkinScheme(order=order, quadrature_degree=2*order-1)
stepper = TimeStepper(F, scheme_, t, dt, psi)

# Record initial energy
E0 = float(assemble(E))

# Advance once
stepper.advance()

# Final energy and drift
E1 = float(assemble(E))
drift = abs(E1 - E0)
if scheme == "gauss": assert drift > 1e-10
else: assert drift < 1e-10