Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torax/_src/config/build_runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __call__(
numerics=self.numerics.build_runtime_params(t),
neoclassical=self.neoclassical.build_runtime_params(),
pedestal=self.pedestal.build_runtime_params(t),
pedestal_policy=self.pedestal.build_pedestal_policy_runtime_params(),
mhd=self.mhd.build_runtime_params(t),
time_step_calculator=self.time_step_calculator.build_runtime_params(),
edge=None if self.edge is None else self.edge.build_runtime_params(t),
Expand Down
2 changes: 2 additions & 0 deletions torax/_src/config/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from torax._src.mhd import runtime_params as mhd_runtime_params
from torax._src.neoclassical import runtime_params as neoclassical_params
from torax._src.pedestal_model import runtime_params as pedestal_model_params
from torax._src.pedestal_policy import runtime_params as pedestal_policy_runtime_params
from torax._src.solver import runtime_params as solver_params
from torax._src.sources import runtime_params as sources_params
from torax._src.time_step_calculator import runtime_params as time_step_calculator_runtime_params
Expand Down Expand Up @@ -80,6 +81,7 @@ class RuntimeParams:
neoclassical: neoclassical_params.RuntimeParams
numerics: numerics.RuntimeParams
pedestal: pedestal_model_params.RuntimeParams
pedestal_policy: pedestal_policy_runtime_params.PedestalPolicyRuntimeParams
plasma_composition: plasma_composition.RuntimeParams
profile_conditions: profile_conditions.RuntimeParams
solver: solver_params.RuntimeParams
Expand Down
17 changes: 15 additions & 2 deletions torax/_src/config/tests/build_runtime_params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,32 @@ def test_pedestal_is_time_dependent(self):
set_pedestal={0.0: True, 1.0: False},
)
)
pedestal_policy = pedestal.build_pedestal_model().pedestal_policy
# Check at time 0.

pedestal_params = pedestal.build_runtime_params(t=0.0)
pedestal_policy_rp = pedestal.build_pedestal_policy_runtime_params()
assert isinstance(pedestal_params, set_tped_nped.RuntimeParams)
np.testing.assert_allclose(pedestal_params.set_pedestal, True)
np.testing.assert_allclose(
pedestal_policy.initial_state(
t=0.0, runtime_params=pedestal_policy_rp
).use_pedestal,
True,
)
np.testing.assert_allclose(pedestal_params.T_i_ped, 0.0)
np.testing.assert_allclose(pedestal_params.T_e_ped, 1.0)
np.testing.assert_allclose(pedestal_params.n_e_ped, 2.0e20)
np.testing.assert_allclose(pedestal_params.rho_norm_ped_top, 3.0)
# And check after the time limit.
pedestal_params = pedestal.build_runtime_params(t=1.0)
# Note: pedestal_policy_rp does not depend on time for its structure
assert isinstance(pedestal_params, set_tped_nped.RuntimeParams)
np.testing.assert_allclose(pedestal_params.set_pedestal, False)
np.testing.assert_allclose(
pedestal_policy.initial_state(
t=1.0, runtime_params=pedestal_policy_rp
).use_pedestal,
False,
)
np.testing.assert_allclose(pedestal_params.T_i_ped, 1.0)
np.testing.assert_allclose(pedestal_params.T_e_ped, 2.0)
np.testing.assert_allclose(pedestal_params.n_e_ped, 3.0e20)
Expand Down
45 changes: 43 additions & 2 deletions torax/_src/fvm/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torax._src.fvm import cell_variable
from torax._src.geometry import geometry
from torax._src.pedestal_model import pedestal_model as pedestal_model_lib
from torax._src.pedestal_policy import pedestal_policy
from torax._src.sources import source_profile_builders
from torax._src.sources import source_profiles as source_profiles_lib
import typing_extensions
Expand Down Expand Up @@ -63,6 +64,7 @@ def __call__(
core_profiles: state.CoreProfiles,
x: tuple[cell_variable.CellVariable, ...],
explicit_source_profiles: source_profiles_lib.SourceProfiles,
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
allow_pereverzev: bool = False,
# Checks if reduced calc_coeffs for explicit terms when theta_implicit=1
# should be called
Expand All @@ -86,6 +88,7 @@ def __call__(
not recalculated at time t+plus_dt with updated state during the solver
iterations. For sources that are implicit, their explicit profiles are
set to all zeros.
pedestal_policy_state: State held by the pedestal policy.
allow_pereverzev: If True, then the coeffs are being called within a
linear solver. Thus could be either the use_predictor_corrector solver
or as part of calculating the initial guess for the nonlinear solver. In
Expand All @@ -101,6 +104,16 @@ def __call__(
coeffs: The diffusion, convection, etc. coefficients for this state.
"""

# There are cases where pytype fails to enforce this
if not isinstance(
pedestal_policy_state, pedestal_policy.PedestalPolicyState
):
raise TypeError(
'Expected `pedestal_policy_state` to be of type '
'`pedestal_policy.PedestalPolicyState`',
f'got `{type(pedestal_policy_state)}`.',
)

# Update core_profiles with the subset of new values of evolving variables
core_profiles = updaters.update_core_profiles_during_step(
x,
Expand All @@ -121,6 +134,7 @@ def __call__(
explicit_source_profiles=explicit_source_profiles,
physics_models=self.physics_models,
evolving_names=self.evolving_names,
pedestal_policy_state=pedestal_policy_state,
use_pereverzev=use_pereverzev,
explicit_call=explicit_call,
)
Expand Down Expand Up @@ -219,6 +233,7 @@ def calc_coeffs(
explicit_source_profiles: source_profiles_lib.SourceProfiles,
physics_models: physics_models_lib.PhysicsModels,
evolving_names: tuple[str, ...],
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
use_pereverzev: bool = False,
explicit_call: bool = False,
) -> block_1d_coeffs.Block1DCoeffs:
Expand All @@ -241,6 +256,7 @@ def calc_coeffs(
physics_models: The physics models to use for the simulation.
evolving_names: The names of the evolving variables in the order that their
coefficients should be written to `coeffs`.
pedestal_policy_state: State held by the pedestal policy.
use_pereverzev: Toggle whether to calculate Pereverzev terms
explicit_call: If True, indicates that calc_coeffs is being called for the
explicit component of the PDE. Then calculates a reduced Block1DCoeffs if
Expand All @@ -251,6 +267,14 @@ def calc_coeffs(
coeffs: Block1DCoeffs containing the coefficients at this time step.
"""

# There are cases where pytype fails to enforce this
if not isinstance(pedestal_policy_state, pedestal_policy.PedestalPolicyState):
raise TypeError(
'Expected `pedestal_policy_state` to be of type '
'`pedestal_policy.PedestalPolicyState`',
f'got `{type(pedestal_policy_state)}`.',
)

# If we are fully implicit and we are making a call for calc_coeffs for the
# explicit components of the PDE, only return a cheaper reduced Block1DCoeffs
if explicit_call and runtime_params.solver.theta_implicit == 1.0:
Expand All @@ -267,6 +291,7 @@ def calc_coeffs(
explicit_source_profiles=explicit_source_profiles,
physics_models=physics_models,
evolving_names=evolving_names,
pedestal_policy_state=pedestal_policy_state,
use_pereverzev=use_pereverzev,
)

Expand All @@ -285,14 +310,26 @@ def _calc_coeffs_full(
explicit_source_profiles: source_profiles_lib.SourceProfiles,
physics_models: physics_models_lib.PhysicsModels,
evolving_names: tuple[str, ...],
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
use_pereverzev: bool = False,
) -> block_1d_coeffs.Block1DCoeffs:
"""See `calc_coeffs` for details."""

consts = constants.CONSTANTS

# There are cases where pytype fails to enforce this
if not isinstance(pedestal_policy_state, pedestal_policy.PedestalPolicyState):
raise TypeError(
'Expected `pedestal_policy_state` to be of type '
'`pedestal_policy.PedestalPolicyState`',
f'got `{type(pedestal_policy_state)}`.',
)

pedestal_model_output = physics_models.pedestal_model(
runtime_params, geo, core_profiles
runtime_params,
geo,
core_profiles,
pedestal_policy_state=pedestal_policy_state,
)

# Boolean mask for enforcing internal temperature boundary conditions to
Expand Down Expand Up @@ -352,7 +389,11 @@ def _calc_coeffs_full(

# Diffusion term coefficients
turbulent_transport = physics_models.transport_model(
runtime_params, geo, core_profiles, pedestal_model_output
runtime_params,
geo,
core_profiles,
pedestal_policy_state,
pedestal_model_output,
)
neoclassical_transport = physics_models.neoclassical_models.transport(
runtime_params, geo, core_profiles
Expand Down
9 changes: 9 additions & 0 deletions torax/_src/fvm/newton_raphson_solve_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torax._src.fvm import fvm_conversions
from torax._src.fvm import residual_and_loss
from torax._src.geometry import geometry
from torax._src.pedestal_policy import pedestal_policy
from torax._src.solver import jax_root_finding
from torax._src.solver import predictor_corrector_method
from torax._src.sources import source_profiles
Expand Down Expand Up @@ -68,6 +69,8 @@ def newton_raphson_solve_block(
physics_models: physics_models_lib.PhysicsModels,
coeffs_callback: calc_coeffs.CoeffsCallback,
evolving_names: tuple[str, ...],
pedestal_policy_state_t: pedestal_policy.PedestalPolicyState,
pedestal_policy_state_t_plus_dt: pedestal_policy.PedestalPolicyState,
initial_guess_mode: enums.InitialGuessMode,
maxiter: int,
tol: float,
Expand Down Expand Up @@ -129,6 +132,8 @@ def newton_raphson_solve_block(
core_profiles. Repeatedly called by the iterative optimizer.
evolving_names: The names of variables within the core profiles that should
evolve.
pedestal_policy_state_t: Pedestal policy state at time t
pedestal_policy_state_t_plus_dt: Pedestal policy state at time t + dt
initial_guess_mode: chooses the initial_guess for the iterative method,
either x_old or linear step. When taking the linear step, it is also
recommended to use Pereverzev-Corrigan terms if the transport coefficients
Expand Down Expand Up @@ -160,6 +165,7 @@ def newton_raphson_solve_block(
core_profiles_t,
x_old,
explicit_source_profiles=explicit_source_profiles,
pedestal_policy_state=pedestal_policy_state_t,
explicit_call=True,
)

Expand All @@ -176,6 +182,7 @@ def newton_raphson_solve_block(
core_profiles_t,
x_old,
explicit_source_profiles=explicit_source_profiles,
pedestal_policy_state=pedestal_policy_state_t,
allow_pereverzev=True,
explicit_call=True,
)
Expand All @@ -194,6 +201,7 @@ def newton_raphson_solve_block(
coeffs_exp=coeffs_exp_linear,
coeffs_callback=coeffs_callback,
explicit_source_profiles=explicit_source_profiles,
pedestal_policy_state_t_plus_dt=pedestal_policy_state_t_plus_dt,
)
init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new)
case enums.InitialGuessMode.X_OLD:
Expand All @@ -215,6 +223,7 @@ def newton_raphson_solve_block(
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
physics_models=physics_models,
explicit_source_profiles=explicit_source_profiles,
pedestal_policy_state_t_plus_dt=pedestal_policy_state_t_plus_dt,
coeffs_old=coeffs_old,
evolving_names=evolving_names,
)
Expand Down
10 changes: 10 additions & 0 deletions torax/_src/fvm/optimizer_solve_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torax._src.fvm import fvm_conversions
from torax._src.fvm import residual_and_loss
from torax._src.geometry import geometry
from torax._src.pedestal_policy import pedestal_policy
from torax._src.solver import predictor_corrector_method
from torax._src.sources import source_profiles

Expand All @@ -57,6 +58,8 @@ def optimizer_solve_block(
core_profiles_t: state.CoreProfiles,
core_profiles_t_plus_dt: state.CoreProfiles,
explicit_source_profiles: source_profiles.SourceProfiles,
pedestal_policy_state_t: pedestal_policy.PedestalPolicyState,
pedestal_policy_state_t_plus_dt: pedestal_policy.PedestalPolicyState,
physics_models: physics_models_lib.PhysicsModels,
coeffs_callback: calc_coeffs.CoeffsCallback,
evolving_names: tuple[str, ...],
Expand Down Expand Up @@ -98,6 +101,9 @@ def optimizer_solve_block(
being evolved by the PDE system.
explicit_source_profiles: Pre-calculated sources implemented as explicit
sources in the PDE.
pedestal_policy_state_t: State variables held by the pedestal policy.
pedestal_policy_state_t_plus_dt: State variables held by the pedestal
policy.
physics_models: Physics models used for the calculations.
coeffs_callback: Calculates diffusion, convection etc. coefficients given a
core_profiles. Repeatedly called by the iterative optimizer.
Expand All @@ -124,6 +130,7 @@ def optimizer_solve_block(
core_profiles_t,
x_old,
explicit_source_profiles=explicit_source_profiles,
pedestal_policy_state=pedestal_policy_state_t,
explicit_call=True,
)

Expand All @@ -141,6 +148,7 @@ def optimizer_solve_block(
core_profiles_t,
x_old,
explicit_source_profiles=explicit_source_profiles,
pedestal_policy_state=pedestal_policy_state_t,
allow_pereverzev=True,
explicit_call=True,
)
Expand All @@ -158,6 +166,7 @@ def optimizer_solve_block(
coeffs_exp=coeffs_exp_linear,
coeffs_callback=coeffs_callback,
explicit_source_profiles=explicit_source_profiles,
pedestal_policy_state_t_plus_dt=pedestal_policy_state_t_plus_dt,
)
init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new)
case enums.InitialGuessMode.X_OLD:
Expand All @@ -180,6 +189,7 @@ def optimizer_solve_block(
init_x_new_vec=init_x_new_vec,
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
explicit_source_profiles=explicit_source_profiles,
pedestal_policy_state_t_plus_dt=pedestal_policy_state_t_plus_dt,
physics_models=physics_models,
coeffs_old=coeffs_old,
evolving_names=evolving_names,
Expand Down
13 changes: 13 additions & 0 deletions torax/_src/fvm/residual_and_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from torax._src.fvm import discrete_system
from torax._src.fvm import fvm_conversions
from torax._src.geometry import geometry
from torax._src.pedestal_policy import pedestal_policy
from torax._src.sources import source_profiles

Block1DCoeffs: TypeAlias = block_1d_coeffs.Block1DCoeffs
Expand Down Expand Up @@ -201,6 +202,7 @@ def theta_method_block_residual(
x_old: tuple[cell_variable.CellVariable, ...],
core_profiles_t_plus_dt: state.CoreProfiles,
explicit_source_profiles: source_profiles.SourceProfiles,
pedestal_policy_state_t_plus_dt: pedestal_policy.PedestalPolicyState,
physics_models: physics_models_lib.PhysicsModels,
coeffs_old: Block1DCoeffs,
evolving_names: tuple[str, ...],
Expand All @@ -220,6 +222,8 @@ def theta_method_block_residual(
being evolved by the PDE system.
explicit_source_profiles: Pre-calculated sources implemented as explicit
sources in the PDE.
pedestal_policy_state_t_plus_dt: State variables held by the pedestal
policy.
physics_models: Physics models used for the calculations.
coeffs_old: The coefficients calculated at x_old.
evolving_names: The names of variables within the core profiles that should
Expand Down Expand Up @@ -252,6 +256,7 @@ def theta_method_block_residual(
core_profiles=core_profiles_t_plus_dt,
explicit_source_profiles=explicit_source_profiles,
physics_models=physics_models,
pedestal_policy_state=pedestal_policy_state_t_plus_dt,
evolving_names=evolving_names,
use_pereverzev=False,
)
Expand Down Expand Up @@ -290,6 +295,7 @@ def theta_method_block_loss(
x_old: tuple[cell_variable.CellVariable, ...],
core_profiles_t_plus_dt: state.CoreProfiles,
explicit_source_profiles: source_profiles.SourceProfiles,
pedestal_policy_state_t_plus_dt: pedestal_policy.PedestalPolicyState,
physics_models: physics_models_lib.PhysicsModels,
coeffs_old: Block1DCoeffs,
evolving_names: tuple[str, ...],
Expand All @@ -309,6 +315,8 @@ def theta_method_block_loss(
being evolved by the PDE system.
explicit_source_profiles: pre-calculated sources implemented as explicit
sources in the PDE
pedestal_policy_state_t_plus_dt: State variables held by the pedestal
policy.
physics_models: Physics models used for the calculations.
coeffs_old: The coefficients calculated at x_old.
evolving_names: The names of variables within the core profiles that should
Expand All @@ -326,6 +334,7 @@ def theta_method_block_loss(
x_new_guess_vec=x_new_guess_vec,
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
explicit_source_profiles=explicit_source_profiles,
pedestal_policy_state_t_plus_dt=pedestal_policy_state_t_plus_dt,
physics_models=physics_models,
coeffs_old=coeffs_old,
evolving_names=evolving_names,
Expand All @@ -349,6 +358,7 @@ def jaxopt_solver(
init_x_new_vec: jax.Array,
core_profiles_t_plus_dt: state.CoreProfiles,
explicit_source_profiles: source_profiles.SourceProfiles,
pedestal_policy_state_t_plus_dt: pedestal_policy.PedestalPolicy,
physics_models: physics_models_lib.PhysicsModels,
coeffs_old: Block1DCoeffs,
evolving_names: tuple[str, ...],
Expand All @@ -370,6 +380,8 @@ def jaxopt_solver(
being evolved by the PDE system.
explicit_source_profiles: pre-calculated sources implemented as explicit
sources in the PDE.
pedestal_policy_state_t_plus_dt: State variables held by the pedestal
policy.
physics_models: Physics models used for the calculations.
coeffs_old: The coefficients calculated at x_old.
evolving_names: The names of variables within the core profiles that should
Expand All @@ -394,6 +406,7 @@ def jaxopt_solver(
physics_models=physics_models,
coeffs_old=coeffs_old,
evolving_names=evolving_names,
pedestal_policy_state_t_plus_dt=pedestal_policy_state_t_plus_dt,
)
solver = jaxopt.LBFGS(fun=loss, maxiter=maxiter, tol=tol, implicit_diff=True)
solver_output = solver.run(init_x_new_vec)
Expand Down
Loading