diff --git a/irksome/base_time_stepper.py b/irksome/base_time_stepper.py index cf0410c4..792ad746 100644 --- a/irksome/base_time_stepper.py +++ b/irksome/base_time_stepper.py @@ -1,7 +1,7 @@ from abc import abstractmethod from firedrake import Function, NonlinearVariationalProblem, NonlinearVariationalSolver from firedrake.petsc import PETSc -from .tools import AI, get_stage_space, getNullspace, flatten_dats +from .tools import AI, get_stage_space, getNullspace, flatten_dats, get_stage_function class BaseTimeStepper: @@ -84,6 +84,7 @@ def __init__(self, F, t, dt, u0, num_stages, transpose_nullspace=None, near_nullspace=None, splitting=None, bc_type=None, butcher_tableau=None, bounds=None, + stage_functions=None, **kwargs): super().__init__(F, t, dt, u0, @@ -97,6 +98,11 @@ def __init__(self, F, t, dt, u0, num_stages, self.splitting = splitting self.bc_type = bc_type + if stage_functions is not None: + stage_functions = {w: get_stage_function(w, self.num_stages) + for w in stage_functions} + self.stage_functions = stage_functions + self.num_steps = 0 self.num_nonlinear_iterations = 0 self.num_linear_iterations = 0 diff --git a/irksome/stage_derivative.py b/irksome/stage_derivative.py index 4467be5d..79e8746f 100644 --- a/irksome/stage_derivative.py +++ b/irksome/stage_derivative.py @@ -13,7 +13,7 @@ from .base_time_stepper import StageCoupledTimeStepper -def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI): +def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI, stage_functions=None): """Given a time-dependent variational form and a :class:`ButcherTableau`, produce UFL for the s-stage RK method. @@ -54,8 +54,13 @@ def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI): if bc_type is None: bc_type = "DAE" + if stage_functions is None: + stage_functions = {} + stage_functions[u0] = stages + # preprocess time derivatives - F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) + timedep_coeffs = tuple(stage_functions) + F = expand_time_derivatives(F, t=t, timedep_coeffs=timedep_coeffs) v, = F.arguments() V = v.function_space() assert V == u0.function_space() @@ -76,17 +81,18 @@ def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI): # set up the pieces we need to work with to do our substitutions v_np = numpy.reshape(test, (num_stages, *u0.ufl_shape)) - w_np = numpy.reshape(stages, (num_stages, *u0.ufl_shape)) - A1w = A1 @ w_np - A2invw = A2inv @ w_np + w_np = {w: numpy.reshape(stage_functions[w], (num_stages, *w.ufl_shape)) + for w in stage_functions} + A1w = {w: A1 @ w_np[w] for w in w_np} + A2invw = {w: A2inv @ w_np[w] for w in w_np} - dtu = TimeDerivative(u0) repl = {} for i in range(num_stages): repl[i] = {t: t + c[i] * dt, - v: v_np[i], - u0: u0 + A1w[i] * dt, - dtu: A2invw[i]} + v: v_np[i]} + for w in w_np: + repl[i][w] = w + A1w[w][i] * dt + repl[i][TimeDerivative(w)] = A2invw[w][i] Fnew = sum(replace(F, repl[i]) for i in range(num_stages)) @@ -99,7 +105,7 @@ def bc2stagebc(bc, i): if isinstance(bc, EquationBCSplit): raise NotImplementedError("EquationBC not implemented for ODE formulation") gorig = as_ufl(bc._original_arg) - gfoo = expand_time_derivatives(Dt(gorig), t=t, timedep_coeffs=(u0,)) + gfoo = expand_time_derivatives(Dt(gorig), t=t, timedep_coeffs=timedep_coeffs) gcur = replace(gfoo, {t: t + c[i] * dt}) return BCStageData(bc, gcur, u0, stages, i) @@ -112,7 +118,7 @@ def bc2stagebc(bc, i): def bc2stagebc(bc, i): if isinstance(bc, EquationBCSplit): - F_bc_orig = expand_time_derivatives(bc.f, t=t, timedep_coeffs=(u0,)) + F_bc_orig = expand_time_derivatives(bc.f, t=t, timedep_coeffs=timedep_coeffs) F_bc_new = replace(F_bc_orig, repl[i]) Vbigi = stage2spaces4bc(bc, V, Vbig, i) return EquationBC(F_bc_new == 0, stages, bc.sub_domain, V=Vbigi, @@ -179,11 +185,12 @@ class StageDerivativeTimeStepper(StageCoupledTimeStepper): associated with the Runge-Kutta method """ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, - solver_parameters=None, splitting=AI, + solver_parameters=None, splitting=AI, stage_functions=None, appctx=None, bc_type="DAE", **kwargs): self.num_fields = len(u0.function_space()) self.butcher_tableau = butcher_tableau + A1, A2 = splitting(butcher_tableau.A) try: self.updateb = vecconst(numpy.linalg.solve(A2.T, butcher_tableau.b)) @@ -195,7 +202,9 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, solver_parameters=solver_parameters, appctx=appctx, splitting=splitting, bc_type=bc_type, - butcher_tableau=butcher_tableau, **kwargs) + butcher_tableau=butcher_tableau, + stage_functions=stage_functions, + **kwargs) def _update(self): """Assuming the algebraic problem for the RK stages has been @@ -216,7 +225,8 @@ def get_form_and_bcs(self, stages, tableau=None, F=None): tableau or self.butcher_tableau, self.t, self.dt, self.u0, stages, self.orig_bcs, self.bc_type, - self.splitting) + self.splitting, + self.stage_functions) class AdaptiveTimeStepper(StageDerivativeTimeStepper): diff --git a/irksome/stage_value.py b/irksome/stage_value.py index 675e9e06..278a1ff9 100644 --- a/irksome/stage_value.py +++ b/irksome/stage_value.py @@ -32,7 +32,7 @@ def to_value(u0, stages, vandermonde): return vandermonde[1:] @ u_np -def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermonde=None): +def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermonde=None, stage_functions=None): """Given a time-dependent variational form and a :class:`ButcherTableau`, produce UFL for the s-stage RK method. @@ -83,6 +83,10 @@ def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermo V = v.function_space() assert V == u0.function_space() + if stage_functions is None: + stage_functions = {} + stage_functions[u0] = stages + c = vecconst(butch.c) bA1, bA2 = splitting(butch.A) try: @@ -99,7 +103,9 @@ def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermo # set up the pieces we need to work with to do our substitutions v_np = numpy.reshape(test, (num_stages, *u0.ufl_shape)) - w_np = to_value(u0, stages, vandermonde) + w_np = {w: to_value(w, stage_functions[w], vandermonde) + for w in stage_functions} + A1Tv = A1.T @ v_np A2invTv = A2inv.T @ v_np @@ -107,7 +113,7 @@ def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermo # assuming we have something of the form inner(Dt(g(u0)), v)*dx # For each stage i, this gets replaced with # inner((g(stages[i]) - g(u0))/dt, v)*dx - F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) + F = expand_time_derivatives(F, t=t, timedep_coeffs=tuple(stage_functions)) split_form = extract_terms(F) F_dtless = strip_dt_form(split_form.time) F_remainder = split_form.remainder @@ -116,16 +122,18 @@ def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermo # Terms with time derivatives for i in range(num_stages): repl = {t: t + c[i] * dt, - v: A2invTv[i], - u0: w_np[i] - u0} + v: A2invTv[i]} + for w in w_np: + repl[w] = w_np[w][i] - w Fnew += replace(F_dtless, repl) # Handle the rest of the terms for i in range(num_stages): # replace the solution with stage values repl = {t: t + c[i] * dt, - v: A1Tv[i] * dt, - u0: w_np[i]} + v: A1Tv[i] * dt} + for w in w_np: + repl[w] = w_np[w][i] Fnew += replace(F_remainder, repl) if bcs is None: @@ -159,6 +167,7 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, splitting=AI, basis_type=None, appctx=None, bounds=None, use_collocation_update=False, + stage_functions=None, **kwargs): # we can only do DAE-type problems correctly if one assumes a stiffly-accurate method. @@ -169,6 +178,7 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, self.basis_type = basis_type degree = butcher_tableau.num_stages + num_stages = butcher_tableau.num_stages if basis_type is None or basis_type == 'Lagrange': vandermonde = None @@ -184,10 +194,11 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, vandermonde = vecconst(vandermonde) self.vandermonde = vandermonde - super().__init__(F, t, dt, u0, butcher_tableau.num_stages, bcs=bcs, + super().__init__(F, t, dt, u0, num_stages, bcs=bcs, solver_parameters=solver_parameters, appctx=appctx, - splitting=splitting, butcher_tableau=butcher_tableau, bounds=bounds, + splitting=splitting, butcher_tableau=butcher_tableau, + bounds=bounds, stage_functions=stage_functions, **kwargs) if use_collocation_update: @@ -266,4 +277,5 @@ def get_form_and_bcs(self, stages, tableau=None, F=None): self.t, self.dt, self.u0, stages, bcs=self.orig_bcs, splitting=self.splitting, + stage_functions=self.stage_functions, vandermonde=self.vandermonde) diff --git a/irksome/stepper.py b/irksome/stepper.py index 9d9703c9..d5ee9fc6 100644 --- a/irksome/stepper.py +++ b/irksome/stepper.py @@ -11,8 +11,8 @@ "appctx", "options_prefix", "pre_apply_bcs") valid_kwargs_per_stage_type = { - "deriv": ["stage_type", "bc_type", "splitting", "adaptive_parameters"], - "value": ["stage_type", "basis_type", + "deriv": ["stage_type", "bc_type", "splitting", "adaptive_parameters", "stage_functions"], + "value": ["stage_type", "basis_type", "stage_functions", "update_solver_parameters", "splitting", "bounds", "use_collocation_update"], "dirk": ["stage_type", "bcs", "nullspace", "solver_parameters", "appctx"], "explicit": ["stage_type", "bcs", "solver_parameters", "appctx"], @@ -109,10 +109,13 @@ def TimeStepper(F, butcher_tableau, t, dt, u0, **kwargs): if stage_type == "deriv": bc_type = kwargs.get("bc_type", "DAE") splitting = kwargs.get("splitting", AI) + stage_functions = kwargs.get("stage_functions") if adapt_params is None: return StageDerivativeTimeStepper( F, butcher_tableau, t, dt, u0, bcs, - bc_type=bc_type, splitting=splitting, **base_kwargs) + bc_type=bc_type, splitting=splitting, + stage_functions=stage_functions, + **base_kwargs) else: for param in adapt_params: assert param in valid_adapt_parameters @@ -138,11 +141,13 @@ def TimeStepper(F, butcher_tableau, t, dt, u0, **kwargs): update_solver_parameters = kwargs.get("update_solver_parameters") bounds = kwargs.get("bounds") use_collocation_update = kwargs.get("use_collocation_update", False) + stage_functions = kwargs.get("stage_functions") return StageValueTimeStepper( F, butcher_tableau, t, dt, u0, bcs=bcs, splitting=splitting, basis_type=basis_type, update_solver_parameters=update_solver_parameters, bounds=bounds, use_collocation_update=use_collocation_update, + stage_functions=stage_functions, **base_kwargs) elif stage_type == "dirk": return DIRKTimeStepper( diff --git a/irksome/tools.py b/irksome/tools.py index 197ef204..e95295bc 100644 --- a/irksome/tools.py +++ b/irksome/tools.py @@ -24,6 +24,11 @@ def get_stage_space(V, num_stages): return reduce(mul, (V for _ in range(num_stages))) +def get_stage_function(w, num_stages): + Wbig = get_stage_space(w.function_space(), num_stages) + return Function(Wbig) + + def getNullspace(V, Vbig, num_stages, nullspace): """ Computes the nullspace for a multi-stage method. diff --git a/tests/test_accuracy.py b/tests/test_accuracy.py index 0e429b93..8c9b60ad 100644 --- a/tests/test_accuracy.py +++ b/tests/test_accuracy.py @@ -52,7 +52,8 @@ def heat(n, deg, time_stages, **kwargs): @pytest.mark.parametrize("kwargs", ({"stage_type": "deriv"}, - {"stage_type": "value"})) + {"stage_type": "value"}), + ids=("deriv", "value")) @pytest.mark.parametrize(('deg', 'convrate', 'time_stages'), [(1, 1.78, i) for i in (1, 2)] + [(2, 2.8, i) for i in (2, 3)])