diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 0a8d29eb41be4..131d71c5b57a3 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -1618,6 +1618,15 @@ workers: type: float example: ~ default: "90.0" + execution_api_timeout: + description: | + The timeout (in seconds) for HTTP requests from workers to the Execution API server. + This controls how long a worker will wait for a response from the API server before + timing out. Increase this value if you experience timeout errors under high load. + version_added: 3.1.1 + type: float + example: ~ + default: "5.0" socket_cleanup_timeout: description: | Number of seconds to wait after a task process exits before forcibly closing any diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 105faaf21cddd..a6a12e9daca50 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -829,6 +829,7 @@ def noop_handler(request: httpx.Request) -> httpx.Response: API_RETRY_WAIT_MIN = conf.getfloat("workers", "execution_api_retry_wait_min") API_RETRY_WAIT_MAX = conf.getfloat("workers", "execution_api_retry_wait_max") API_SSL_CERT_PATH = conf.get("api", "ssl_cert") +API_TIMEOUT = conf.getfloat("workers", "execution_api_timeout") class Client(httpx.Client): @@ -848,6 +849,10 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, * if API_SSL_CERT_PATH: ctx.load_verify_locations(API_SSL_CERT_PATH) kwargs["verify"] = ctx + + # Set timeout if not explicitly provided + kwargs.setdefault("timeout", API_TIMEOUT) + pyver = f"{'.'.join(map(str, sys.version_info[:3]))}" super().__init__( auth=auth, diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index cbd7c80bb0178..32a709663acb2 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -30,7 +30,7 @@ from uuid6 import uuid7 from airflow.sdk import timezone -from airflow.sdk.api.client import RemoteValidationError, ServerResponseError +from airflow.sdk.api.client import Client, RemoteValidationError, ServerResponseError from airflow.sdk.api.datamodels._generated import ( AssetEventsResponse, AssetResponse, @@ -99,6 +99,23 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert isinstance(err.value, FileNotFoundError) + @mock.patch("airflow.sdk.api.client.API_TIMEOUT", 60.0) + def test_timeout_configuration(self): + def handle_request(request: httpx.Request) -> httpx.Response: + return httpx.Response(status_code=200) + + client = make_client(httpx.MockTransport(handle_request)) + assert client.timeout == httpx.Timeout(60.0) + + def test_timeout_can_be_overridden(self): + def handle_request(request: httpx.Request) -> httpx.Response: + return httpx.Response(status_code=200) + + client = Client( + base_url="test://server", token="", transport=httpx.MockTransport(handle_request), timeout=120.0 + ) + assert client.timeout == httpx.Timeout(120.0) + def test_error_parsing(self): responses = [ httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg": "err", "type": "required"}]})