Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion irksome/base_time_stepper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from firedrake import Function, NonlinearVariationalProblem, NonlinearVariationalSolver
from firedrake.petsc import PETSc
from .tools import AI, get_stage_space, getNullspace, flatten_dats
from .tools import AI, get_stage_space, getNullspace, flatten_dats, get_stage_function


class BaseTimeStepper:
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(self, F, t, dt, u0, num_stages,
transpose_nullspace=None, near_nullspace=None,
splitting=None, bc_type=None,
butcher_tableau=None, bounds=None,
stage_functions=None,
**kwargs):

super().__init__(F, t, dt, u0,
Expand All @@ -97,6 +98,11 @@ def __init__(self, F, t, dt, u0, num_stages,
self.splitting = splitting
self.bc_type = bc_type

if stage_functions is not None:
stage_functions = {w: get_stage_function(w, self.num_stages)
for w in stage_functions}
self.stage_functions = stage_functions

self.num_steps = 0
self.num_nonlinear_iterations = 0
self.num_linear_iterations = 0
Expand Down
38 changes: 24 additions & 14 deletions irksome/stage_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .base_time_stepper import StageCoupledTimeStepper


def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI):
def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI, stage_functions=None):
"""Given a time-dependent variational form and a
:class:`ButcherTableau`, produce UFL for the s-stage RK method.

Expand Down Expand Up @@ -54,8 +54,13 @@ def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI):
if bc_type is None:
bc_type = "DAE"

if stage_functions is None:
stage_functions = {}
stage_functions[u0] = stages

# preprocess time derivatives
F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,))
timedep_coeffs = tuple(stage_functions)
F = expand_time_derivatives(F, t=t, timedep_coeffs=timedep_coeffs)
v, = F.arguments()
V = v.function_space()
assert V == u0.function_space()
Expand All @@ -76,17 +81,18 @@ def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI):

# set up the pieces we need to work with to do our substitutions
v_np = numpy.reshape(test, (num_stages, *u0.ufl_shape))
w_np = numpy.reshape(stages, (num_stages, *u0.ufl_shape))
A1w = A1 @ w_np
A2invw = A2inv @ w_np
w_np = {w: numpy.reshape(stage_functions[w], (num_stages, *w.ufl_shape))
for w in stage_functions}
A1w = {w: A1 @ w_np[w] for w in w_np}
A2invw = {w: A2inv @ w_np[w] for w in w_np}

dtu = TimeDerivative(u0)
repl = {}
for i in range(num_stages):
repl[i] = {t: t + c[i] * dt,
v: v_np[i],
u0: u0 + A1w[i] * dt,
dtu: A2invw[i]}
v: v_np[i]}
for w in w_np:
repl[i][w] = w + A1w[w][i] * dt
repl[i][TimeDerivative(w)] = A2invw[w][i]

Fnew = sum(replace(F, repl[i]) for i in range(num_stages))

Expand All @@ -99,7 +105,7 @@ def bc2stagebc(bc, i):
if isinstance(bc, EquationBCSplit):
raise NotImplementedError("EquationBC not implemented for ODE formulation")
gorig = as_ufl(bc._original_arg)
gfoo = expand_time_derivatives(Dt(gorig), t=t, timedep_coeffs=(u0,))
gfoo = expand_time_derivatives(Dt(gorig), t=t, timedep_coeffs=timedep_coeffs)
gcur = replace(gfoo, {t: t + c[i] * dt})
return BCStageData(bc, gcur, u0, stages, i)

Expand All @@ -112,7 +118,7 @@ def bc2stagebc(bc, i):

def bc2stagebc(bc, i):
if isinstance(bc, EquationBCSplit):
F_bc_orig = expand_time_derivatives(bc.f, t=t, timedep_coeffs=(u0,))
F_bc_orig = expand_time_derivatives(bc.f, t=t, timedep_coeffs=timedep_coeffs)
F_bc_new = replace(F_bc_orig, repl[i])
Vbigi = stage2spaces4bc(bc, V, Vbig, i)
return EquationBC(F_bc_new == 0, stages, bc.sub_domain, V=Vbigi,
Expand Down Expand Up @@ -179,11 +185,12 @@ class StageDerivativeTimeStepper(StageCoupledTimeStepper):
associated with the Runge-Kutta method
"""
def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
solver_parameters=None, splitting=AI,
solver_parameters=None, splitting=AI, stage_functions=None,
appctx=None, bc_type="DAE", **kwargs):

self.num_fields = len(u0.function_space())
self.butcher_tableau = butcher_tableau

A1, A2 = splitting(butcher_tableau.A)
try:
self.updateb = vecconst(numpy.linalg.solve(A2.T, butcher_tableau.b))
Expand All @@ -195,7 +202,9 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
solver_parameters=solver_parameters,
appctx=appctx,
splitting=splitting, bc_type=bc_type,
butcher_tableau=butcher_tableau, **kwargs)
butcher_tableau=butcher_tableau,
stage_functions=stage_functions,
**kwargs)

def _update(self):
"""Assuming the algebraic problem for the RK stages has been
Expand All @@ -216,7 +225,8 @@ def get_form_and_bcs(self, stages, tableau=None, F=None):
tableau or self.butcher_tableau,
self.t, self.dt,
self.u0, stages, self.orig_bcs, self.bc_type,
self.splitting)
self.splitting,
self.stage_functions)


class AdaptiveTimeStepper(StageDerivativeTimeStepper):
Expand Down
30 changes: 21 additions & 9 deletions irksome/stage_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def to_value(u0, stages, vandermonde):
return vandermonde[1:] @ u_np


def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermonde=None):
def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermonde=None, stage_functions=None):
"""Given a time-dependent variational form and a
:class:`ButcherTableau`, produce UFL for the s-stage RK method.

Expand Down Expand Up @@ -83,6 +83,10 @@ def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermo
V = v.function_space()
assert V == u0.function_space()

if stage_functions is None:
stage_functions = {}
stage_functions[u0] = stages

c = vecconst(butch.c)
bA1, bA2 = splitting(butch.A)
try:
Expand All @@ -99,15 +103,17 @@ def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermo

# set up the pieces we need to work with to do our substitutions
v_np = numpy.reshape(test, (num_stages, *u0.ufl_shape))
w_np = to_value(u0, stages, vandermonde)
w_np = {w: to_value(w, stage_functions[w], vandermonde)
for w in stage_functions}

A1Tv = A1.T @ v_np
A2invTv = A2inv.T @ v_np

# first, process terms with a time derivative. I'm
# assuming we have something of the form inner(Dt(g(u0)), v)*dx
# For each stage i, this gets replaced with
# inner((g(stages[i]) - g(u0))/dt, v)*dx
F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,))
F = expand_time_derivatives(F, t=t, timedep_coeffs=tuple(stage_functions))
split_form = extract_terms(F)
F_dtless = strip_dt_form(split_form.time)
F_remainder = split_form.remainder
Expand All @@ -116,16 +122,18 @@ def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermo
# Terms with time derivatives
for i in range(num_stages):
repl = {t: t + c[i] * dt,
v: A2invTv[i],
u0: w_np[i] - u0}
v: A2invTv[i]}
for w in w_np:
repl[w] = w_np[w][i] - w
Fnew += replace(F_dtless, repl)

# Handle the rest of the terms
for i in range(num_stages):
# replace the solution with stage values
repl = {t: t + c[i] * dt,
v: A1Tv[i] * dt,
u0: w_np[i]}
v: A1Tv[i] * dt}
for w in w_np:
repl[w] = w_np[w][i]
Fnew += replace(F_remainder, repl)

if bcs is None:
Expand Down Expand Up @@ -159,6 +167,7 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
splitting=AI, basis_type=None,
appctx=None, bounds=None,
use_collocation_update=False,
stage_functions=None,
**kwargs):

# we can only do DAE-type problems correctly if one assumes a stiffly-accurate method.
Expand All @@ -169,6 +178,7 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
self.basis_type = basis_type

degree = butcher_tableau.num_stages
num_stages = butcher_tableau.num_stages

if basis_type is None or basis_type == 'Lagrange':
vandermonde = None
Expand All @@ -184,10 +194,11 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
vandermonde = vecconst(vandermonde)
self.vandermonde = vandermonde

super().__init__(F, t, dt, u0, butcher_tableau.num_stages, bcs=bcs,
super().__init__(F, t, dt, u0, num_stages, bcs=bcs,
solver_parameters=solver_parameters,
appctx=appctx,
splitting=splitting, butcher_tableau=butcher_tableau, bounds=bounds,
splitting=splitting, butcher_tableau=butcher_tableau,
bounds=bounds, stage_functions=stage_functions,
**kwargs)

if use_collocation_update:
Expand Down Expand Up @@ -266,4 +277,5 @@ def get_form_and_bcs(self, stages, tableau=None, F=None):
self.t, self.dt, self.u0,
stages, bcs=self.orig_bcs,
splitting=self.splitting,
stage_functions=self.stage_functions,
vandermonde=self.vandermonde)
11 changes: 8 additions & 3 deletions irksome/stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
"appctx", "options_prefix", "pre_apply_bcs")

valid_kwargs_per_stage_type = {
"deriv": ["stage_type", "bc_type", "splitting", "adaptive_parameters"],
"value": ["stage_type", "basis_type",
"deriv": ["stage_type", "bc_type", "splitting", "adaptive_parameters", "stage_functions"],
"value": ["stage_type", "basis_type", "stage_functions",
"update_solver_parameters", "splitting", "bounds", "use_collocation_update"],
"dirk": ["stage_type", "bcs", "nullspace", "solver_parameters", "appctx"],
"explicit": ["stage_type", "bcs", "solver_parameters", "appctx"],
Expand Down Expand Up @@ -109,10 +109,13 @@ def TimeStepper(F, butcher_tableau, t, dt, u0, **kwargs):
if stage_type == "deriv":
bc_type = kwargs.get("bc_type", "DAE")
splitting = kwargs.get("splitting", AI)
stage_functions = kwargs.get("stage_functions")
if adapt_params is None:
return StageDerivativeTimeStepper(
F, butcher_tableau, t, dt, u0, bcs,
bc_type=bc_type, splitting=splitting, **base_kwargs)
bc_type=bc_type, splitting=splitting,
stage_functions=stage_functions,
**base_kwargs)
else:
for param in adapt_params:
assert param in valid_adapt_parameters
Expand All @@ -138,11 +141,13 @@ def TimeStepper(F, butcher_tableau, t, dt, u0, **kwargs):
update_solver_parameters = kwargs.get("update_solver_parameters")
bounds = kwargs.get("bounds")
use_collocation_update = kwargs.get("use_collocation_update", False)
stage_functions = kwargs.get("stage_functions")
return StageValueTimeStepper(
F, butcher_tableau, t, dt, u0, bcs=bcs,
splitting=splitting, basis_type=basis_type,
update_solver_parameters=update_solver_parameters,
bounds=bounds, use_collocation_update=use_collocation_update,
stage_functions=stage_functions,
**base_kwargs)
elif stage_type == "dirk":
return DIRKTimeStepper(
Expand Down
5 changes: 5 additions & 0 deletions irksome/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def get_stage_space(V, num_stages):
return reduce(mul, (V for _ in range(num_stages)))


def get_stage_function(w, num_stages):
Wbig = get_stage_space(w.function_space(), num_stages)
return Function(Wbig)


def getNullspace(V, Vbig, num_stages, nullspace):
"""
Computes the nullspace for a multi-stage method.
Expand Down
3 changes: 2 additions & 1 deletion tests/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def heat(n, deg, time_stages, **kwargs):


@pytest.mark.parametrize("kwargs", ({"stage_type": "deriv"},
{"stage_type": "value"}))
{"stage_type": "value"}),
ids=("deriv", "value"))
@pytest.mark.parametrize(('deg', 'convrate', 'time_stages'),
[(1, 1.78, i) for i in (1, 2)]
+ [(2, 2.8, i) for i in (2, 3)])
Expand Down