Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 51 additions & 32 deletions tests/test_tasks/test_task_methods.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# License: BSD 3-Clause
from __future__ import annotations

from time import time
from unittest import mock

import pytest
import requests

import openml
import openml._api_calls
from openml.testing import TestBase
import pytest


# Common methods between tasks
Expand All @@ -16,23 +19,49 @@ def setUp(self):
def tearDown(self):
super().tearDown()

@pytest.mark.test_server()
def test_tagging(self):
task = openml.tasks.get_task(1) # anneal; crossvalidation
# tags can be at most 64 alphanumeric (+ underscore) chars
unique_indicator = str(time()).replace(".", "")
tag = f"test_tag_OpenMLTaskMethodsTest_{unique_indicator}"
tasks = openml.tasks.list_tasks(tag=tag)
assert len(tasks) == 0
task.push_tag(tag)
tasks = openml.tasks.list_tasks(tag=tag)
assert len(tasks) == 1
assert 1 in tasks["tid"]
task.remove_tag(tag)
tasks = openml.tasks.list_tasks(tag=tag)
assert len(tasks) == 0

@pytest.mark.test_server()
openml.config.set_root_cache_directory(self.static_cache_dir)
task = openml.tasks.get_task(1882, download_data=False)

tag = "test_tag_OpenMLTaskMethodsTest"
tag_url = openml._api_calls._create_url_from_endpoint("task/tag")
untag_url = openml._api_calls._create_url_from_endpoint("task/untag")
expected_data = {"task_id": 1882, "tag": tag, "api_key": openml.config.apikey}

def _make_response(content: str) -> requests.Response:
response = requests.Response()
response.status_code = 200
response._content = content.encode()
return response

with mock.patch.object(
requests.Session,
"post",
return_value=_make_response(
f'<oml:task_tag xmlns:oml="http://openml.org/openml">'
f"<oml:id>1882</oml:id><oml:tag>{tag}</oml:tag>"
f"</oml:task_tag>"
),
) as mock_post:
task.push_tag(tag)
mock_post.assert_called_once_with(
tag_url, data=expected_data, files=None, headers=openml._api_calls._HEADERS
)

with mock.patch.object(
requests.Session,
"post",
return_value=_make_response(
'<oml:task_untag xmlns:oml="http://openml.org/openml">'
"<oml:id>1882</oml:id>"
"</oml:task_untag>"
),
) as mock_post:
task.remove_tag(tag)
mock_post.assert_called_once_with(
untag_url, data=expected_data, files=None, headers=openml._api_calls._HEADERS
)

def test_get_train_and_test_split_indices(self):
openml.config.set_root_cache_directory(self.static_cache_dir)
task = openml.tasks.get_task(1882)
Expand All @@ -46,17 +75,7 @@ def test_get_train_and_test_split_indices(self):
assert train_indices[-1] == 681
assert test_indices[0] == 583
assert test_indices[-1] == 24
self.assertRaisesRegex(
ValueError,
"Fold 10 not known",
task.get_train_test_split_indices,
10,
0,
)
self.assertRaisesRegex(
ValueError,
"Repeat 10 not known",
task.get_train_test_split_indices,
0,
10,
)
with pytest.raises(ValueError, match="Fold 10 not known"):
task.get_train_test_split_indices(10, 0)
with pytest.raises(ValueError, match="Repeat 10 not known"):
task.get_train_test_split_indices(0, 10)
Loading