From 1972bd531a444dca8ae3dcd1559a1a072dd6fd7b Mon Sep 17 00:00:00 2001 From: akshat3410 Date: Sat, 3 Jan 2026 14:42:40 +0530 Subject: [PATCH] feat(numerics): implement minimum allowed temperature (#1607) Add min_temperature config option to stop simulation when temperatures fall below a user-defined threshold. This is useful for avoiding numerical issues during radiative collapse scenarios. Changes: - Add min_temperature field to Numerics config (default: 0.0 = disabled) - Add temperature_below_minimum() method to CoreProfiles - Add BELOW_MIN_TEMPERATURE error type to SimError enum - Integrate check into SimState.check_for_errors() - Pass min_temperature from step_function to error checker Fixes #1607 --- torax/_src/config/numerics.py | 7 ++++ torax/_src/orchestration/sim_state.py | 21 +++++++++-- torax/_src/orchestration/step_function.py | 4 ++- torax/_src/state.py | 26 ++++++++++++++ torax/_src/tests/state_test.py | 44 +++++++++++++++++++++++ 5 files changed, 99 insertions(+), 3 deletions(-) diff --git a/torax/_src/config/numerics.py b/torax/_src/config/numerics.py index 8bb063c1a..7d7e96cfd 100644 --- a/torax/_src/config/numerics.py +++ b/torax/_src/config/numerics.py @@ -45,6 +45,7 @@ class RuntimeParams: resistivity_multiplier: array_typing.FloatScalar adaptive_T_source_prefactor: float adaptive_n_source_prefactor: float + min_temperature: float evolve_ion_heat: bool = dataclasses.field(metadata={'static': True}) evolve_electron_heat: bool = dataclasses.field(metadata={'static': True}) evolve_current: bool = dataclasses.field(metadata={'static': True}) @@ -102,6 +103,10 @@ class Numerics(torax_pydantic.BaseModelFrozen): temperature internal boundary conditions. adaptive_n_source_prefactor: Prefactor for adaptive source term for setting density internal boundary conditions. + min_temperature: Minimum allowed temperature in keV. If any temperature + (T_e or T_i) falls below this threshold, the simulation will exit with + an error. This is useful for avoiding numerical issues during radiative + collapse scenarios. Default is 0.0 (only negative temperatures trigger). """ t_initial: torax_pydantic.Second = 0.0 @@ -125,6 +130,7 @@ class Numerics(torax_pydantic.BaseModelFrozen): ) adaptive_T_source_prefactor: pydantic.PositiveFloat = 2.0e10 adaptive_n_source_prefactor: pydantic.PositiveFloat = 2.0e8 + min_temperature: pydantic.NonNegativeFloat = 0.0 @pydantic.model_validator(mode='after') def model_validation(self) -> Self: @@ -168,6 +174,7 @@ def build_runtime_params(self, t: chex.Numeric) -> RuntimeParams: resistivity_multiplier=self.resistivity_multiplier.get_value(t), adaptive_T_source_prefactor=self.adaptive_T_source_prefactor, adaptive_n_source_prefactor=self.adaptive_n_source_prefactor, + min_temperature=self.min_temperature, evolve_ion_heat=self.evolve_ion_heat, evolve_electron_heat=self.evolve_electron_heat, evolve_current=self.evolve_current, diff --git a/torax/_src/orchestration/sim_state.py b/torax/_src/orchestration/sim_state.py index d9a20a230..0e16ffde2 100644 --- a/torax/_src/orchestration/sim_state.py +++ b/torax/_src/orchestration/sim_state.py @@ -55,8 +55,25 @@ class SimState: geometry: geometry.Geometry solver_numeric_outputs: state.SolverNumericOutputs - def check_for_errors(self) -> state.SimError: - """Checks for errors in the simulation state.""" + def check_for_errors( + self, + min_temperature: float = 0.0, + ) -> state.SimError: + """Checks for errors in the simulation state. + + Args: + min_temperature: Minimum allowed temperature in keV. If any temperature + falls below this threshold, returns BELOW_MIN_TEMPERATURE error. + + Returns: + SimError indicating the type of error, or NO_ERROR if none. + """ + if self.core_profiles.temperature_below_minimum(min_temperature): + logging.info( + "Temperature below minimum threshold (%s keV) detected.", + min_temperature, + ) + return state.SimError.BELOW_MIN_TEMPERATURE if self.core_profiles.negative_temperature_or_density(): logging.info("Unphysical negative values detected in core profiles:\n") _log_negative_profile_names(self.core_profiles) diff --git a/torax/_src/orchestration/step_function.py b/torax/_src/orchestration/step_function.py index 89dfc5c5e..7dc687b17 100644 --- a/torax/_src/orchestration/step_function.py +++ b/torax/_src/orchestration/step_function.py @@ -151,7 +151,9 @@ def check_for_errors( < self._runtime_params_provider.numerics.min_dt ): return state.SimError.REACHED_MIN_DT - state_error = output_state.check_for_errors() + state_error = output_state.check_for_errors( + min_temperature=self._runtime_params_provider.numerics.min_temperature, + ) if state_error != state.SimError.NO_ERROR: return state_error else: diff --git a/torax/_src/state.py b/torax/_src/state.py index bfe36d021..bcd0934b6 100644 --- a/torax/_src/state.py +++ b/torax/_src/state.py @@ -167,6 +167,24 @@ def negative_temperature_or_density(self) -> jax.Array: ]) ) + def temperature_below_minimum(self, min_temperature: float) -> jax.Array: + """Checks if any temperature is below the minimum threshold. + + Args: + min_temperature: Minimum allowed temperature in keV. + + Returns: + True if any temperature (T_i or T_e) is below min_temperature. + """ + if min_temperature <= 0.0: + return np.array(False) + return np.any( + np.array([ + np.any(np.less(self.T_i.value, min_temperature)), + np.any(np.less(self.T_e.value, min_temperature)), + ]) + ) + def __str__(self) -> str: return f""" CoreProfiles( @@ -292,6 +310,7 @@ class SimError(enum.Enum): QUASINEUTRALITY_BROKEN = 2 NEGATIVE_CORE_PROFILES = 3 REACHED_MIN_DT = 4 + BELOW_MIN_TEMPERATURE = 5 def log_error(self): match self: @@ -320,6 +339,13 @@ def log_error(self): quasineutrality. Check the output file for near-zero temperatures or densities at the last valid step. """) + case SimError.BELOW_MIN_TEMPERATURE: + logging.error(""" + Simulation stopped because temperature fell below the minimum + allowed threshold (min_temperature). This typically occurs during + radiative collapse scenarios. Check the output file for temperature + profiles at the last valid step. + """) case SimError.NO_ERROR: pass case _: diff --git a/torax/_src/tests/state_test.py b/torax/_src/tests/state_test.py index 179aea87a..2d60088a9 100644 --- a/torax/_src/tests/state_test.py +++ b/torax/_src/tests/state_test.py @@ -126,6 +126,50 @@ def test_core_profiles_negative_values_check(self): ) self.assertFalse(new_core_profiles.negative_temperature_or_density()) + def test_temperature_below_minimum(self): + """Tests the temperature_below_minimum method for issue #1607.""" + geo = circular_geometry.CircularConfig().build_geometry() + core_profiles = core_profile_helpers.make_zero_core_profiles(geo) + + # Set temperatures to 0.3 keV + core_profiles = dataclasses.replace( + core_profiles, + T_e=core_profile_helpers.make_constant_core_profile(geo, 0.3), + T_i=core_profile_helpers.make_constant_core_profile(geo, 0.3), + ) + + with self.subTest('min_temperature=0.0 disables check'): + # When min_temperature is 0.0, feature is disabled + self.assertFalse(core_profiles.temperature_below_minimum(0.0)) + + with self.subTest('min_temperature negative disables check'): + # When min_temperature is negative, feature is disabled + self.assertFalse(core_profiles.temperature_below_minimum(-1.0)) + + with self.subTest('temperature below threshold triggers'): + # T=0.3, min=0.5 should trigger + self.assertTrue(core_profiles.temperature_below_minimum(0.5)) + + with self.subTest('temperature above threshold does not trigger'): + # T=0.3, min=0.1 should not trigger + self.assertFalse(core_profiles.temperature_below_minimum(0.1)) + + with self.subTest('T_e below triggers even if T_i above'): + # T_e=0.3, T_i=1.0, min=0.5 should trigger (T_e is below) + mixed_profiles = dataclasses.replace( + core_profiles, + T_i=core_profile_helpers.make_constant_core_profile(geo, 1.0), + ) + self.assertTrue(mixed_profiles.temperature_below_minimum(0.5)) + + with self.subTest('T_i below triggers even if T_e above'): + # T_e=1.0, T_i=0.3, min=0.5 should trigger (T_i is below) + mixed_profiles = dataclasses.replace( + core_profiles, + T_e=core_profile_helpers.make_constant_core_profile(geo, 1.0), + ) + self.assertTrue(mixed_profiles.temperature_below_minimum(0.5)) + class ImpurityFractionsTest(parameterized.TestCase): """Tests for the impurity_fractions attribute in CoreProfiles."""