From 17ecf5b9cf40200236749fdb439cf8f0ca58f909 Mon Sep 17 00:00:00 2001 From: Francesco Faraone Date: Fri, 29 Sep 2023 12:15:51 +0200 Subject: [PATCH] LITE-28571 cache transports by endpoint address --- connect/client/fluent.py | 65 +++++++++++++++++++++++++++---------- connect/client/mixins.py | 4 +-- tests/client/test_fluent.py | 2 +- 3 files changed, 50 insertions(+), 21 deletions(-) diff --git a/connect/client/fluent.py b/connect/client/fluent.py index ed626df..9aa22e4 100644 --- a/connect/client/fluent.py +++ b/connect/client/fluent.py @@ -10,6 +10,7 @@ import httpx import requests +from requests.adapters import HTTPAdapter from connect.client.constants import CONNECT_ENDPOINT_URL, CONNECT_SPECS_URL from connect.client.help_formatter import DefaultFormatter @@ -24,7 +25,11 @@ from connect.client.utils import get_headers -class _ConnectClientBase(threading.local): +_SYNC_TRANSPORTS = {} +_ASYNC_TRANSPORTS = {} + + +class _ConnectClientBase: def __init__( self, api_key, @@ -53,25 +58,11 @@ def __init__( self.specs = None if self._use_specs: self.specs = OpenAPISpecs(self.specs_location) - self._response = None self.logger = logger self._help_formatter = DefaultFormatter(self.specs) self.timeout = timeout self.resourceset_append = resourceset_append - @property - def response(self) -> requests.Response: - """ - Returns the raw - [`requests`](https://requests.readthedocs.io/en/latest/api/#requests.Response) - response. - """ - return self._response - - @response.setter - def response(self, value: requests.Response): - self._response = value - def __getattr__(self, name): if '_' in name: name = name.replace('_', '-') @@ -173,7 +164,7 @@ def _get_api_error_details(self): pass -class ConnectClient(_ConnectClientBase, threading.local, SyncClientMixin): +class ConnectClient(_ConnectClientBase, SyncClientMixin): """ Create a new instance of the ConnectClient. @@ -203,7 +194,33 @@ class ConnectClient(_ConnectClientBase, threading.local, SyncClientMixin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._session = requests.Session() + self._thread_locals = threading.local() + self._thread_locals.response = None + self._thread_locals.session = requests.Session() + self._thread_locals.session.mount( + self.endpoint, + _SYNC_TRANSPORTS.setdefault( + self.endpoint, + HTTPAdapter(), + ), + ) + + @property + def session(self): + return self._thread_locals.session + + @property + def response(self) -> requests.Response: + """ + Returns the raw + [`requests`](https://requests.readthedocs.io/en/latest/api/#requests.Response) + response. + """ + return self._thread_locals.response + + @response.setter + def response(self, value: requests.Response): + self._thread_locals.response = value def _get_collection_class(self): return Collection @@ -246,7 +263,19 @@ class AsyncConnectClient(_ConnectClientBase, AsyncClientMixin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._response = contextvars.ContextVar('response', default=None) - self._session = httpx.AsyncClient(verify=_SSL_CONTEXT) + self._session = contextvars.ContextVar( + 'session', + default=httpx.AsyncClient( + transport=_ASYNC_TRANSPORTS.setdefault( + self.endpoint, + httpx.AsyncHTTPTransport(verify=_SSL_CONTEXT), + ), + ), + ) + + @property + def session(self): + return self._session.get() @property def response(self): diff --git a/connect/client/mixins.py b/connect/client/mixins.py index 3efee58..7ca6182 100644 --- a/connect/client/mixins.py +++ b/connect/client/mixins.py @@ -99,7 +99,7 @@ def _execute_http_call(self, method, url, kwargs): # noqa: CCR001 if self.logger: self.logger.log_request(method, url, kwargs) try: - self.response = self._session.request(method, url, **kwargs) + self.response = self.session.request(method, url, **kwargs) if self.logger: self.logger.log_response(self.response) except RequestException: @@ -209,7 +209,7 @@ async def _execute_http_call(self, method, url, kwargs): self.logger.log_request(method, url, kwargs) try: - self.response = await self._session.request(method, url, **kwargs) + self.response = await self.session.request(method, url, **kwargs) if self.logger: self.logger.log_response(self.response) diff --git a/tests/client/test_fluent.py b/tests/client/test_fluent.py index 9b016f1..3167878 100644 --- a/tests/client/test_fluent.py +++ b/tests/client/test_fluent.py @@ -549,4 +549,4 @@ def test_sync_client_manage_response(): c = ConnectClient('API_KEY') assert c.response is None c.response = 'Some response' - assert c._response == 'Some response' + assert c._thread_locals.response == 'Some response'