diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index dd939620..ac9a2e75 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -3,7 +3,7 @@ FROM mcr.microsoft.com/vscode/devcontainers/python:0-${VARIANT} USER vscode -RUN curl -sSf https://rye-up.com/get | RYE_VERSION="0.24.0" RYE_INSTALL_OPTION="--yes" bash +RUN curl -sSf https://rye.astral.sh/get | RYE_VERSION="0.35.0" RYE_INSTALL_OPTION="--yes" bash ENV PATH=/home/vscode/.rye/shims:$PATH RUN echo "[[ -d .venv ]] && source .venv/bin/activate" >> /home/vscode/.bashrc diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6ce2f72b..40293964 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,6 +6,7 @@ on: pull_request: branches: - main + - next jobs: lint: @@ -18,27 +19,17 @@ jobs: - name: Install Rye run: | - curl -sSf https://rye-up.com/get | bash + curl -sSf https://rye.astral.sh/get | bash echo "$HOME/.rye/shims" >> $GITHUB_PATH env: - RYE_VERSION: 0.24.0 + RYE_VERSION: '0.35.0' RYE_INSTALL_OPTION: '--yes' - name: Install dependencies - run: | - rye sync --all-features - - - name: Run ruff - run: | - rye run check:ruff + run: rye sync --all-features - - name: Run type checking - run: | - rye run typecheck - - - name: Ensure importable - run: | - rye run python -c 'import maisa' + - name: Run lints + run: ./scripts/lint test: name: test runs-on: ubuntu-latest @@ -48,10 +39,10 @@ jobs: - name: Install Rye run: | - curl -sSf https://rye-up.com/get | bash + curl -sSf https://rye.astral.sh/get | bash echo "$HOME/.rye/shims" >> $GITHUB_PATH env: - RYE_VERSION: 0.24.0 + RYE_VERSION: '0.35.0' RYE_INSTALL_OPTION: '--yes' - name: Bootstrap diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 3a18e4be..a251ab22 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -18,11 +18,11 @@ jobs: - name: Install Rye run: | - curl -sSf https://rye-up.com/get | bash + curl -sSf https://rye.astral.sh/get | bash echo "$HOME/.rye/shims" >> $GITHUB_PATH env: - RYE_VERSION: 0.24.0 - RYE_INSTALL_OPTION: "--yes" + RYE_VERSION: '0.35.0' + RYE_INSTALL_OPTION: '--yes' - name: Publish to PyPI run: | diff --git a/.github/workflows/release-doctor.yml b/.github/workflows/release-doctor.yml index 588669b5..e2332a84 100644 --- a/.github/workflows/release-doctor.yml +++ b/.github/workflows/release-doctor.yml @@ -1,6 +1,8 @@ name: Release Doctor on: pull_request: + branches: + - main workflow_dispatch: jobs: diff --git a/.gitignore b/.gitignore index 0f9a66a9..87797408 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.prism.log .vscode _dev diff --git a/.stats.yml b/.stats.yml index 9128fdb4..4076c464 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,2 +1,2 @@ configured_endpoints: 13 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/maisa%2Fmaisa-c82e0a80b379d33d11af783865ec15844bcd750043bf0f1543ebdf72bda79f3a.yml +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/maisa%2FMaisa-c82e0a80b379d33d11af783865ec15844bcd750043bf0f1543ebdf72bda79f3a.yml diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index cf8d4723..a2ddcfb0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,9 +2,13 @@ ### With Rye -We use [Rye](https://rye-up.com/) to manage dependencies so we highly recommend [installing it](https://rye-up.com/guide/installation/) as it will automatically provision a Python environment with the expected Python version. +We use [Rye](https://rye.astral.sh/) to manage dependencies because it will automatically provision a Python environment with the expected Python version. To set it up, run: -After installing Rye, you'll just have to run this command: +```sh +$ ./scripts/bootstrap +``` + +Or [install Rye manually](https://rye.astral.sh/guide/installation/) and run: ```sh $ rye sync --all-features @@ -31,25 +35,25 @@ $ pip install -r requirements-dev.lock ## Modifying/Adding code -Most of the SDK is generated code, and any modified code will be overridden on the next generation. The -`src/maisa/lib/` and `examples/` directories are exceptions and will never be overridden. +Most of the SDK is generated code. Modifications to code will be persisted between generations, but may +result in merge conflicts between manual patches and changes from the generator. The generator will never +modify the contents of the `src/maisa/lib/` and `examples/` directories. ## Adding and running examples -All files in the `examples/` directory are not modified by the Stainless generator and can be freely edited or -added to. +All files in the `examples/` directory are not modified by the generator and can be freely edited or added to. -```bash +```py # add an example to examples/.py #!/usr/bin/env -S rye run python … ``` -``` -chmod +x examples/.py +```sh +$ chmod +x examples/.py # run the example against your api -./examples/.py +$ ./examples/.py ``` ## Using the repository from source @@ -58,8 +62,8 @@ If you’d like to use the repository from source, you can either install from g To install via git: -```bash -pip install git+ssh://git@github.com/maisaai/python-sdk.git +```sh +$ pip install git+ssh://git@github.com/maisaai/python-sdk.git ``` Alternatively, you can build from source and install the wheel file: @@ -68,29 +72,29 @@ Building this package will create two files in the `dist/` directory, a `.tar.gz To create a distributable version of the library, all you have to do is run this command: -```bash -rye build +```sh +$ rye build # or -python -m build +$ python -m build ``` Then to install: ```sh -pip install ./path-to-wheel-file.whl +$ pip install ./path-to-wheel-file.whl ``` ## Running tests Most tests require you to [set up a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests. -```bash +```sh # you will need npm installed -npx prism mock path/to/your/openapi.yml +$ npx prism mock path/to/your/openapi.yml ``` -```bash -rye run pytest +```sh +$ ./scripts/test ``` ## Linting and formatting @@ -100,14 +104,14 @@ This repository uses [ruff](https://github.com/astral-sh/ruff) and To lint: -```bash -rye run lint +```sh +$ ./scripts/lint ``` To format and fix all ruff issues automatically: -```bash -rye run format +```sh +$ ./scripts/format ``` ## Publishing and releases diff --git a/README.md b/README.md index 5d5c0a00..c82f836f 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![PyPI version](https://img.shields.io/pypi/v/maisa.svg)](https://pypi.org/project/maisa/) -The Maisa Python library provides convenient access to the Maisa REST API from any Python 3.7+ +The Maisa Python library provides convenient access to the Maisa REST API from any Python 3.8+ application. The library includes type definitions for all request params and response fields, and offers both synchronous and asynchronous clients powered by [httpx](https://github.com/encode/httpx). @@ -10,7 +10,7 @@ It is generated with [Stainless](https://www.stainlessapi.com/). ## Documentation -The REST API documentation can be found [on docs.maisa.ai](https://docs.maisa.ai/). The full API of this library can be found in [api.md](api.md). +The REST API documentation can be found on [docs.maisa.ai](https://docs.maisa.ai/). The full API of this library can be found in [api.md](api.md). ## Installation @@ -270,7 +270,7 @@ You can directly override the [httpx client](https://www.python-httpx.org/api/#c - Support for proxies - Custom transports -- Additional [advanced](https://www.python-httpx.org/advanced/#client-instances) functionality +- Additional [advanced](https://www.python-httpx.org/advanced/clients/) functionality ```python from maisa import Maisa, DefaultHttpxClient @@ -285,6 +285,12 @@ client = Maisa( ) ``` +You can also customize the client on a per-request basis by using `with_options()`: + +```python +client.with_options(http_client=DefaultHttpxClient(...)) +``` + ### Managing HTTP resources By default the library closes underlying HTTP connections whenever the client is [garbage collected](https://docs.python.org/3/reference/datamodel.html#object.__del__). You can manually close the client using the `.close()` method if desired, or with a context manager that closes when exiting. @@ -301,6 +307,21 @@ We take backwards-compatibility seriously and work hard to ensure you can rely o We are keen for your feedback; please open an [issue](https://www.github.com/maisaai/python-sdk/issues) with questions, bugs, or suggestions. +### Determining the installed version + +If you've upgraded to the latest version but aren't seeing any new features you were expecting then your python environment is likely still using an older version. + +You can determine the version that is being used at runtime with: + +```py +import maisa +print(maisa.__version__) +``` + ## Requirements -Python 3.7 or higher. +Python 3.8 or higher. + +## Contributing + +See [the contributing documentation](./CONTRIBUTING.md). diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..459e7314 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,27 @@ +# Security Policy + +## Reporting Security Issues + +This SDK is generated by [Stainless Software Inc](http://stainlessapi.com). Stainless takes security seriously, and encourages you to report any security vulnerability promptly so that appropriate action can be taken. + +To report a security issue, please contact the Stainless team at security@stainlessapi.com. + +## Responsible Disclosure + +We appreciate the efforts of security researchers and individuals who help us maintain the security of +SDKs we generate. If you believe you have found a security vulnerability, please adhere to responsible +disclosure practices by allowing us a reasonable amount of time to investigate and address the issue +before making any information public. + +## Reporting Non-SDK Related Security Issues + +If you encounter security issues that are not directly related to SDKs but pertain to the services +or products provided by Maisa please follow the respective company's security reporting guidelines. + +### Maisa Terms and Policies + +Please contact support@maisa.ai for any questions or concerns regarding security of our services. + +--- + +Thank you for helping us keep the SDKs and systems they interact with secure. diff --git a/bin/check-release-environment b/bin/check-release-environment index dff35355..5af9c3e6 100644 --- a/bin/check-release-environment +++ b/bin/check-release-environment @@ -1,20 +1,9 @@ #!/usr/bin/env bash -warnings=() errors=() if [ -z "${PYPI_TOKEN}" ]; then - warnings+=("The MAISA_PYPI_TOKEN secret has not been set. Please set it in either this repository's secrets or your organization secrets.") -fi - -lenWarnings=${#warnings[@]} - -if [[ lenWarnings -gt 0 ]]; then - echo -e "Found the following warnings in the release environment:\n" - - for warning in "${warnings[@]}"; do - echo -e "- $warning\n" - done + errors+=("The MAISA_PYPI_TOKEN secret has not been set. Please set it in either this repository's secrets or your organization secrets.") fi lenErrors=${#errors[@]} diff --git a/bin/publish-pypi b/bin/publish-pypi index 826054e9..05bfccbb 100644 --- a/bin/publish-pypi +++ b/bin/publish-pypi @@ -3,4 +3,7 @@ set -eux mkdir -p dist rye build --clean +# Patching importlib-metadata version until upstream library version is updated +# https://github.com/pypa/twine/issues/977#issuecomment-2189800841 +"$HOME/.rye/self/bin/python3" -m pip install 'importlib-metadata==7.2.1' rye publish --yes --token=$PYPI_TOKEN diff --git a/pyproject.toml b/pyproject.toml index d9df4e66..1720adab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,13 +15,11 @@ dependencies = [ "distro>=1.7.0, <2", "sniffio", "cached-property; python_version < '3.8'", - ] -requires-python = ">= 3.7" +requires-python = ">= 3.8" classifiers = [ "Typing :: Typed", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", @@ -36,8 +34,6 @@ classifiers = [ "License :: OSI Approved :: Apache Software License" ] - - [project.urls] Homepage = "https://github.com/maisaai/python-sdk" Repository = "https://github.com/maisaai/python-sdk" @@ -58,7 +54,7 @@ dev-dependencies = [ "nox", "dirty-equals>=0.6.0", "importlib-metadata>=6.7.0", - + "rich>=13.7.1", ] [tool.rye.scripts] @@ -66,18 +62,21 @@ format = { chain = [ "format:ruff", "format:docs", "fix:ruff", + # run formatting again to fix any inconsistencies when imports are stripped + "format:ruff", ]} -"format:black" = "black ." "format:docs" = "python scripts/utils/ruffen-docs.py README.md api.md" "format:ruff" = "ruff format" -"format:isort" = "isort ." "lint" = { chain = [ "check:ruff", "typecheck", + "check:importable", ]} -"check:ruff" = "ruff ." -"fix:ruff" = "ruff --fix ." +"check:ruff" = "ruff check ." +"fix:ruff" = "ruff check --fix ." + +"check:importable" = "python -c 'import maisa'" typecheck = { chain = [ "typecheck:pyright", @@ -99,6 +98,21 @@ include = [ [tool.hatch.build.targets.wheel] packages = ["src/maisa"] +[tool.hatch.build.targets.sdist] +# Basically everything except hidden files/directories (such as .github, .devcontainers, .python-version, etc) +include = [ + "/*.toml", + "/*.json", + "/*.lock", + "/*.md", + "/mypy.ini", + "/noxfile.py", + "bin/*", + "examples/*", + "src/*", + "tests/*", +] + [tool.hatch.metadata.hooks.fancy-pypi-readme] content-type = "text/markdown" @@ -110,10 +124,6 @@ path = "README.md" pattern = '\[(.+?)\]\(((?!https?://)\S+?)\)' replacement = '[\1](https://github.com/maisaai/python-sdk/tree/main/\g<2>)' -[tool.black] -line-length = 120 -target-version = ["py37"] - [tool.pytest.ini_options] testpaths = ["tests"] addopts = "--tb=short" @@ -128,7 +138,7 @@ filterwarnings = [ # there are a couple of flags that are still disabled by # default in strict mode as they are experimental and niche. typeCheckingMode = "strict" -pythonVersion = "3.7" +pythonVersion = "3.8" exclude = [ "_dev", @@ -146,6 +156,11 @@ reportPrivateUsage = false line-length = 120 output-format = "grouped" target-version = "py37" + +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff.lint] select = [ # isort "I", @@ -174,10 +189,6 @@ unfixable = [ "T201", "T203", ] -ignore-init-module-imports = true - -[tool.ruff.format] -docstring-code-format = true [tool.ruff.lint.flake8-tidy-imports.banned-api] "functools.lru_cache".msg = "This function does not retain type information for the wrapped function's arguments; The `lru_cache` function from `_utils` should be used instead" @@ -189,7 +200,7 @@ combine-as-imports = true extra-standard-library = ["typing_extensions"] known-first-party = ["maisa", "tests"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "bin/**.py" = ["T201", "T203"] "scripts/**.py" = ["T201", "T203"] "tests/**.py" = ["T201", "T203"] diff --git a/requirements-dev.lock b/requirements-dev.lock index 7a6ac29b..aab8a26d 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -6,17 +6,16 @@ # features: [] # all-features: true # with-sources: false +# generate-hashes: false -e file:. annotated-types==0.6.0 # via pydantic -anyio==4.1.0 +anyio==4.4.0 # via httpx # via maisa argcomplete==3.1.2 # via nox -attrs==23.1.0 - # via pytest certifi==2023.7.22 # via httpcore # via httpx @@ -27,8 +26,9 @@ distlib==0.3.7 # via virtualenv distro==1.8.0 # via maisa -exceptiongroup==1.1.3 +exceptiongroup==1.2.2 # via anyio + # via pytest filelock==3.12.4 # via virtualenv h11==0.14.0 @@ -44,7 +44,11 @@ idna==3.4 importlib-metadata==7.0.0 iniconfig==2.0.0 # via pytest -mypy==1.7.1 +markdown-it-py==3.0.0 + # via rich +mdurl==0.1.2 + # via markdown-it-py +mypy==1.13.0 mypy-extensions==1.0.0 # via mypy nodeenv==1.8.0 @@ -55,24 +59,25 @@ packaging==23.2 # via pytest platformdirs==3.11.0 # via virtualenv -pluggy==1.3.0 +pluggy==1.5.0 # via pytest -py==1.11.0 - # via pytest -pydantic==2.4.2 +pydantic==2.9.2 # via maisa -pydantic-core==2.10.1 +pydantic-core==2.23.4 # via pydantic -pyright==1.1.359 -pytest==7.1.1 +pygments==2.18.0 + # via rich +pyright==1.1.380 +pytest==8.3.3 # via pytest-asyncio -pytest-asyncio==0.21.1 +pytest-asyncio==0.24.0 python-dateutil==2.8.2 # via time-machine pytz==2023.3.post1 # via dirty-equals respx==0.20.2 -ruff==0.1.9 +rich==13.7.1 +ruff==0.6.9 setuptools==68.2.2 # via nodeenv six==1.16.0 @@ -82,10 +87,11 @@ sniffio==1.3.0 # via httpx # via maisa time-machine==2.9.0 -tomli==2.0.1 +tomli==2.0.2 # via mypy # via pytest -typing-extensions==4.8.0 +typing-extensions==4.12.2 + # via anyio # via maisa # via mypy # via pydantic diff --git a/requirements.lock b/requirements.lock index 6a42654c..46d20886 100644 --- a/requirements.lock +++ b/requirements.lock @@ -6,11 +6,12 @@ # features: [] # all-features: true # with-sources: false +# generate-hashes: false -e file:. annotated-types==0.6.0 # via pydantic -anyio==4.1.0 +anyio==4.4.0 # via httpx # via maisa certifi==2023.7.22 @@ -18,7 +19,7 @@ certifi==2023.7.22 # via httpx distro==1.8.0 # via maisa -exceptiongroup==1.1.3 +exceptiongroup==1.2.2 # via anyio h11==0.14.0 # via httpcore @@ -29,15 +30,16 @@ httpx==0.25.2 idna==3.4 # via anyio # via httpx -pydantic==2.4.2 +pydantic==2.9.2 # via maisa -pydantic-core==2.10.1 +pydantic-core==2.23.4 # via pydantic sniffio==1.3.0 # via anyio # via httpx # via maisa -typing-extensions==4.8.0 +typing-extensions==4.12.2 + # via anyio # via maisa # via pydantic # via pydantic-core diff --git a/scripts/bootstrap b/scripts/bootstrap index 29df07e7..8c5c60eb 100755 --- a/scripts/bootstrap +++ b/scripts/bootstrap @@ -16,4 +16,4 @@ echo "==> Installing Python dependencies…" # experimental uv support makes installations significantly faster rye config --set-bool behavior.use-uv=true -rye sync +rye sync --all-features diff --git a/scripts/format b/scripts/format index 2a9ea466..667ec2d7 100755 --- a/scripts/format +++ b/scripts/format @@ -4,5 +4,5 @@ set -e cd "$(dirname "$0")/.." +echo "==> Running formatters" rye run format - diff --git a/scripts/lint b/scripts/lint index 0cc68b51..b9127eda 100755 --- a/scripts/lint +++ b/scripts/lint @@ -4,5 +4,9 @@ set -e cd "$(dirname "$0")/.." +echo "==> Running lints" rye run lint +echo "==> Making sure it imports" +rye run python -c 'import maisa' + diff --git a/scripts/mock b/scripts/mock index fe89a1d0..d2814ae6 100755 --- a/scripts/mock +++ b/scripts/mock @@ -21,7 +21,7 @@ echo "==> Starting mock server with URL ${URL}" # Run prism mock on the given spec if [ "$1" == "--daemon" ]; then - npm exec --package=@stoplight/prism-cli@~5.8 -- prism mock "$URL" &> .prism.log & + npm exec --package=@stainless-api/prism-cli@5.8.5 -- prism mock "$URL" &> .prism.log & # Wait for server to come online echo -n "Waiting for server" @@ -37,5 +37,5 @@ if [ "$1" == "--daemon" ]; then echo else - npm exec --package=@stoplight/prism-cli@~5.8 -- prism mock "$URL" + npm exec --package=@stainless-api/prism-cli@5.8.5 -- prism mock "$URL" fi diff --git a/scripts/test b/scripts/test index be01d044..4fa5698b 100755 --- a/scripts/test +++ b/scripts/test @@ -52,6 +52,8 @@ else echo fi -# Run tests echo "==> Running tests" rye run pytest "$@" + +echo "==> Running Pydantic v1 tests" +rye run nox -s test-pydantic-v1 -- "$@" diff --git a/src/maisa/_base_client.py b/src/maisa/_base_client.py index 338b5a46..c5a8b7d2 100644 --- a/src/maisa/_base_client.py +++ b/src/maisa/_base_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys import json import time import uuid @@ -58,9 +59,10 @@ HttpxSendArgs, AsyncTransport, RequestOptions, + HttpxRequestFiles, ModelBuilderProtocol, ) -from ._utils import is_dict, is_list, is_given, lru_cache, is_mapping +from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping from ._compat import model_copy, model_dump from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type from ._response import ( @@ -123,16 +125,14 @@ def __init__( self, *, url: URL, - ) -> None: - ... + ) -> None: ... @overload def __init__( self, *, params: Query, - ) -> None: - ... + ) -> None: ... def __init__( self, @@ -143,6 +143,12 @@ def __init__( self.url = url self.params = params + @override + def __repr__(self) -> str: + if self.url: + return f"{self.__class__.__name__}(url={self.url})" + return f"{self.__class__.__name__}(params={self.params})" + class BasePage(GenericModel, Generic[_T]): """ @@ -165,8 +171,7 @@ def has_next_page(self) -> bool: return False return self.next_page_info() is not None - def next_page_info(self) -> Optional[PageInfo]: - ... + def next_page_info(self) -> Optional[PageInfo]: ... def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body] ... @@ -358,6 +363,7 @@ def __init__( self._custom_query = custom_query or {} self._strict_response_validation = _strict_response_validation self._idempotency_header = None + self._platform: Platform | None = None if max_retries is None: # pyright: ignore[reportUnnecessaryComparison] raise TypeError( @@ -400,14 +406,7 @@ def _make_status_error( ) -> _exceptions.APIStatusError: raise NotImplementedError() - def _remaining_retries( - self, - remaining_retries: Optional[int], - options: FinalRequestOptions, - ) -> int: - return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries) - - def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers: + def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Headers: custom_headers = options.headers or {} headers_dict = _merge_mappings(self.default_headers, custom_headers) self._validate_headers(headers_dict, custom_headers) @@ -419,6 +418,11 @@ def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers: if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers: headers[idempotency_header] = options.idempotency_key or self._idempotency_key() + # Don't set the retry count header if it was already set or removed by the caller. We check + # `custom_headers`, which can contain `Omit()`, instead of `headers` to account for the removal case. + if "x-stainless-retry-count" not in (header.lower() for header in custom_headers): + headers["x-stainless-retry-count"] = str(retries_taken) + return headers def _prepare_url(self, url: str) -> URL: @@ -440,6 +444,8 @@ def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder: def _build_request( self, options: FinalRequestOptions, + *, + retries_taken: int = 0, ) -> httpx.Request: if log.isEnabledFor(logging.DEBUG): log.debug("Request options: %s", model_dump(options, exclude_unset=True)) @@ -455,9 +461,10 @@ def _build_request( else: raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`") - headers = self._build_headers(options) - params = _merge_mappings(self._custom_query, options.params) + headers = self._build_headers(options, retries_taken=retries_taken) + params = _merge_mappings(self.default_query, options.params) content_type = headers.get("Content-Type") + files = options.files # If the given Content-Type header is multipart/form-data then it # has to be removed so that httpx can generate the header with @@ -471,7 +478,7 @@ def _build_request( headers.pop("Content-Type") # As we are now sending multipart/form-data instead of application/json - # we need to tell httpx to use it, https://www.python-httpx.org/advanced/#multipart-file-encoding + # we need to tell httpx to use it, https://www.python-httpx.org/advanced/clients/#multipart-file-encoding if json_data: if not is_dict(json_data): raise TypeError( @@ -479,19 +486,33 @@ def _build_request( ) kwargs["data"] = self._serialize_multipartform(json_data) + # httpx determines whether or not to send a "multipart/form-data" + # request based on the truthiness of the "files" argument. + # This gets around that issue by generating a dict value that + # evaluates to true. + # + # https://github.com/encode/httpx/discussions/2399#discussioncomment-3814186 + if not files: + files = cast(HttpxRequestFiles, ForceMultipartDict()) + + prepared_url = self._prepare_url(options.url) + if "_" in prepared_url.host: + # work around https://github.com/encode/httpx/discussions/2880 + kwargs["extensions"] = {"sni_hostname": prepared_url.host.replace("_", "-")} + # TODO: report this error to httpx return self._client.build_request( # pyright: ignore[reportUnknownMemberType] headers=headers, timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout, method=options.method, - url=self._prepare_url(options.url), + url=prepared_url, # the `Query` type that we use is incompatible with qs' # `Params` type as it needs to be typed as `Mapping[str, object]` # so that passing a `TypedDict` doesn't cause an error. # https://github.com/microsoft/pyright/issues/3526#event-6715453066 params=self.qs.stringify(cast(Mapping[str, Any], params)) if params else None, json=json_data, - files=options.files, + files=files, **kwargs, ) @@ -592,6 +613,12 @@ def default_headers(self) -> dict[str, str | Omit]: **self._custom_headers, } + @property + def default_query(self) -> dict[str, object]: + return { + **self._custom_query, + } + def _validate_headers( self, headers: Headers, # noqa: ARG002 @@ -616,7 +643,10 @@ def base_url(self, url: URL | str) -> None: self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url)) def platform_headers(self) -> Dict[str, str]: - return platform_headers(self._version) + # the actual implementation is in a separate `lru_cache` decorated + # function because adding `lru_cache` to methods will leak memory + # https://github.com/python/cpython/issues/88476 + return platform_headers(self._version, platform=self._platform) def _parse_retry_after_header(self, response_headers: Optional[httpx.Headers] = None) -> float | None: """Returns a float of the number of seconds (not milliseconds) to wait after retrying, or None if unspecified. @@ -665,7 +695,8 @@ def _calculate_retry_timeout( if retry_after is not None and 0 < retry_after <= 60: return retry_after - nb_retries = max_retries - remaining_retries + # Also cap retry count to 1000 to avoid any potential overflows with `pow` + nb_retries = min(max_retries - remaining_retries, 1000) # Apply exponential backoff, but not more than the max. sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY) @@ -858,9 +889,9 @@ def __exit__( def _prepare_options( self, options: FinalRequestOptions, # noqa: ARG002 - ) -> None: + ) -> FinalRequestOptions: """Hook for mutating the given options""" - return None + return options def _prepare_request( self, @@ -882,8 +913,7 @@ def request( *, stream: Literal[True], stream_cls: Type[_StreamT], - ) -> _StreamT: - ... + ) -> _StreamT: ... @overload def request( @@ -893,8 +923,7 @@ def request( remaining_retries: Optional[int] = None, *, stream: Literal[False] = False, - ) -> ResponseT: - ... + ) -> ResponseT: ... @overload def request( @@ -905,8 +934,7 @@ def request( *, stream: bool = False, stream_cls: Type[_StreamT] | None = None, - ) -> ResponseT | _StreamT: - ... + ) -> ResponseT | _StreamT: ... def request( self, @@ -917,12 +945,17 @@ def request( stream: bool = False, stream_cls: type[_StreamT] | None = None, ) -> ResponseT | _StreamT: + if remaining_retries is not None: + retries_taken = options.get_max_retries(self.max_retries) - remaining_retries + else: + retries_taken = 0 + return self._request( cast_to=cast_to, options=options, stream=stream, stream_cls=stream_cls, - remaining_retries=remaining_retries, + retries_taken=retries_taken, ) def _request( @@ -930,15 +963,20 @@ def _request( *, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: int | None, + retries_taken: int, stream: bool, stream_cls: type[_StreamT] | None, ) -> ResponseT | _StreamT: + # create a copy of the options we were given so that if the + # options are mutated later & we then retry, the retries are + # given the original options + input_options = model_copy(options) + cast_to = self._maybe_override_cast_to(cast_to, options) - self._prepare_options(options) + options = self._prepare_options(options) - retries = self._remaining_retries(remaining_retries, options) - request = self._build_request(options) + remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + request = self._build_request(options, retries_taken=retries_taken) self._prepare_request(request) kwargs: HttpxSendArgs = {} @@ -956,11 +994,11 @@ def _request( except httpx.TimeoutException as err: log.debug("Encountered httpx.TimeoutException", exc_info=True) - if retries > 0: + if remaining_retries > 0: return self._retry_request( - options, + input_options, cast_to, - retries, + retries_taken=retries_taken, stream=stream, stream_cls=stream_cls, response_headers=None, @@ -971,11 +1009,11 @@ def _request( except Exception as err: log.debug("Encountered Exception", exc_info=True) - if retries > 0: + if remaining_retries > 0: return self._retry_request( - options, + input_options, cast_to, - retries, + retries_taken=retries_taken, stream=stream, stream_cls=stream_cls, response_headers=None, @@ -998,13 +1036,13 @@ def _request( except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code log.debug("Encountered httpx.HTTPStatusError", exc_info=True) - if retries > 0 and self._should_retry(err.response): + if remaining_retries > 0 and self._should_retry(err.response): err.response.close() return self._retry_request( - options, + input_options, cast_to, - retries, - err.response.headers, + retries_taken=retries_taken, + response_headers=err.response.headers, stream=stream, stream_cls=stream_cls, ) @@ -1023,25 +1061,26 @@ def _request( response=response, stream=stream, stream_cls=stream_cls, + retries_taken=retries_taken, ) def _retry_request( self, options: FinalRequestOptions, cast_to: Type[ResponseT], - remaining_retries: int, - response_headers: httpx.Headers | None, *, + retries_taken: int, + response_headers: httpx.Headers | None, stream: bool, stream_cls: type[_StreamT] | None, ) -> ResponseT | _StreamT: - remaining = remaining_retries - 1 - if remaining == 1: + remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + if remaining_retries == 1: log.debug("1 retry left") else: - log.debug("%i retries left", remaining) + log.debug("%i retries left", remaining_retries) - timeout = self._calculate_retry_timeout(remaining, options, response_headers) + timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers) log.info("Retrying request to %s in %f seconds", options.url, timeout) # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a @@ -1051,7 +1090,7 @@ def _retry_request( return self._request( options=options, cast_to=cast_to, - remaining_retries=remaining, + retries_taken=retries_taken + 1, stream=stream, stream_cls=stream_cls, ) @@ -1064,6 +1103,7 @@ def _process_response( response: httpx.Response, stream: bool, stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + retries_taken: int = 0, ) -> ResponseT: origin = get_origin(cast_to) or cast_to @@ -1081,6 +1121,7 @@ def _process_response( stream=stream, stream_cls=stream_cls, options=options, + retries_taken=retries_taken, ), ) @@ -1094,6 +1135,7 @@ def _process_response( stream=stream, stream_cls=stream_cls, options=options, + retries_taken=retries_taken, ) if bool(response.request.headers.get(RAW_RESPONSE_HEADER)): return cast(ResponseT, api_response) @@ -1126,8 +1168,7 @@ def get( cast_to: Type[ResponseT], options: RequestOptions = {}, stream: Literal[False] = False, - ) -> ResponseT: - ... + ) -> ResponseT: ... @overload def get( @@ -1138,8 +1179,7 @@ def get( options: RequestOptions = {}, stream: Literal[True], stream_cls: type[_StreamT], - ) -> _StreamT: - ... + ) -> _StreamT: ... @overload def get( @@ -1150,8 +1190,7 @@ def get( options: RequestOptions = {}, stream: bool, stream_cls: type[_StreamT] | None = None, - ) -> ResponseT | _StreamT: - ... + ) -> ResponseT | _StreamT: ... def get( self, @@ -1177,8 +1216,7 @@ def post( options: RequestOptions = {}, files: RequestFiles | None = None, stream: Literal[False] = False, - ) -> ResponseT: - ... + ) -> ResponseT: ... @overload def post( @@ -1191,8 +1229,7 @@ def post( files: RequestFiles | None = None, stream: Literal[True], stream_cls: type[_StreamT], - ) -> _StreamT: - ... + ) -> _StreamT: ... @overload def post( @@ -1205,8 +1242,7 @@ def post( files: RequestFiles | None = None, stream: bool, stream_cls: type[_StreamT] | None = None, - ) -> ResponseT | _StreamT: - ... + ) -> ResponseT | _StreamT: ... def post( self, @@ -1416,9 +1452,9 @@ async def __aexit__( async def _prepare_options( self, options: FinalRequestOptions, # noqa: ARG002 - ) -> None: + ) -> FinalRequestOptions: """Hook for mutating the given options""" - return None + return options async def _prepare_request( self, @@ -1439,8 +1475,7 @@ async def request( *, stream: Literal[False] = False, remaining_retries: Optional[int] = None, - ) -> ResponseT: - ... + ) -> ResponseT: ... @overload async def request( @@ -1451,8 +1486,7 @@ async def request( stream: Literal[True], stream_cls: type[_AsyncStreamT], remaining_retries: Optional[int] = None, - ) -> _AsyncStreamT: - ... + ) -> _AsyncStreamT: ... @overload async def request( @@ -1463,8 +1497,7 @@ async def request( stream: bool, stream_cls: type[_AsyncStreamT] | None = None, remaining_retries: Optional[int] = None, - ) -> ResponseT | _AsyncStreamT: - ... + ) -> ResponseT | _AsyncStreamT: ... async def request( self, @@ -1475,12 +1508,17 @@ async def request( stream_cls: type[_AsyncStreamT] | None = None, remaining_retries: Optional[int] = None, ) -> ResponseT | _AsyncStreamT: + if remaining_retries is not None: + retries_taken = options.get_max_retries(self.max_retries) - remaining_retries + else: + retries_taken = 0 + return await self._request( cast_to=cast_to, options=options, stream=stream, stream_cls=stream_cls, - remaining_retries=remaining_retries, + retries_taken=retries_taken, ) async def _request( @@ -1490,13 +1528,23 @@ async def _request( *, stream: bool, stream_cls: type[_AsyncStreamT] | None, - remaining_retries: int | None, + retries_taken: int, ) -> ResponseT | _AsyncStreamT: + if self._platform is None: + # `get_platform` can make blocking IO calls so we + # execute it earlier while we are in an async context + self._platform = await asyncify(get_platform)() + + # create a copy of the options we were given so that if the + # options are mutated later & we then retry, the retries are + # given the original options + input_options = model_copy(options) + cast_to = self._maybe_override_cast_to(cast_to, options) - await self._prepare_options(options) + options = await self._prepare_options(options) - retries = self._remaining_retries(remaining_retries, options) - request = self._build_request(options) + remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + request = self._build_request(options, retries_taken=retries_taken) await self._prepare_request(request) kwargs: HttpxSendArgs = {} @@ -1512,11 +1560,11 @@ async def _request( except httpx.TimeoutException as err: log.debug("Encountered httpx.TimeoutException", exc_info=True) - if retries > 0: + if remaining_retries > 0: return await self._retry_request( - options, + input_options, cast_to, - retries, + retries_taken=retries_taken, stream=stream, stream_cls=stream_cls, response_headers=None, @@ -1527,11 +1575,11 @@ async def _request( except Exception as err: log.debug("Encountered Exception", exc_info=True) - if retries > 0: + if remaining_retries > 0: return await self._retry_request( - options, + input_options, cast_to, - retries, + retries_taken=retries_taken, stream=stream, stream_cls=stream_cls, response_headers=None, @@ -1549,13 +1597,13 @@ async def _request( except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code log.debug("Encountered httpx.HTTPStatusError", exc_info=True) - if retries > 0 and self._should_retry(err.response): + if remaining_retries > 0 and self._should_retry(err.response): await err.response.aclose() return await self._retry_request( - options, + input_options, cast_to, - retries, - err.response.headers, + retries_taken=retries_taken, + response_headers=err.response.headers, stream=stream, stream_cls=stream_cls, ) @@ -1574,25 +1622,26 @@ async def _request( response=response, stream=stream, stream_cls=stream_cls, + retries_taken=retries_taken, ) async def _retry_request( self, options: FinalRequestOptions, cast_to: Type[ResponseT], - remaining_retries: int, - response_headers: httpx.Headers | None, *, + retries_taken: int, + response_headers: httpx.Headers | None, stream: bool, stream_cls: type[_AsyncStreamT] | None, ) -> ResponseT | _AsyncStreamT: - remaining = remaining_retries - 1 - if remaining == 1: + remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + if remaining_retries == 1: log.debug("1 retry left") else: - log.debug("%i retries left", remaining) + log.debug("%i retries left", remaining_retries) - timeout = self._calculate_retry_timeout(remaining, options, response_headers) + timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers) log.info("Retrying request to %s in %f seconds", options.url, timeout) await anyio.sleep(timeout) @@ -1600,7 +1649,7 @@ async def _retry_request( return await self._request( options=options, cast_to=cast_to, - remaining_retries=remaining, + retries_taken=retries_taken + 1, stream=stream, stream_cls=stream_cls, ) @@ -1613,6 +1662,7 @@ async def _process_response( response: httpx.Response, stream: bool, stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + retries_taken: int = 0, ) -> ResponseT: origin = get_origin(cast_to) or cast_to @@ -1630,6 +1680,7 @@ async def _process_response( stream=stream, stream_cls=stream_cls, options=options, + retries_taken=retries_taken, ), ) @@ -1643,6 +1694,7 @@ async def _process_response( stream=stream, stream_cls=stream_cls, options=options, + retries_taken=retries_taken, ) if bool(response.request.headers.get(RAW_RESPONSE_HEADER)): return cast(ResponseT, api_response) @@ -1665,8 +1717,7 @@ async def get( cast_to: Type[ResponseT], options: RequestOptions = {}, stream: Literal[False] = False, - ) -> ResponseT: - ... + ) -> ResponseT: ... @overload async def get( @@ -1677,8 +1728,7 @@ async def get( options: RequestOptions = {}, stream: Literal[True], stream_cls: type[_AsyncStreamT], - ) -> _AsyncStreamT: - ... + ) -> _AsyncStreamT: ... @overload async def get( @@ -1689,8 +1739,7 @@ async def get( options: RequestOptions = {}, stream: bool, stream_cls: type[_AsyncStreamT] | None = None, - ) -> ResponseT | _AsyncStreamT: - ... + ) -> ResponseT | _AsyncStreamT: ... async def get( self, @@ -1714,8 +1763,7 @@ async def post( files: RequestFiles | None = None, options: RequestOptions = {}, stream: Literal[False] = False, - ) -> ResponseT: - ... + ) -> ResponseT: ... @overload async def post( @@ -1728,8 +1776,7 @@ async def post( options: RequestOptions = {}, stream: Literal[True], stream_cls: type[_AsyncStreamT], - ) -> _AsyncStreamT: - ... + ) -> _AsyncStreamT: ... @overload async def post( @@ -1742,8 +1789,7 @@ async def post( options: RequestOptions = {}, stream: bool, stream_cls: type[_AsyncStreamT] | None = None, - ) -> ResponseT | _AsyncStreamT: - ... + ) -> ResponseT | _AsyncStreamT: ... async def post( self, @@ -1848,6 +1894,11 @@ def make_request_options( return options +class ForceMultipartDict(Dict[str, None]): + def __bool__(self) -> bool: + return True + + class OtherPlatform: def __init__(self, name: str) -> None: self.name = name @@ -1915,11 +1966,11 @@ def get_platform() -> Platform: @lru_cache(maxsize=None) -def platform_headers(version: str) -> Dict[str, str]: +def platform_headers(version: str, *, platform: Platform | None) -> Dict[str, str]: return { "X-Stainless-Lang": "python", "X-Stainless-Package-Version": version, - "X-Stainless-OS": str(get_platform()), + "X-Stainless-OS": str(platform or get_platform()), "X-Stainless-Arch": str(get_architecture()), "X-Stainless-Runtime": get_python_runtime(), "X-Stainless-Runtime-Version": get_python_version(), @@ -1954,7 +2005,6 @@ def get_python_version() -> str: def get_architecture() -> Arch: try: - python_bitness, _ = platform.architecture() machine = platform.machine().lower() except Exception: return "unknown" @@ -1970,7 +2020,7 @@ def get_architecture() -> Arch: return "x64" # TODO: untested - if python_bitness == "32bit": + if sys.maxsize <= 2**32: return "x32" if machine: diff --git a/src/maisa/_compat.py b/src/maisa/_compat.py index 74c7639b..4794129c 100644 --- a/src/maisa/_compat.py +++ b/src/maisa/_compat.py @@ -2,12 +2,12 @@ from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload from datetime import date, datetime -from typing_extensions import Self +from typing_extensions import Self, Literal import pydantic from pydantic.fields import FieldInfo -from ._types import StrBytesIntFloat +from ._types import IncEx, StrBytesIntFloat _T = TypeVar("_T") _ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) @@ -118,10 +118,10 @@ def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]: return model.__fields__ # type: ignore -def model_copy(model: _ModelT) -> _ModelT: +def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT: if PYDANTIC_V2: - return model.model_copy() - return model.copy() # type: ignore + return model.model_copy(deep=deep) + return model.copy(deep=deep) # type: ignore def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: @@ -133,17 +133,24 @@ def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: def model_dump( model: pydantic.BaseModel, *, + exclude: IncEx | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, + warnings: bool = True, + mode: Literal["json", "python"] = "python", ) -> dict[str, Any]: - if PYDANTIC_V2: + if PYDANTIC_V2 or hasattr(model, "model_dump"): return model.model_dump( + mode=mode, + exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, + warnings=warnings, ) return cast( "dict[str, Any]", model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, ), @@ -159,22 +166,19 @@ def model_parse(model: type[_ModelT], data: Any) -> _ModelT: # generic models if TYPE_CHECKING: - class GenericModel(pydantic.BaseModel): - ... + class GenericModel(pydantic.BaseModel): ... else: if PYDANTIC_V2: # there no longer needs to be a distinction in v2 but # we still have to create our own subclass to avoid # inconsistent MRO ordering errors - class GenericModel(pydantic.BaseModel): - ... + class GenericModel(pydantic.BaseModel): ... else: import pydantic.generics - class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): - ... + class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ... # cached properties @@ -193,26 +197,21 @@ class typed_cached_property(Generic[_T]): func: Callable[[Any], _T] attrname: str | None - def __init__(self, func: Callable[[Any], _T]) -> None: - ... + def __init__(self, func: Callable[[Any], _T]) -> None: ... @overload - def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: - ... + def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ... @overload - def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: - ... + def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ... def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self: raise NotImplementedError() - def __set_name__(self, owner: type[Any], name: str) -> None: - ... + def __set_name__(self, owner: type[Any], name: str) -> None: ... # __set__ is not defined at runtime, but @cached_property is designed to be settable - def __set__(self, instance: object, value: _T) -> None: - ... + def __set__(self, instance: object, value: _T) -> None: ... else: try: from functools import cached_property as cached_property diff --git a/src/maisa/_files.py b/src/maisa/_files.py index 0d2022ae..715cc207 100644 --- a/src/maisa/_files.py +++ b/src/maisa/_files.py @@ -39,13 +39,11 @@ def assert_is_file_content(obj: object, *, key: str | None = None) -> None: @overload -def to_httpx_files(files: None) -> None: - ... +def to_httpx_files(files: None) -> None: ... @overload -def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: - ... +def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ... def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: @@ -83,13 +81,11 @@ def _read_file_content(file: FileContent) -> HttpxFileContent: @overload -async def async_to_httpx_files(files: None) -> None: - ... +async def async_to_httpx_files(files: None) -> None: ... @overload -async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: - ... +async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ... async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: diff --git a/src/maisa/_models.py b/src/maisa/_models.py index ff3f54e2..6cb469e2 100644 --- a/src/maisa/_models.py +++ b/src/maisa/_models.py @@ -10,6 +10,7 @@ ClassVar, Protocol, Required, + ParamSpec, TypedDict, TypeGuard, final, @@ -36,6 +37,7 @@ PropertyInfo, is_list, is_given, + json_safe, lru_cache, is_mapping, parse_date, @@ -62,11 +64,14 @@ from ._constants import RAW_RESPONSE_HEADER if TYPE_CHECKING: - from pydantic_core.core_schema import ModelField, ModelFieldsSchema + from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema __all__ = ["BaseModel", "GenericModel"] _T = TypeVar("_T") +_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel") + +P = ParamSpec("P") @runtime_checkable @@ -172,7 +177,7 @@ def __str__(self) -> str: # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836. @classmethod @override - def construct( + def construct( # pyright: ignore[reportIncompatibleMethodOverride] cls: Type[ModelT], _fields_set: set[str] | None = None, **values: object, @@ -244,14 +249,16 @@ def model_dump( self, *, mode: Literal["json", "python"] | str = "python", - include: IncEx = None, - exclude: IncEx = None, + include: IncEx | None = None, + exclude: IncEx | None = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, - warnings: bool = True, + warnings: bool | Literal["none", "warn", "error"] = True, + context: dict[str, Any] | None = None, + serialize_as_any: bool = False, ) -> dict[str, Any]: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump @@ -273,13 +280,17 @@ def model_dump( Returns: A dictionary representation of the model. """ - if mode != "python": - raise ValueError("mode is only supported in Pydantic v2") + if mode not in {"json", "python"}: + raise ValueError("mode must be either 'json' or 'python'") if round_trip != False: raise ValueError("round_trip is only supported in Pydantic v2") if warnings != True: raise ValueError("warnings is only supported in Pydantic v2") - return super().dict( # pyright: ignore[reportDeprecated] + if context is not None: + raise ValueError("context is only supported in Pydantic v2") + if serialize_as_any != False: + raise ValueError("serialize_as_any is only supported in Pydantic v2") + dumped = super().dict( # pyright: ignore[reportDeprecated] include=include, exclude=exclude, by_alias=by_alias, @@ -288,19 +299,23 @@ def model_dump( exclude_none=exclude_none, ) + return cast(dict[str, Any], json_safe(dumped)) if mode == "json" else dumped + @override def model_dump_json( self, *, indent: int | None = None, - include: IncEx = None, - exclude: IncEx = None, + include: IncEx | None = None, + exclude: IncEx | None = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, - warnings: bool = True, + warnings: bool | Literal["none", "warn", "error"] = True, + context: dict[str, Any] | None = None, + serialize_as_any: bool = False, ) -> str: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json @@ -324,6 +339,10 @@ def model_dump_json( raise ValueError("round_trip is only supported in Pydantic v2") if warnings != True: raise ValueError("warnings is only supported in Pydantic v2") + if context is not None: + raise ValueError("context is only supported in Pydantic v2") + if serialize_as_any != False: + raise ValueError("serialize_as_any is only supported in Pydantic v2") return super().json( # type: ignore[reportDeprecated] indent=indent, include=include, @@ -364,9 +383,43 @@ def is_basemodel(type_: type) -> bool: def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]: origin = get_origin(type_) or type_ + if not inspect.isclass(origin): + return False return issubclass(origin, BaseModel) or issubclass(origin, GenericModel) +def build( + base_model_cls: Callable[P, _BaseModelT], + *args: P.args, + **kwargs: P.kwargs, +) -> _BaseModelT: + """Construct a BaseModel class without validation. + + This is useful for cases where you need to instantiate a `BaseModel` + from an API response as this provides type-safe params which isn't supported + by helpers like `construct_type()`. + + ```py + build(MyModel, my_field_a="foo", my_field_b=123) + ``` + """ + if args: + raise TypeError( + "Received positional arguments which are not supported; Keyword arguments must be used instead", + ) + + return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs)) + + +def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T: + """Loose coercion to the expected type with construction of nested values. + + Note: the returned value from this function is not guaranteed to match the + given type. + """ + return cast(_T, construct_type(value=value, type_=type_)) + + def construct_type(*, value: object, type_: object) -> object: """Loose coercion to the expected type with construction of nested values. @@ -550,7 +603,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, field_schema = field["schema"] if field_schema["type"] == "literal": - for entry in field_schema["expected"]: + for entry in cast("LiteralSchema", field_schema)["expected"]: if isinstance(entry, str): mapping[entry] = variant else: @@ -604,6 +657,14 @@ def validate_type(*, type_: type[_T], value: object) -> _T: return cast(_T, _validate_non_model_type(type_=type_, value=value)) +def set_pydantic_config(typ: Any, config: pydantic.ConfigDict) -> None: + """Add a pydantic config for the given type. + + Note: this is a no-op on Pydantic v1. + """ + setattr(typ, "__pydantic_config__", config) # noqa: B010 + + # our use of subclasssing here causes weirdness for type checkers, # so we just pretend that we don't subclass if TYPE_CHECKING: diff --git a/src/maisa/_response.py b/src/maisa/_response.py index 0a3b276a..48f1a4ab 100644 --- a/src/maisa/_response.py +++ b/src/maisa/_response.py @@ -55,6 +55,9 @@ class BaseAPIResponse(Generic[R]): http_response: httpx.Response + retries_taken: int + """The number of retries made. If no retries happened this will be `0`""" + def __init__( self, *, @@ -64,6 +67,7 @@ def __init__( stream: bool, stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, options: FinalRequestOptions, + retries_taken: int = 0, ) -> None: self._cast_to = cast_to self._client = client @@ -72,6 +76,7 @@ def __init__( self._stream_cls = stream_cls self._options = options self.http_response = raw + self.retries_taken = retries_taken @property def headers(self) -> httpx.Headers: @@ -187,6 +192,9 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: if cast_to == float: return cast(R, float(response.text)) + if cast_to == bool: + return cast(R, response.text.lower() == "true") + origin = get_origin(cast_to) or cast_to if origin == APIResponse: @@ -255,12 +263,10 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: class APIResponse(BaseAPIResponse[R]): @overload - def parse(self, *, to: type[_T]) -> _T: - ... + def parse(self, *, to: type[_T]) -> _T: ... @overload - def parse(self) -> R: - ... + def parse(self) -> R: ... def parse(self, *, to: type[_T] | None = None) -> R | _T: """Returns the rich python representation of this response's data. @@ -359,12 +365,10 @@ def iter_lines(self) -> Iterator[str]: class AsyncAPIResponse(BaseAPIResponse[R]): @overload - async def parse(self, *, to: type[_T]) -> _T: - ... + async def parse(self, *, to: type[_T]) -> _T: ... @overload - async def parse(self) -> R: - ... + async def parse(self) -> R: ... async def parse(self, *, to: type[_T] | None = None) -> R | _T: """Returns the rich python representation of this response's data. diff --git a/src/maisa/_types.py b/src/maisa/_types.py index 94a83c8a..a865073d 100644 --- a/src/maisa/_types.py +++ b/src/maisa/_types.py @@ -16,7 +16,7 @@ Optional, Sequence, ) -from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable +from typing_extensions import Set, Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable import httpx import pydantic @@ -111,8 +111,7 @@ class NotGiven: For example: ```py - def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: - ... + def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ... get(timeout=1) # 1s timeout @@ -162,16 +161,14 @@ def build( *, response: Response, data: object, - ) -> _T: - ... + ) -> _T: ... Headers = Mapping[str, Union[str, Omit]] class HeadersLikeProtocol(Protocol): - def get(self, __key: str) -> str | None: - ... + def get(self, __key: str) -> str | None: ... HeadersLike = Union[Headers, HeadersLikeProtocol] @@ -196,7 +193,9 @@ def get(self, __key: str) -> str | None: # Note: copied from Pydantic # https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49 -IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None" +IncEx: TypeAlias = Union[ + Set[int], Set[str], Mapping[int, Union["IncEx", Literal[True]]], Mapping[str, Union["IncEx", Literal[True]]] +] PostParser = Callable[[Any], Any] diff --git a/src/maisa/_utils/__init__.py b/src/maisa/_utils/__init__.py index 31b5b227..a7cff3c0 100644 --- a/src/maisa/_utils/__init__.py +++ b/src/maisa/_utils/__init__.py @@ -6,6 +6,7 @@ is_list as is_list, is_given as is_given, is_tuple as is_tuple, + json_safe as json_safe, lru_cache as lru_cache, is_mapping as is_mapping, is_tuple_t as is_tuple_t, @@ -49,3 +50,7 @@ maybe_transform as maybe_transform, async_maybe_transform as async_maybe_transform, ) +from ._reflection import ( + function_has_argument as function_has_argument, + assert_signatures_in_sync as assert_signatures_in_sync, +) diff --git a/src/maisa/_utils/_proxy.py b/src/maisa/_utils/_proxy.py index c46a62a6..ffd883e9 100644 --- a/src/maisa/_utils/_proxy.py +++ b/src/maisa/_utils/_proxy.py @@ -59,5 +59,4 @@ def __as_proxied__(self) -> T: return cast(T, self) @abstractmethod - def __load__(self) -> T: - ... + def __load__(self) -> T: ... diff --git a/src/maisa/_utils/_reflection.py b/src/maisa/_utils/_reflection.py new file mode 100644 index 00000000..89aa712a --- /dev/null +++ b/src/maisa/_utils/_reflection.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import inspect +from typing import Any, Callable + + +def function_has_argument(func: Callable[..., Any], arg_name: str) -> bool: + """Returns whether or not the given function has a specific parameter""" + sig = inspect.signature(func) + return arg_name in sig.parameters + + +def assert_signatures_in_sync( + source_func: Callable[..., Any], + check_func: Callable[..., Any], + *, + exclude_params: set[str] = set(), +) -> None: + """Ensure that the signature of the second function matches the first.""" + + check_sig = inspect.signature(check_func) + source_sig = inspect.signature(source_func) + + errors: list[str] = [] + + for name, source_param in source_sig.parameters.items(): + if name in exclude_params: + continue + + custom_param = check_sig.parameters.get(name) + if not custom_param: + errors.append(f"the `{name}` param is missing") + continue + + if custom_param.annotation != source_param.annotation: + errors.append( + f"types for the `{name}` param are do not match; source={repr(source_param.annotation)} checking={repr(custom_param.annotation)}" + ) + continue + + if errors: + raise AssertionError(f"{len(errors)} errors encountered when comparing signatures:\n\n" + "\n\n".join(errors)) diff --git a/src/maisa/_utils/_sync.py b/src/maisa/_utils/_sync.py index 595924e5..d0d81033 100644 --- a/src/maisa/_utils/_sync.py +++ b/src/maisa/_utils/_sync.py @@ -7,6 +7,8 @@ import anyio import anyio.to_thread +from ._reflection import function_has_argument + T_Retval = TypeVar("T_Retval") T_ParamSpec = ParamSpec("T_ParamSpec") @@ -59,6 +61,21 @@ def do_work(arg1, arg2, kwarg1="", kwarg2="") -> str: async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval: partial_f = functools.partial(function, *args, **kwargs) - return await anyio.to_thread.run_sync(partial_f, cancellable=cancellable, limiter=limiter) + + # In `v4.1.0` anyio added the `abandon_on_cancel` argument and deprecated the old + # `cancellable` argument, so we need to use the new `abandon_on_cancel` to avoid + # surfacing deprecation warnings. + if function_has_argument(anyio.to_thread.run_sync, "abandon_on_cancel"): + return await anyio.to_thread.run_sync( + partial_f, + abandon_on_cancel=cancellable, + limiter=limiter, + ) + + return await anyio.to_thread.run_sync( + partial_f, + cancellable=cancellable, + limiter=limiter, + ) return wrapper diff --git a/src/maisa/_utils/_transform.py b/src/maisa/_utils/_transform.py index 47e262a5..d7c05345 100644 --- a/src/maisa/_utils/_transform.py +++ b/src/maisa/_utils/_transform.py @@ -173,6 +173,11 @@ def _transform_recursive( # Iterable[T] or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) ): + # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually + # intended as an iterable, so we don't transform it. + if isinstance(data, dict): + return cast(object, data) + inner_type = extract_type_arg(stripped_type, 0) return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] @@ -186,7 +191,7 @@ def _transform_recursive( return data if isinstance(data, pydantic.BaseModel): - return model_dump(data, exclude_unset=True) + return model_dump(data, exclude_unset=True, mode="json") annotated_type = _get_annotated_type(annotation) if annotated_type is None: @@ -324,7 +329,7 @@ async def _async_transform_recursive( return data if isinstance(data, pydantic.BaseModel): - return model_dump(data, exclude_unset=True) + return model_dump(data, exclude_unset=True, mode="json") annotated_type = _get_annotated_type(annotation) if annotated_type is None: diff --git a/src/maisa/_utils/_utils.py b/src/maisa/_utils/_utils.py index 17904ce6..e5811bba 100644 --- a/src/maisa/_utils/_utils.py +++ b/src/maisa/_utils/_utils.py @@ -16,11 +16,12 @@ overload, ) from pathlib import Path +from datetime import date, datetime from typing_extensions import TypeGuard import sniffio -from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike +from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike from .._compat import parse_date as parse_date, parse_datetime as parse_datetime _T = TypeVar("_T") @@ -211,20 +212,17 @@ def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: Example usage: ```py @overload - def foo(*, a: str) -> str: - ... + def foo(*, a: str) -> str: ... @overload - def foo(*, b: bool) -> str: - ... + def foo(*, b: bool) -> str: ... # This enforces the same constraints that a static type checker would # i.e. that either a or b must be passed to the function @required_args(["a"], ["b"]) - def foo(*, a: str | None = None, b: bool | None = None) -> str: - ... + def foo(*, a: str | None = None, b: bool | None = None) -> str: ... ``` """ @@ -286,18 +284,15 @@ def wrapper(*args: object, **kwargs: object) -> object: @overload -def strip_not_given(obj: None) -> None: - ... +def strip_not_given(obj: None) -> None: ... @overload -def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: - ... +def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ... @overload -def strip_not_given(obj: object) -> object: - ... +def strip_not_given(obj: object) -> object: ... def strip_not_given(obj: object | None) -> object: @@ -369,13 +364,13 @@ def file_from_path(path: str) -> FileTypes: def get_required_header(headers: HeadersLike, header: str) -> str: lower_header = header.lower() - if isinstance(headers, Mapping): - headers = cast(Headers, headers) - for k, v in headers.items(): + if is_mapping_t(headers): + # mypy doesn't understand the type narrowing here + for k, v in headers.items(): # type: ignore if k.lower() == lower_header and isinstance(v, str): return v - """ to deal with the case where the header looks like Stainless-Event-Id """ + # to deal with the case where the header looks like Stainless-Event-Id intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize()) for normalized_header in [header, lower_header, header.upper(), intercaps_header]: @@ -401,3 +396,19 @@ def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]: maxsize=maxsize, ) return cast(Any, wrapper) # type: ignore[no-any-return] + + +def json_safe(data: object) -> object: + """Translates a mapping / sequence recursively in the same fashion + as `pydantic` v2's `model_dump(mode="json")`. + """ + if is_mapping(data): + return {json_safe(key): json_safe(value) for key, value in data.items()} + + if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)): + return [json_safe(item) for item in data] + + if isinstance(data, (datetime, date)): + return data.isoformat() + + return data diff --git a/src/maisa/resources/capabilities/capabilities.py b/src/maisa/resources/capabilities/capabilities.py index e8989be5..7c760b4e 100644 --- a/src/maisa/resources/capabilities/capabilities.py +++ b/src/maisa/resources/capabilities/capabilities.py @@ -29,9 +29,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from ..._base_client import ( - make_request_options, -) +from ..._base_client import make_request_options from ...types.shared.text_summary import TextSummary from ...types.shared.text_extractor import TextExtractor from ...types.shared.text_comparator import TextComparator @@ -46,10 +44,21 @@ def media(self) -> MediaResource: @cached_property def with_raw_response(self) -> CapabilitiesResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return CapabilitiesResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> CapabilitiesResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return CapabilitiesResourceWithStreamingResponse(self) def compare( @@ -224,10 +233,21 @@ def media(self) -> AsyncMediaResource: @cached_property def with_raw_response(self) -> AsyncCapabilitiesResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return AsyncCapabilitiesResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> AsyncCapabilitiesResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return AsyncCapabilitiesResourceWithStreamingResponse(self) async def compare( diff --git a/src/maisa/resources/capabilities/media.py b/src/maisa/resources/capabilities/media.py index 04e41a0e..f9287395 100644 --- a/src/maisa/resources/capabilities/media.py +++ b/src/maisa/resources/capabilities/media.py @@ -22,9 +22,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from ..._base_client import ( - make_request_options, -) +from ..._base_client import make_request_options from ...types.capabilities import media_compare_params, media_extract_params, media_summarize_params from ...types.shared.text_summary import TextSummary from ...types.shared.text_extractor import TextExtractor @@ -36,10 +34,21 @@ class MediaResource(SyncAPIResource): @cached_property def with_raw_response(self) -> MediaResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return MediaResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> MediaResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return MediaResourceWithStreamingResponse(self) def compare( @@ -133,11 +142,10 @@ def compare( } ) files = extract_files(cast(Mapping[str, object], body), paths=[["file1"], ["file2"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return self._post( "/v1/capabilities/compare/media", body=maybe_transform(body, media_compare_params.MediaCompareParams), @@ -233,11 +241,10 @@ def extract( } ) files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return self._post( "/v1/capabilities/extract/media", body=maybe_transform(body, media_extract_params.MediaExtractParams), @@ -297,11 +304,10 @@ def summarize( } ) files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return self._post( "/v1/capabilities/summarize/media", body=maybe_transform(body, media_summarize_params.MediaSummarizeParams), @@ -316,10 +322,21 @@ def summarize( class AsyncMediaResource(AsyncAPIResource): @cached_property def with_raw_response(self) -> AsyncMediaResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return AsyncMediaResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> AsyncMediaResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return AsyncMediaResourceWithStreamingResponse(self) async def compare( @@ -413,11 +430,10 @@ async def compare( } ) files = extract_files(cast(Mapping[str, object], body), paths=[["file1"], ["file2"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/capabilities/compare/media", body=await async_maybe_transform(body, media_compare_params.MediaCompareParams), @@ -513,11 +529,10 @@ async def extract( } ) files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/capabilities/extract/media", body=await async_maybe_transform(body, media_extract_params.MediaExtractParams), @@ -577,11 +592,10 @@ async def summarize( } ) files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/capabilities/summarize/media", body=await async_maybe_transform(body, media_summarize_params.MediaSummarizeParams), diff --git a/src/maisa/resources/file_interpreter/file_interpreter.py b/src/maisa/resources/file_interpreter/file_interpreter.py index f94c2a6e..6519a992 100644 --- a/src/maisa/resources/file_interpreter/file_interpreter.py +++ b/src/maisa/resources/file_interpreter/file_interpreter.py @@ -71,10 +71,21 @@ def from_audio(self) -> FromAudioResource: @cached_property def with_raw_response(self) -> FileInterpreterResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return FileInterpreterResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> FileInterpreterResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return FileInterpreterResourceWithStreamingResponse(self) @@ -101,10 +112,21 @@ def from_audio(self) -> AsyncFromAudioResource: @cached_property def with_raw_response(self) -> AsyncFileInterpreterResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return AsyncFileInterpreterResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> AsyncFileInterpreterResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return AsyncFileInterpreterResourceWithStreamingResponse(self) diff --git a/src/maisa/resources/file_interpreter/from_audio.py b/src/maisa/resources/file_interpreter/from_audio.py index 96399bcd..73c52414 100644 --- a/src/maisa/resources/file_interpreter/from_audio.py +++ b/src/maisa/resources/file_interpreter/from_audio.py @@ -21,9 +21,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from ..._base_client import ( - make_request_options, -) +from ..._base_client import make_request_options from ...types.file_interpreter import from_audio_create_params __all__ = ["FromAudioResource", "AsyncFromAudioResource"] @@ -32,10 +30,21 @@ class FromAudioResource(SyncAPIResource): @cached_property def with_raw_response(self) -> FromAudioResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return FromAudioResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> FromAudioResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return FromAudioResourceWithStreamingResponse(self) def create( @@ -63,11 +72,10 @@ def create( """ body = deepcopy_minimal({"file": file}) files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return self._post( "/v1/file-interpreter/from-audio", body=maybe_transform(body, from_audio_create_params.FromAudioCreateParams), @@ -82,10 +90,21 @@ def create( class AsyncFromAudioResource(AsyncAPIResource): @cached_property def with_raw_response(self) -> AsyncFromAudioResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return AsyncFromAudioResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> AsyncFromAudioResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return AsyncFromAudioResourceWithStreamingResponse(self) async def create( @@ -113,11 +132,10 @@ async def create( """ body = deepcopy_minimal({"file": file}) files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/file-interpreter/from-audio", body=await async_maybe_transform(body, from_audio_create_params.FromAudioCreateParams), diff --git a/src/maisa/resources/file_interpreter/from_docx.py b/src/maisa/resources/file_interpreter/from_docx.py index eec6adaf..ca0b6ee0 100644 --- a/src/maisa/resources/file_interpreter/from_docx.py +++ b/src/maisa/resources/file_interpreter/from_docx.py @@ -21,9 +21,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from ..._base_client import ( - make_request_options, -) +from ..._base_client import make_request_options from ...types.file_interpreter import from_docx_create_params __all__ = ["FromDocxResource", "AsyncFromDocxResource"] @@ -32,10 +30,21 @@ class FromDocxResource(SyncAPIResource): @cached_property def with_raw_response(self) -> FromDocxResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return FromDocxResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> FromDocxResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return FromDocxResourceWithStreamingResponse(self) def create( @@ -63,11 +72,10 @@ def create( """ body = deepcopy_minimal({"file": file}) files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return self._post( "/v1/file-interpreter/from-docx", body=maybe_transform(body, from_docx_create_params.FromDocxCreateParams), @@ -82,10 +90,21 @@ def create( class AsyncFromDocxResource(AsyncAPIResource): @cached_property def with_raw_response(self) -> AsyncFromDocxResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return AsyncFromDocxResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> AsyncFromDocxResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return AsyncFromDocxResourceWithStreamingResponse(self) async def create( @@ -113,11 +132,10 @@ async def create( """ body = deepcopy_minimal({"file": file}) files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/file-interpreter/from-docx", body=await async_maybe_transform(body, from_docx_create_params.FromDocxCreateParams), diff --git a/src/maisa/resources/file_interpreter/from_html.py b/src/maisa/resources/file_interpreter/from_html.py index 166bb9a6..ba998348 100644 --- a/src/maisa/resources/file_interpreter/from_html.py +++ b/src/maisa/resources/file_interpreter/from_html.py @@ -21,9 +21,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from ..._base_client import ( - make_request_options, -) +from ..._base_client import make_request_options from ...types.file_interpreter import from_html_create_params __all__ = ["FromHTMLResource", "AsyncFromHTMLResource"] @@ -32,10 +30,21 @@ class FromHTMLResource(SyncAPIResource): @cached_property def with_raw_response(self) -> FromHTMLResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return FromHTMLResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> FromHTMLResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return FromHTMLResourceWithStreamingResponse(self) def create( @@ -63,11 +72,10 @@ def create( """ body = deepcopy_minimal({"file": file}) files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return self._post( "/v1/file-interpreter/from-html", body=maybe_transform(body, from_html_create_params.FromHTMLCreateParams), @@ -82,10 +90,21 @@ def create( class AsyncFromHTMLResource(AsyncAPIResource): @cached_property def with_raw_response(self) -> AsyncFromHTMLResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return AsyncFromHTMLResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> AsyncFromHTMLResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return AsyncFromHTMLResourceWithStreamingResponse(self) async def create( @@ -113,11 +132,10 @@ async def create( """ body = deepcopy_minimal({"file": file}) files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/file-interpreter/from-html", body=await async_maybe_transform(body, from_html_create_params.FromHTMLCreateParams), diff --git a/src/maisa/resources/file_interpreter/from_image.py b/src/maisa/resources/file_interpreter/from_image.py index a35ab876..bc967a10 100644 --- a/src/maisa/resources/file_interpreter/from_image.py +++ b/src/maisa/resources/file_interpreter/from_image.py @@ -21,9 +21,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from ..._base_client import ( - make_request_options, -) +from ..._base_client import make_request_options from ...types.file_interpreter import from_image_create_params __all__ = ["FromImageResource", "AsyncFromImageResource"] @@ -32,10 +30,21 @@ class FromImageResource(SyncAPIResource): @cached_property def with_raw_response(self) -> FromImageResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return FromImageResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> FromImageResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return FromImageResourceWithStreamingResponse(self) def create( @@ -63,11 +72,10 @@ def create( """ body = deepcopy_minimal({"file": file}) files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return self._post( "/v1/file-interpreter/from-image", body=maybe_transform(body, from_image_create_params.FromImageCreateParams), @@ -82,10 +90,21 @@ def create( class AsyncFromImageResource(AsyncAPIResource): @cached_property def with_raw_response(self) -> AsyncFromImageResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return AsyncFromImageResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> AsyncFromImageResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return AsyncFromImageResourceWithStreamingResponse(self) async def create( @@ -113,11 +132,10 @@ async def create( """ body = deepcopy_minimal({"file": file}) files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/file-interpreter/from-image", body=await async_maybe_transform(body, from_image_create_params.FromImageCreateParams), diff --git a/src/maisa/resources/file_interpreter/from_pdf.py b/src/maisa/resources/file_interpreter/from_pdf.py index 9889625a..d8f2a691 100644 --- a/src/maisa/resources/file_interpreter/from_pdf.py +++ b/src/maisa/resources/file_interpreter/from_pdf.py @@ -21,9 +21,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from ..._base_client import ( - make_request_options, -) +from ..._base_client import make_request_options from ...types.file_interpreter import from_pdf_create_params __all__ = ["FromPdfResource", "AsyncFromPdfResource"] @@ -32,10 +30,21 @@ class FromPdfResource(SyncAPIResource): @cached_property def with_raw_response(self) -> FromPdfResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return FromPdfResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> FromPdfResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return FromPdfResourceWithStreamingResponse(self) def create( @@ -64,11 +73,10 @@ def create( """ body = deepcopy_minimal({"file": file}) files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return self._post( "/v1/file-interpreter/from-pdf", body=maybe_transform(body, from_pdf_create_params.FromPdfCreateParams), @@ -87,10 +95,21 @@ def create( class AsyncFromPdfResource(AsyncAPIResource): @cached_property def with_raw_response(self) -> AsyncFromPdfResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return AsyncFromPdfResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> AsyncFromPdfResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return AsyncFromPdfResourceWithStreamingResponse(self) async def create( @@ -119,11 +138,10 @@ async def create( """ body = deepcopy_minimal({"file": file}) files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/file-interpreter/from-pdf", body=await async_maybe_transform(body, from_pdf_create_params.FromPdfCreateParams), diff --git a/src/maisa/resources/kpu.py b/src/maisa/resources/kpu.py index de7e243f..c6b22de2 100644 --- a/src/maisa/resources/kpu.py +++ b/src/maisa/resources/kpu.py @@ -23,9 +23,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from .._base_client import ( - make_request_options, -) +from .._base_client import make_request_options __all__ = ["KpuResource", "AsyncKpuResource"] @@ -33,10 +31,21 @@ class KpuResource(SyncAPIResource): @cached_property def with_raw_response(self) -> KpuResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return KpuResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> KpuResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return KpuResourceWithStreamingResponse(self) def run( @@ -101,11 +110,10 @@ def run( } ) files = extract_files(cast(Mapping[str, object], body), paths=[["file", ""]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return self._post( "/v1/kpu/run", body=maybe_transform(body, kpu_run_params.KpuRunParams), @@ -130,10 +138,21 @@ def run( class AsyncKpuResource(AsyncAPIResource): @cached_property def with_raw_response(self) -> AsyncKpuResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return AsyncKpuResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> AsyncKpuResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return AsyncKpuResourceWithStreamingResponse(self) async def run( @@ -198,11 +217,10 @@ async def run( } ) files = extract_files(cast(Mapping[str, object], body), paths=[["file", ""]]) - if files: - # It should be noted that the actual Content-Type header that will be - # sent to the server will contain a `boundary` parameter, e.g. - # multipart/form-data; boundary=---abc-- - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/kpu/run", body=await async_maybe_transform(body, kpu_run_params.KpuRunParams), diff --git a/src/maisa/resources/models/embeddings.py b/src/maisa/resources/models/embeddings.py index 427afddd..15f0aa17 100644 --- a/src/maisa/resources/models/embeddings.py +++ b/src/maisa/resources/models/embeddings.py @@ -19,9 +19,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from ..._base_client import ( - make_request_options, -) +from ..._base_client import make_request_options from ...types.models import embedding_create_params from ...types.models.embeddings import Embeddings @@ -31,10 +29,21 @@ class EmbeddingsResource(SyncAPIResource): @cached_property def with_raw_response(self) -> EmbeddingsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return EmbeddingsResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> EmbeddingsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return EmbeddingsResourceWithStreamingResponse(self) def create( @@ -75,10 +84,21 @@ def create( class AsyncEmbeddingsResource(AsyncAPIResource): @cached_property def with_raw_response(self) -> AsyncEmbeddingsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return AsyncEmbeddingsResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> AsyncEmbeddingsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return AsyncEmbeddingsResourceWithStreamingResponse(self) async def create( diff --git a/src/maisa/resources/models/models.py b/src/maisa/resources/models/models.py index c03c3690..04e974a9 100644 --- a/src/maisa/resources/models/models.py +++ b/src/maisa/resources/models/models.py @@ -23,10 +23,21 @@ def embeddings(self) -> EmbeddingsResource: @cached_property def with_raw_response(self) -> ModelsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return ModelsResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> ModelsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return ModelsResourceWithStreamingResponse(self) @@ -37,10 +48,21 @@ def embeddings(self) -> AsyncEmbeddingsResource: @cached_property def with_raw_response(self) -> AsyncModelsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/maisaai/python-sdk#accessing-raw-response-data-eg-headers + """ return AsyncModelsResourceWithRawResponse(self) @cached_property def with_streaming_response(self) -> AsyncModelsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/maisaai/python-sdk#with_streaming_response + """ return AsyncModelsResourceWithStreamingResponse(self) diff --git a/src/maisa/types/shared/text_comparator.py b/src/maisa/types/shared/text_comparator.py index 1a052b1d..646f30ad 100644 --- a/src/maisa/types/shared/text_comparator.py +++ b/src/maisa/types/shared/text_comparator.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from ..._models import BaseModel __all__ = ["TextComparator"] diff --git a/src/maisa/types/shared/text_extractor.py b/src/maisa/types/shared/text_extractor.py index 70d694a1..d58d0905 100644 --- a/src/maisa/types/shared/text_extractor.py +++ b/src/maisa/types/shared/text_extractor.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from ..._models import BaseModel __all__ = ["TextExtractor"] diff --git a/src/maisa/types/shared/text_summary.py b/src/maisa/types/shared/text_summary.py index 91e2e375..86d1c11a 100644 --- a/src/maisa/types/shared/text_summary.py +++ b/src/maisa/types/shared/text_summary.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from ..._models import BaseModel __all__ = ["TextSummary"] diff --git a/tests/api_resources/test_capabilities.py b/tests/api_resources/test_capabilities.py index c5df9a99..724f1f6d 100644 --- a/tests/api_resources/test_capabilities.py +++ b/tests/api_resources/test_capabilities.py @@ -24,8 +24,8 @@ def test_method_compare(self, client: Maisa) -> None: text2="Sed ut perspiciatis unde omnis", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, ) @@ -38,8 +38,8 @@ def test_method_compare_with_all_params(self, client: Maisa) -> None: text2="Sed ut perspiciatis unde omnis", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, lang="en", @@ -54,8 +54,8 @@ def test_raw_response_compare(self, client: Maisa) -> None: text2="Sed ut perspiciatis unde omnis", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, ) @@ -72,8 +72,8 @@ def test_streaming_response_compare(self, client: Maisa) -> None: text2="Sed ut perspiciatis unde omnis", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, ) as response: @@ -91,8 +91,8 @@ def test_method_extract(self, client: Maisa) -> None: text="Example long text...", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, ) @@ -104,8 +104,8 @@ def test_method_extract_with_all_params(self, client: Maisa) -> None: text="Example long text...", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, lang="en", @@ -118,8 +118,8 @@ def test_raw_response_extract(self, client: Maisa) -> None: text="Example long text...", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, ) @@ -135,8 +135,8 @@ def test_streaming_response_extract(self, client: Maisa) -> None: text="Example long text...", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, ) as response: @@ -161,7 +161,7 @@ def test_method_summarize_with_all_params(self, client: Maisa) -> None: text="Example long text...", format="paragraph", lang="en", - length="medium", + length="short", summary_hint="Example summary of the text...", ) assert_matches_type(TextSummary, capability, path=["response"]) @@ -201,8 +201,8 @@ async def test_method_compare(self, async_client: AsyncMaisa) -> None: text2="Sed ut perspiciatis unde omnis", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, ) @@ -215,8 +215,8 @@ async def test_method_compare_with_all_params(self, async_client: AsyncMaisa) -> text2="Sed ut perspiciatis unde omnis", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, lang="en", @@ -231,8 +231,8 @@ async def test_raw_response_compare(self, async_client: AsyncMaisa) -> None: text2="Sed ut perspiciatis unde omnis", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, ) @@ -249,8 +249,8 @@ async def test_streaming_response_compare(self, async_client: AsyncMaisa) -> Non text2="Sed ut perspiciatis unde omnis", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, ) as response: @@ -268,8 +268,8 @@ async def test_method_extract(self, async_client: AsyncMaisa) -> None: text="Example long text...", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, ) @@ -281,8 +281,8 @@ async def test_method_extract_with_all_params(self, async_client: AsyncMaisa) -> text="Example long text...", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, lang="en", @@ -295,8 +295,8 @@ async def test_raw_response_extract(self, async_client: AsyncMaisa) -> None: text="Example long text...", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, ) @@ -312,8 +312,8 @@ async def test_streaming_response_extract(self, async_client: AsyncMaisa) -> Non text="Example long text...", variables={ "name": { - "type": "string", "description": "The name of the person.", + "type": "string", } }, ) as response: @@ -338,7 +338,7 @@ async def test_method_summarize_with_all_params(self, async_client: AsyncMaisa) text="Example long text...", format="paragraph", lang="en", - length="medium", + length="short", summary_hint="Example summary of the text...", ) assert_matches_type(TextSummary, capability, path=["response"]) diff --git a/tests/api_resources/test_kpu.py b/tests/api_resources/test_kpu.py index a9f9f1be..f916db66 100644 --- a/tests/api_resources/test_kpu.py +++ b/tests/api_resources/test_kpu.py @@ -19,26 +19,26 @@ class TestKpu: @parametrize def test_method_run(self, client: Maisa) -> None: kpu = client.kpu.run( - query="string", + query="query", ) assert_matches_type(object, kpu, path=["response"]) @parametrize def test_method_run_with_all_params(self, client: Maisa) -> None: kpu = client.kpu.run( - query="string", + query="query", explain_steps=True, retries=1, file=[b"raw file contents", b"raw file contents", b"raw file contents"], reasoner_model="gpt-4-turbo", - reasoner_prompt="string", + reasoner_prompt="reasoner_prompt", ) assert_matches_type(object, kpu, path=["response"]) @parametrize def test_raw_response_run(self, client: Maisa) -> None: response = client.kpu.with_raw_response.run( - query="string", + query="query", ) assert response.is_closed is True @@ -49,7 +49,7 @@ def test_raw_response_run(self, client: Maisa) -> None: @parametrize def test_streaming_response_run(self, client: Maisa) -> None: with client.kpu.with_streaming_response.run( - query="string", + query="query", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -66,26 +66,26 @@ class TestAsyncKpu: @parametrize async def test_method_run(self, async_client: AsyncMaisa) -> None: kpu = await async_client.kpu.run( - query="string", + query="query", ) assert_matches_type(object, kpu, path=["response"]) @parametrize async def test_method_run_with_all_params(self, async_client: AsyncMaisa) -> None: kpu = await async_client.kpu.run( - query="string", + query="query", explain_steps=True, retries=1, file=[b"raw file contents", b"raw file contents", b"raw file contents"], reasoner_model="gpt-4-turbo", - reasoner_prompt="string", + reasoner_prompt="reasoner_prompt", ) assert_matches_type(object, kpu, path=["response"]) @parametrize async def test_raw_response_run(self, async_client: AsyncMaisa) -> None: response = await async_client.kpu.with_raw_response.run( - query="string", + query="query", ) assert response.is_closed is True @@ -96,7 +96,7 @@ async def test_raw_response_run(self, async_client: AsyncMaisa) -> None: @parametrize async def test_streaming_response_run(self, async_client: AsyncMaisa) -> None: async with async_client.kpu.with_streaming_response.run( - query="string", + query="query", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/conftest.py b/tests/conftest.py index a4bcb401..e4a04e40 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,11 @@ from __future__ import annotations import os -import asyncio import logging from typing import TYPE_CHECKING, Iterator, AsyncIterator import pytest +from pytest_asyncio import is_async_test from maisa import Maisa, AsyncMaisa @@ -17,11 +17,13 @@ logging.getLogger("maisa").setLevel(logging.DEBUG) -@pytest.fixture(scope="session") -def event_loop() -> Iterator[asyncio.AbstractEventLoop]: - loop = asyncio.new_event_loop() - yield loop - loop.close() +# automatically add `pytest.mark.asyncio()` to all of our async tests +# so we don't have to add that boilerplate everywhere +def pytest_collection_modifyitems(items: list[pytest.Function]) -> None: + pytest_asyncio_tests = (item for item in items if is_async_test(item)) + session_scope_marker = pytest.mark.asyncio(loop_scope="session") + for async_test in pytest_asyncio_tests: + async_test.add_marker(session_scope_marker, append=False) base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") diff --git a/tests/test_client.py b/tests/test_client.py index aa708a4c..f75c5e2f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -10,6 +10,7 @@ import tracemalloc from typing import Any, Union, cast from unittest import mock +from typing_extensions import Literal import httpx import pytest @@ -17,6 +18,7 @@ from pydantic import ValidationError from maisa import Maisa, AsyncMaisa, APIResponseValidationError +from maisa._types import Omit from maisa._models import BaseModel, FinalRequestOptions from maisa._constants import RAW_RESPONSE_HEADER from maisa._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError @@ -674,6 +676,7 @@ class Model(BaseModel): [3, "", 0.5], [2, "", 0.5 * 2.0], [1, "", 0.5 * 4.0], + [-1100, "", 7.8], # test large number potentially overflowing ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) @@ -715,6 +718,85 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non assert _get_open_connections(self.client) == 0 + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("maisa._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.parametrize("failure_mode", ["status", "exception"]) + def test_retries_taken( + self, + client: Maisa, + failures_before_success: int, + failure_mode: Literal["status", "exception"], + respx_mock: MockRouter, + ) -> None: + client = client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + if failure_mode == "exception": + raise RuntimeError("oops") + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/capabilities/summarize").mock(side_effect=retry_handler) + + response = client.capabilities.with_raw_response.summarize(text="Example long text...") + + assert response.retries_taken == failures_before_success + assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("maisa._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_omit_retry_count_header(self, client: Maisa, failures_before_success: int, respx_mock: MockRouter) -> None: + client = client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/capabilities/summarize").mock(side_effect=retry_handler) + + response = client.capabilities.with_raw_response.summarize( + text="Example long text...", extra_headers={"x-stainless-retry-count": Omit()} + ) + + assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("maisa._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_overwrite_retry_count_header( + self, client: Maisa, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/capabilities/summarize").mock(side_effect=retry_handler) + + response = client.capabilities.with_raw_response.summarize( + text="Example long text...", extra_headers={"x-stainless-retry-count": "42"} + ) + + assert response.http_request.headers.get("x-stainless-retry-count") == "42" + class TestAsyncMaisa: client = AsyncMaisa(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -1359,6 +1441,7 @@ class Model(BaseModel): [3, "", 0.5], [2, "", 0.5 * 2.0], [1, "", 0.5 * 4.0], + [-1100, "", 7.8], # test large number potentially overflowing ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) @@ -1400,3 +1483,87 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) ) assert _get_open_connections(self.client) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("maisa._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + @pytest.mark.parametrize("failure_mode", ["status", "exception"]) + async def test_retries_taken( + self, + async_client: AsyncMaisa, + failures_before_success: int, + failure_mode: Literal["status", "exception"], + respx_mock: MockRouter, + ) -> None: + client = async_client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + if failure_mode == "exception": + raise RuntimeError("oops") + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/capabilities/summarize").mock(side_effect=retry_handler) + + response = await client.capabilities.with_raw_response.summarize(text="Example long text...") + + assert response.retries_taken == failures_before_success + assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("maisa._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + async def test_omit_retry_count_header( + self, async_client: AsyncMaisa, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = async_client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/capabilities/summarize").mock(side_effect=retry_handler) + + response = await client.capabilities.with_raw_response.summarize( + text="Example long text...", extra_headers={"x-stainless-retry-count": Omit()} + ) + + assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("maisa._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + async def test_overwrite_retry_count_header( + self, async_client: AsyncMaisa, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = async_client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/capabilities/summarize").mock(side_effect=retry_handler) + + response = await client.capabilities.with_raw_response.summarize( + text="Example long text...", extra_headers={"x-stainless-retry-count": "42"} + ) + + assert response.http_request.headers.get("x-stainless-retry-count") == "42" diff --git a/tests/test_deepcopy.py b/tests/test_deepcopy.py index a33fd8b8..3697c8c9 100644 --- a/tests/test_deepcopy.py +++ b/tests/test_deepcopy.py @@ -41,8 +41,7 @@ def test_nested_list() -> None: assert_different_identities(obj1[1], obj2[1]) -class MyObject: - ... +class MyObject: ... def test_ignores_other_types() -> None: diff --git a/tests/test_models.py b/tests/test_models.py index 9e18d54e..0dbcbc18 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -31,7 +31,7 @@ class NestedModel(BaseModel): # mismatched types m = NestedModel.construct(nested="hello!") - assert m.nested == "hello!" + assert cast(Any, m.nested) == "hello!" def test_optional_nested_model() -> None: @@ -48,7 +48,7 @@ class NestedModel(BaseModel): # mismatched types m3 = NestedModel.construct(nested={"foo"}) assert isinstance(cast(Any, m3.nested), set) - assert m3.nested == {"foo"} + assert cast(Any, m3.nested) == {"foo"} def test_list_nested_model() -> None: @@ -245,7 +245,7 @@ class Model(BaseModel): assert m.foo is True m = Model.construct(foo="CARD_HOLDER") - assert m.foo is "CARD_HOLDER" + assert m.foo == "CARD_HOLDER" m = Model.construct(foo={"bar": False}) assert isinstance(m.foo, Submodel1) @@ -323,7 +323,7 @@ class Model(BaseModel): assert len(m.items) == 2 assert isinstance(m.items[0], Submodel1) assert m.items[0].level == -1 - assert m.items[1] == 156 + assert cast(Any, m.items[1]) == 156 def test_union_of_lists() -> None: @@ -355,7 +355,7 @@ class Model(BaseModel): assert len(m.items) == 2 assert isinstance(m.items[0], SubModel1) assert m.items[0].level == -1 - assert m.items[1] == 156 + assert cast(Any, m.items[1]) == 156 def test_dict_of_union() -> None: @@ -520,19 +520,15 @@ class Model(BaseModel): assert m3.to_dict(exclude_none=True) == {} assert m3.to_dict(exclude_defaults=True) == {} - if PYDANTIC_V2: - - class Model2(BaseModel): - created_at: datetime + class Model2(BaseModel): + created_at: datetime - time_str = "2024-03-21T11:39:01.275859" - m4 = Model2.construct(created_at=time_str) - assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)} - assert m4.to_dict(mode="json") == {"created_at": time_str} - else: - with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"): - m.to_dict(mode="json") + time_str = "2024-03-21T11:39:01.275859" + m4 = Model2.construct(created_at=time_str) + assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)} + assert m4.to_dict(mode="json") == {"created_at": time_str} + if not PYDANTIC_V2: with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): m.to_dict(warnings=False) @@ -558,9 +554,6 @@ class Model(BaseModel): assert m3.model_dump(exclude_none=True) == {} if not PYDANTIC_V2: - with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"): - m.model_dump(mode="json") - with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"): m.model_dump(round_trip=True) diff --git a/tests/test_response.py b/tests/test_response.py index eebb8027..9c56eceb 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,5 +1,5 @@ import json -from typing import List, cast +from typing import Any, List, Union, cast from typing_extensions import Annotated import httpx @@ -19,16 +19,13 @@ from maisa._base_client import FinalRequestOptions -class ConcreteBaseAPIResponse(APIResponse[bytes]): - ... +class ConcreteBaseAPIResponse(APIResponse[bytes]): ... -class ConcreteAPIResponse(APIResponse[List[str]]): - ... +class ConcreteAPIResponse(APIResponse[List[str]]): ... -class ConcreteAsyncAPIResponse(APIResponse[httpx.Response]): - ... +class ConcreteAsyncAPIResponse(APIResponse[httpx.Response]): ... def test_extract_response_type_direct_classes() -> None: @@ -56,8 +53,7 @@ def test_extract_response_type_binary_response() -> None: assert extract_response_type(AsyncBinaryAPIResponse) == bytes -class PydanticModel(pydantic.BaseModel): - ... +class PydanticModel(pydantic.BaseModel): ... def test_response_parse_mismatched_basemodel(client: Maisa) -> None: @@ -192,3 +188,90 @@ async def test_async_response_parse_annotated_type(async_client: AsyncMaisa) -> ) assert obj.foo == "hello!" assert obj.bar == 2 + + +@pytest.mark.parametrize( + "content, expected", + [ + ("false", False), + ("true", True), + ("False", False), + ("True", True), + ("TrUe", True), + ("FalSe", False), + ], +) +def test_response_parse_bool(client: Maisa, content: str, expected: bool) -> None: + response = APIResponse( + raw=httpx.Response(200, content=content), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + result = response.parse(to=bool) + assert result is expected + + +@pytest.mark.parametrize( + "content, expected", + [ + ("false", False), + ("true", True), + ("False", False), + ("True", True), + ("TrUe", True), + ("FalSe", False), + ], +) +async def test_async_response_parse_bool(client: AsyncMaisa, content: str, expected: bool) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=content), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + result = await response.parse(to=bool) + assert result is expected + + +class OtherModel(BaseModel): + a: str + + +@pytest.mark.parametrize("client", [False], indirect=True) # loose validation +def test_response_parse_expect_model_union_non_json_content(client: Maisa) -> None: + response = APIResponse( + raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = response.parse(to=cast(Any, Union[CustomModel, OtherModel])) + assert isinstance(obj, str) + assert obj == "foo" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_client", [False], indirect=True) # loose validation +async def test_async_response_parse_expect_model_union_non_json_content(async_client: AsyncMaisa) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = await response.parse(to=cast(Any, Union[CustomModel, OtherModel])) + assert isinstance(obj, str) + assert obj == "foo" diff --git a/tests/test_transform.py b/tests/test_transform.py index 19442b4f..ce1fcc70 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -177,17 +177,32 @@ class DateDict(TypedDict, total=False): foo: Annotated[date, PropertyInfo(format="iso8601")] +class DatetimeModel(BaseModel): + foo: datetime + + +class DateModel(BaseModel): + foo: Optional[date] + + @parametrize @pytest.mark.asyncio async def test_iso8601_format(use_async: bool) -> None: dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") + tz = "Z" if PYDANTIC_V2 else "+00:00" assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] + assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692" + tz} # type: ignore[comparison-overlap] dt = dt.replace(tzinfo=None) assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap] + assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap] assert await transform({"foo": None}, DateDict, use_async) == {"foo": None} # type: ignore[comparison-overlap] + assert await transform(DateModel(foo=None), Any, use_async) == {"foo": None} # type: ignore assert await transform({"foo": date.fromisoformat("2023-02-23")}, DateDict, use_async) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap] + assert await transform(DateModel(foo=date.fromisoformat("2023-02-23")), DateDict, use_async) == { + "foo": "2023-02-23" + } # type: ignore[comparison-overlap] @parametrize @@ -260,20 +275,22 @@ class MyModel(BaseModel): @parametrize @pytest.mark.asyncio async def test_pydantic_model_to_dictionary(use_async: bool) -> None: - assert await transform(MyModel(foo="hi!"), Any, use_async) == {"foo": "hi!"} - assert await transform(MyModel.construct(foo="hi!"), Any, use_async) == {"foo": "hi!"} + assert cast(Any, await transform(MyModel(foo="hi!"), Any, use_async)) == {"foo": "hi!"} + assert cast(Any, await transform(MyModel.construct(foo="hi!"), Any, use_async)) == {"foo": "hi!"} @parametrize @pytest.mark.asyncio async def test_pydantic_empty_model(use_async: bool) -> None: - assert await transform(MyModel.construct(), Any, use_async) == {} + assert cast(Any, await transform(MyModel.construct(), Any, use_async)) == {} @parametrize @pytest.mark.asyncio async def test_pydantic_unknown_field(use_async: bool) -> None: - assert await transform(MyModel.construct(my_untyped_field=True), Any, use_async) == {"my_untyped_field": True} + assert cast(Any, await transform(MyModel.construct(my_untyped_field=True), Any, use_async)) == { + "my_untyped_field": True + } @parametrize @@ -285,7 +302,7 @@ async def test_pydantic_mismatched_types(use_async: bool) -> None: params = await transform(model, Any, use_async) else: params = await transform(model, Any, use_async) - assert params == {"foo": True} + assert cast(Any, params) == {"foo": True} @parametrize @@ -297,7 +314,7 @@ async def test_pydantic_mismatched_object_type(use_async: bool) -> None: params = await transform(model, Any, use_async) else: params = await transform(model, Any, use_async) - assert params == {"foo": {"hello": "world"}} + assert cast(Any, params) == {"foo": {"hello": "world"}} class ModelNestedObjects(BaseModel): @@ -309,7 +326,7 @@ class ModelNestedObjects(BaseModel): async def test_pydantic_nested_objects(use_async: bool) -> None: model = ModelNestedObjects.construct(nested={"foo": "stainless"}) assert isinstance(model.nested, MyModel) - assert await transform(model, Any, use_async) == {"nested": {"foo": "stainless"}} + assert cast(Any, await transform(model, Any, use_async)) == {"nested": {"foo": "stainless"}} class ModelWithDefaultField(BaseModel): @@ -325,19 +342,19 @@ async def test_pydantic_default_field(use_async: bool) -> None: model = ModelWithDefaultField.construct() assert model.with_none_default is None assert model.with_str_default == "foo" - assert await transform(model, Any, use_async) == {} + assert cast(Any, await transform(model, Any, use_async)) == {} # should be included when the default value is explicitly given model = ModelWithDefaultField.construct(with_none_default=None, with_str_default="foo") assert model.with_none_default is None assert model.with_str_default == "foo" - assert await transform(model, Any, use_async) == {"with_none_default": None, "with_str_default": "foo"} + assert cast(Any, await transform(model, Any, use_async)) == {"with_none_default": None, "with_str_default": "foo"} # should be included when a non-default value is explicitly given model = ModelWithDefaultField.construct(with_none_default="bar", with_str_default="baz") assert model.with_none_default == "bar" assert model.with_str_default == "baz" - assert await transform(model, Any, use_async) == {"with_none_default": "bar", "with_str_default": "baz"} + assert cast(Any, await transform(model, Any, use_async)) == {"with_none_default": "bar", "with_str_default": "baz"} class TypedDictIterableUnion(TypedDict): diff --git a/tests/test_utils/test_typing.py b/tests/test_utils/test_typing.py index 0c5db425..f934fafc 100644 --- a/tests/test_utils/test_typing.py +++ b/tests/test_utils/test_typing.py @@ -9,24 +9,19 @@ _T3 = TypeVar("_T3") -class BaseGeneric(Generic[_T]): - ... +class BaseGeneric(Generic[_T]): ... -class SubclassGeneric(BaseGeneric[_T]): - ... +class SubclassGeneric(BaseGeneric[_T]): ... -class BaseGenericMultipleTypeArgs(Generic[_T, _T2, _T3]): - ... +class BaseGenericMultipleTypeArgs(Generic[_T, _T2, _T3]): ... -class SubclassGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T, _T2, _T3]): - ... +class SubclassGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T, _T2, _T3]): ... -class SubclassDifferentOrderGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T2, _T, _T3]): - ... +class SubclassDifferentOrderGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T2, _T, _T3]): ... def test_extract_type_var() -> None: diff --git a/tests/utils.py b/tests/utils.py index 9fa6ba7a..88d3e2b5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,7 +8,7 @@ from datetime import date, datetime from typing_extensions import Literal, get_args, get_origin, assert_type -from maisa._types import NoneType +from maisa._types import Omit, NoneType from maisa._utils import ( is_dict, is_list, @@ -139,11 +139,15 @@ def _assert_list_type(type_: type[object], value: object) -> None: @contextlib.contextmanager -def update_env(**new_env: str) -> Iterator[None]: +def update_env(**new_env: str | Omit) -> Iterator[None]: old = os.environ.copy() try: - os.environ.update(new_env) + for name, value in new_env.items(): + if isinstance(value, Omit): + os.environ.pop(name, None) + else: + os.environ[name] = value yield None finally: