diff --git a/tableauserverclient/server/endpoint/flow_runs_endpoint.py b/tableauserverclient/server/endpoint/flow_runs_endpoint.py index 3d09ad569..2c3bb84bc 100644 --- a/tableauserverclient/server/endpoint/flow_runs_endpoint.py +++ b/tableauserverclient/server/endpoint/flow_runs_endpoint.py @@ -1,9 +1,9 @@ import logging -from typing import Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING, Union from tableauserverclient.server.endpoint.endpoint import QuerysetEndpoint, api from tableauserverclient.server.endpoint.exceptions import FlowRunFailedException, FlowRunCancelledException -from tableauserverclient.models import FlowRunItem, PaginationItem +from tableauserverclient.models import FlowRunItem from tableauserverclient.exponential_backoff import ExponentialBackoffTimer from tableauserverclient.helpers.logging import logger @@ -25,13 +25,15 @@ def baseurl(self) -> str: # Get all flows @api(version="3.10") - def get(self, req_options: Optional["RequestOptions"] = None) -> tuple[list[FlowRunItem], PaginationItem]: + # QuerysetEndpoint expects a PaginationItem to be returned, but FlowRuns + # does not return a PaginationItem. Suppressing the mypy error because the + # changes to the QuerySet class should permit this to function regardless. + def get(self, req_options: Optional["RequestOptions"] = None) -> list[FlowRunItem]: # type: ignore[override] logger.info("Querying all flow runs on site") url = self.baseurl server_response = self.get_request(url, req_options) - pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace) all_flow_run_items = FlowRunItem.from_response(server_response.content, self.parent_srv.namespace) - return all_flow_run_items, pagination_item + return all_flow_run_items # Get 1 flow by id @api(version="3.10") @@ -46,7 +48,7 @@ def get_by_id(self, flow_run_id: str) -> FlowRunItem: # Cancel 1 flow run by id @api(version="3.10") - def cancel(self, flow_run_id: str) -> None: + def cancel(self, flow_run_id: Union[str, FlowRunItem]) -> None: if not flow_run_id: error = "Flow ID undefined." raise ValueError(error) diff --git a/tableauserverclient/server/query.py b/tableauserverclient/server/query.py index e72b29ab2..feebc1a7e 100644 --- a/tableauserverclient/server/query.py +++ b/tableauserverclient/server/query.py @@ -1,9 +1,10 @@ -from collections.abc import Sized +from collections.abc import Iterable, Iterator, Sized from itertools import count from typing import Optional, Protocol, TYPE_CHECKING, TypeVar, overload -from collections.abc import Iterable, Iterator +import sys from tableauserverclient.config import config from tableauserverclient.models.pagination_item import PaginationItem +from tableauserverclient.server.endpoint.exceptions import ServerResponseError from tableauserverclient.server.filter import Filter from tableauserverclient.server.request_options import RequestOptions from tableauserverclient.server.sort import Sort @@ -35,6 +36,32 @@ def to_camel_case(word: str) -> str: class QuerySet(Iterable[T], Sized): + """ + QuerySet is a class that allows easy filtering, sorting, and iterating over + many endpoints in TableauServerClient. It is designed to be used in a similar + way to Django QuerySets, but with a more limited feature set. + + QuerySet is an iterable, and can be used in for loops, list comprehensions, + and other places where iterables are expected. + + QuerySet is also Sized, and can be used in places where the length of the + QuerySet is needed. The length of the QuerySet is the total number of items + available in the QuerySet, not just the number of items that have been + fetched. If the endpoint does not return a total count of items, the length + of the QuerySet will be sys.maxsize. If there is no total count, the + QuerySet will continue to fetch items until there are no more items to + fetch. + + QuerySet is not re-entrant. It is not designed to be used in multiple places + at the same time. If you need to use a QuerySet in multiple places, you + should create a new QuerySet for each place you need to use it, convert it + to a list, or create a deep copy of the QuerySet. + + QuerySets are also indexable, and can be sliced. If you try to access an + index that has not been fetched, the QuerySet will fetch the page that + contains the item you are looking for. + """ + def __init__(self, model: "QuerysetEndpoint[T]", page_size: Optional[int] = None) -> None: self.model = model self.request_options = RequestOptions(pagesize=page_size or config.PAGE_SIZE) @@ -50,10 +77,20 @@ def __iter__(self: Self) -> Iterator[T]: for page in count(1): self.request_options.pagenumber = page self._result_cache = [] - self._fetch_all() + try: + self._fetch_all() + except ServerResponseError as e: + if e.code == "400006": + # If the endpoint does not support pagination, it will end + # up overrunning the total number of pages. Catch the + # error and break out of the loop. + raise StopIteration yield from self._result_cache - # Set result_cache to empty so the fetch will populate - if (page * self.page_size) >= len(self): + # If the length of the QuerySet is unknown, continue fetching until + # the result cache is empty. + if (size := len(self)) == 0: + continue + if (page * self.page_size) >= size: return @overload @@ -114,10 +151,15 @@ def _fetch_all(self: Self) -> None: Retrieve the data and store result and pagination item in cache """ if not self._result_cache: - self._result_cache, self._pagination_item = self.model.get(self.request_options) + response = self.model.get(self.request_options) + if isinstance(response, tuple): + self._result_cache, self._pagination_item = response + else: + self._result_cache = response + self._pagination_item = PaginationItem() def __len__(self: Self) -> int: - return self.total_available + return self.total_available or sys.maxsize @property def total_available(self: Self) -> int: @@ -127,12 +169,16 @@ def total_available(self: Self) -> int: @property def page_number(self: Self) -> int: self._fetch_all() - return self._pagination_item.page_number + # If the PaginationItem is not returned from the endpoint, use the + # pagenumber from the RequestOptions. + return self._pagination_item.page_number or self.request_options.pagenumber @property def page_size(self: Self) -> int: self._fetch_all() - return self._pagination_item.page_size + # If the PaginationItem is not returned from the endpoint, use the + # pagesize from the RequestOptions. + return self._pagination_item.page_size or self.request_options.pagesize def filter(self: Self, *invalid, page_size: Optional[int] = None, **kwargs) -> Self: if invalid: diff --git a/test/_utils.py b/test/_utils.py index 8527aaf8c..b4ee93bc3 100644 --- a/test/_utils.py +++ b/test/_utils.py @@ -1,5 +1,6 @@ import os.path import unittest +from xml.etree import ElementTree as ET from contextlib import contextmanager TEST_ASSET_DIR = os.path.join(os.path.dirname(__file__), "assets") @@ -18,6 +19,19 @@ def read_xml_assets(*args): return map(read_xml_asset, args) +def server_response_error_factory(code: str, summary: str, detail: str) -> str: + root = ET.Element("tsResponse") + error = ET.SubElement(root, "error") + error.attrib["code"] = code + + summary_element = ET.SubElement(error, "summary") + summary_element.text = summary + + detail_element = ET.SubElement(error, "detail") + detail_element.text = detail + return ET.tostring(root, encoding="utf-8").decode("utf-8") + + @contextmanager def mocked_time(): mock_time = 0 diff --git a/test/assets/flow_runs_get.xml b/test/assets/flow_runs_get.xml index bdce4cdfb..489e8ac63 100644 --- a/test/assets/flow_runs_get.xml +++ b/test/assets/flow_runs_get.xml @@ -1,5 +1,4 @@ - - \ No newline at end of file + diff --git a/test/test_flowruns.py b/test/test_flowruns.py index e1ddd5541..8af2540dc 100644 --- a/test/test_flowruns.py +++ b/test/test_flowruns.py @@ -1,3 +1,4 @@ +import sys import unittest import requests_mock @@ -5,7 +6,7 @@ import tableauserverclient as TSC from tableauserverclient.datetime_helpers import format_datetime from tableauserverclient.server.endpoint.exceptions import FlowRunFailedException -from ._utils import read_xml_asset, mocked_time +from ._utils import read_xml_asset, mocked_time, server_response_error_factory GET_XML = "flow_runs_get.xml" GET_BY_ID_XML = "flow_runs_get_by_id.xml" @@ -28,9 +29,8 @@ def test_get(self) -> None: response_xml = read_xml_asset(GET_XML) with requests_mock.mock() as m: m.get(self.baseurl, text=response_xml) - all_flow_runs, pagination_item = self.server.flow_runs.get() + all_flow_runs = self.server.flow_runs.get() - self.assertEqual(2, pagination_item.total_available) self.assertEqual("cc2e652d-4a9b-4476-8c93-b238c45db968", all_flow_runs[0].id) self.assertEqual("2021-02-11T01:42:55Z", format_datetime(all_flow_runs[0].started_at)) self.assertEqual("2021-02-11T01:57:38Z", format_datetime(all_flow_runs[0].completed_at)) @@ -98,3 +98,14 @@ def test_wait_for_job_timeout(self) -> None: m.get(f"{self.baseurl}/{flow_run_id}", text=response_xml) with self.assertRaises(TimeoutError): self.server.flow_runs.wait_for_job(flow_run_id, timeout=30) + + def test_queryset(self) -> None: + response_xml = read_xml_asset(GET_XML) + error_response = server_response_error_factory( + "400006", "Bad Request", "0xB4EAB088 : The start index '9900' is greater than or equal to the total count.)" + ) + with requests_mock.mock() as m: + m.get(f"{self.baseurl}?pageNumber=1", text=response_xml) + m.get(f"{self.baseurl}?pageNumber=2", text=error_response) + queryset = self.server.flow_runs.all() + assert len(queryset) == sys.maxsize