diff --git a/tests/files/mock_responses/setups/setup_list_flow5873.xml b/tests/files/mock_responses/setups/setup_list_flow5873.xml new file mode 100644 index 000000000..509ea7e3e --- /dev/null +++ b/tests/files/mock_responses/setups/setup_list_flow5873.xml @@ -0,0 +1,10 @@ + + + 1001 + 5873 + + + 1002 + 5873 + + diff --git a/tests/test_setups/test_setup_functions.py b/tests/test_setups/test_setup_functions.py index 0df3a0b3b..27971e12e 100644 --- a/tests/test_setups/test_setup_functions.py +++ b/tests/test_setups/test_setup_functions.py @@ -3,10 +3,11 @@ import hashlib import time -import unittest.mock +import unittest.mock as mock import pandas as pd import pytest +import requests import sklearn.base import sklearn.naive_bayes import sklearn.tree @@ -14,7 +15,7 @@ import openml import openml.exceptions -from openml.testing import TestBase +from openml.testing import TestBase, create_request_response def get_sentinel(): @@ -135,18 +136,6 @@ def test_get_setup(self): else: assert len(current.parameters) == num_params[idx] - @pytest.mark.production_server() - def test_setup_list_filter_flow(self): - self.use_production_server() - - flow_id = 5873 - - setups = openml.setups.list_setups(flow=flow_id) - - assert len(setups) > 0 # TODO: please adjust 0 - for setup_id in setups: - assert setups[setup_id].flow_id == flow_id - @pytest.mark.test_server() def test_list_setups_empty(self): setups = openml.setups.list_setups(setup=[0]) @@ -189,3 +178,26 @@ def test_get_uncached_setup(self): openml.config.set_root_cache_directory(self.static_cache_dir) with pytest.raises(openml.exceptions.OpenMLCacheException): openml.setups.functions._get_cached_setup(10) + + +@mock.patch.object(requests.Session, "get") +def test_setup_list_filter_flow(mock_get, test_files_directory, test_api_key): + content_file = ( + test_files_directory / "mock_responses" / "setups" / "setup_list_flow5873.xml" + ) + mock_get.return_value = create_request_response( + status_code=200, + content_filepath=content_file, + ) + + flow_id = 5873 + setups = openml.setups.list_setups(flow=flow_id) + + assert len(setups) > 0 + for setup_id in setups: + assert setups[setup_id].flow_id == flow_id + + # Verify the GET URL contains the flow filter + call_url = mock_get.call_args.args[0] + assert "setup/list" in call_url + assert f"flow/{flow_id}" in call_url