diff --git a/tests/test_extensions/test_functions.py b/tests/test_extensions/test_functions.py index ac4610a15..90fbaa9f1 100644 --- a/tests/test_extensions/test_functions.py +++ b/tests/test_extensions/test_functions.py @@ -1,12 +1,14 @@ # License: BSD 3-Clause from __future__ import annotations -import inspect +from collections import OrderedDict +import inspect +import numpy as np import pytest - +from unittest.mock import patch import openml.testing -from openml.extensions import get_extension_by_flow, get_extension_by_model, register_extension +from openml.extensions import Extension, get_extension_by_flow, get_extension_by_model, register_extension class DummyFlow: @@ -40,54 +42,197 @@ def can_handle_model(model): return False -def _unregister(): - # "Un-register" the test extensions - while True: - rem_dum_ext1 = False - rem_dum_ext2 = False - try: - openml.extensions.extensions.remove(DummyExtension1) - rem_dum_ext1 = True - except ValueError: - pass - try: - openml.extensions.extensions.remove(DummyExtension2) - rem_dum_ext2 = True - except ValueError: - pass - if not rem_dum_ext1 and not rem_dum_ext2: - break +class DummyExtension(Extension): + @classmethod + def can_handle_flow(cls, flow): + return isinstance(flow, DummyFlow) + + @classmethod + def can_handle_model(cls, model): + return isinstance(model, DummyModel) + + def flow_to_model( + self, + flow, + initialize_with_defaults=False, + strict_version=True, + ): + if not isinstance(flow, DummyFlow): + raise ValueError("Invalid flow") + + model = DummyModel() + model.defaults = initialize_with_defaults + model.strict_version = strict_version + return model + + def model_to_flow(self, model): + if not isinstance(model, DummyModel): + raise ValueError("Invalid model") + return DummyFlow() + + def get_version_information(self): + return ["dummy==1.0"] + + def create_setup_string(self, model): + return "DummyModel()" + + def is_estimator(self, model): + return isinstance(model, DummyModel) + + def seed_model(self, model, seed): + model.seed = seed + return model + + def _run_model_on_fold( + self, + model, + task, + X_train, + rep_no, + fold_no, + y_train=None, + X_test=None, + ): + preds = np.zeros(len(X_train)) + probs = None + measures = OrderedDict() + trace = None + return preds, probs, measures, trace + + def obtain_parameter_values(self, flow, model=None): + return [] + + def check_if_model_fitted(self, model): + return False + + def instantiate_model_from_hpo_class(self, model, trace_iteration): + return DummyModel() + class TestInit(openml.testing.TestBase): - def setUp(self): - super().setUp() - _unregister() def test_get_extension_by_flow(self): - assert get_extension_by_flow(DummyFlow()) is None - with pytest.raises(ValueError, match="No extension registered which can handle flow:"): - get_extension_by_flow(DummyFlow(), raise_if_no_extension=True) - register_extension(DummyExtension1) - assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1) - register_extension(DummyExtension2) - assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1) - register_extension(DummyExtension1) - with pytest.raises( - ValueError, match="Multiple extensions registered which can handle flow:" - ): - get_extension_by_flow(DummyFlow()) + # We replace the global list with a new empty list [] ONLY for this block + with patch("openml.extensions.extensions", []): + assert get_extension_by_flow(DummyFlow()) is None + + with pytest.raises(ValueError, match="No extension registered which can handle flow:"): + get_extension_by_flow(DummyFlow(), raise_if_no_extension=True) + + register_extension(DummyExtension1) + assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1) + + register_extension(DummyExtension2) + assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1) + + register_extension(DummyExtension1) + with pytest.raises( + ValueError, match="Multiple extensions registered which can handle flow:" + ): + get_extension_by_flow(DummyFlow()) def test_get_extension_by_model(self): - assert get_extension_by_model(DummyModel()) is None - with pytest.raises(ValueError, match="No extension registered which can handle model:"): - get_extension_by_model(DummyModel(), raise_if_no_extension=True) - register_extension(DummyExtension1) - assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1) - register_extension(DummyExtension2) - assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1) - register_extension(DummyExtension1) - with pytest.raises( - ValueError, match="Multiple extensions registered which can handle model:" - ): - get_extension_by_model(DummyModel()) + # Again, we start with a fresh empty list automatically + with patch("openml.extensions.extensions", []): + assert get_extension_by_model(DummyModel()) is None + + with pytest.raises(ValueError, match="No extension registered which can handle model:"): + get_extension_by_model(DummyModel(), raise_if_no_extension=True) + + register_extension(DummyExtension1) + assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1) + + register_extension(DummyExtension2) + assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1) + + register_extension(DummyExtension1) + with pytest.raises( + ValueError, match="Multiple extensions registered which can handle model:" + ): + get_extension_by_model(DummyModel()) + + +def test_flow_to_model_with_defaults(): + """Test flow_to_model with initialize_with_defaults=True.""" + ext = DummyExtension() + flow = DummyFlow() + + model = ext.flow_to_model(flow, initialize_with_defaults=True) + + assert isinstance(model, DummyModel) + assert model.defaults is True + +def test_flow_to_model_strict_version(): + """Test flow_to_model with strict_version parameter.""" + ext = DummyExtension() + flow = DummyFlow() + + model_strict = ext.flow_to_model(flow, strict_version=True) + model_non_strict = ext.flow_to_model(flow, strict_version=False) + + assert isinstance(model_strict, DummyModel) + assert model_strict.strict_version is True + + assert isinstance(model_non_strict, DummyModel) + assert model_non_strict.strict_version is False + +def test_model_to_flow_conversion(): + """Test converting a model back to flow representation.""" + ext = DummyExtension() + model = DummyModel() + + flow = ext.model_to_flow(model) + + assert isinstance(flow, DummyFlow) + + +def test_invalid_flow_raises_error(): + """Test that invalid flow raises appropriate error.""" + class InvalidFlow: + pass + + ext = DummyExtension() + flow = InvalidFlow() + + with pytest.raises(ValueError, match="Invalid flow"): + ext.flow_to_model(flow) + + +@patch("openml.extensions.extensions", []) +def test_extension_not_found_error_message(): + """Test error message contains helpful information.""" + class UnknownModel: + pass + + with pytest.raises(ValueError, match="No extension registered"): + get_extension_by_model(UnknownModel(), raise_if_no_extension=True) + + +def test_register_same_extension_twice(): + """Test behavior when registering same extension twice.""" + # Using a context manager here to isolate the list + with patch("openml.extensions.extensions", []): + register_extension(DummyExtension) + register_extension(DummyExtension) + + matches = [ + ext for ext in openml.extensions.extensions + if ext is DummyExtension + ] + assert len(matches) == 2 + + +@patch("openml.extensions.extensions", []) +def test_extension_priority_order(): + """Test that extensions are checked in registration order.""" + class DummyExtensionA(DummyExtension): + pass + class DummyExtensionB(DummyExtension): + pass + + register_extension(DummyExtensionA) + register_extension(DummyExtensionB) + + assert openml.extensions.extensions[0] is DummyExtensionA + assert openml.extensions.extensions[1] is DummyExtensionB \ No newline at end of file