From d356c0b701afbc1d4859967f16352d354044bfcc Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 22 May 2025 00:57:36 +0000 Subject: [PATCH 01/29] feat: Update to support python 3.12 - add case statement to use the asyncio.Queue.shutdown method for 3.13+ - add special handling to allow for similar semantics as asyncio.Queue.shutdown for 3.12 Tested on multiple samples in the a2a repo and some examples in this repo --- pyproject.toml | 4 ++++ src/a2a/server/events/event_consumer.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 991fc8df4..9dcbd527a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,10 +22,14 @@ classifiers = [ "Intended Audience :: Developers", "Programming Language :: Python", "Programming Language :: Python :: 3", +<<<<<<< HEAD "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", +======= + "Programming Language :: Python :: 3.12", +>>>>>>> 8ec734c (feat: Update to support python 3.12) "Operating System :: OS Independent", "Topic :: Software Development :: Libraries :: Python Modules", "License :: OSI Approved :: Apache Software License", diff --git a/src/a2a/server/events/event_consumer.py b/src/a2a/server/events/event_consumer.py index 518680695..2a96eca20 100644 --- a/src/a2a/server/events/event_consumer.py +++ b/src/a2a/server/events/event_consumer.py @@ -15,6 +15,12 @@ from a2a.utils.errors import ServerError from a2a.utils.telemetry import SpanKind, trace_class +# This is an alias to the execption for closed queue +QueueClosed = asyncio.QueueEmpty + +# When using python 3.13 or higher, the closed queue signal is QueueShutdown +if sys.version_info >= (3, 13): + QueueClosed = asyncio.QueueShutDown # This is an alias to the exception for closed queue QueueClosed = asyncio.QueueEmpty From fd0ec5af2e50e099994a611eb2ebb4d1534adb0f Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 22 May 2025 15:08:38 +0000 Subject: [PATCH 02/29] Change to 3.10 and provided detailed description about event queue usage --- pyproject.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9dcbd527a..991fc8df4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,14 +22,10 @@ classifiers = [ "Intended Audience :: Developers", "Programming Language :: Python", "Programming Language :: Python :: 3", -<<<<<<< HEAD "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", -======= - "Programming Language :: Python :: 3.12", ->>>>>>> 8ec734c (feat: Update to support python 3.12) "Operating System :: OS Independent", "Topic :: Software Development :: Libraries :: Python Modules", "License :: OSI Approved :: Apache Software License", From 6bd1d449628fab7abd88f63c576e4e74da1cd448 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 22 May 2025 23:44:47 +0000 Subject: [PATCH 03/29] fix merge conflict --- src/a2a/server/events/event_consumer.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/a2a/server/events/event_consumer.py b/src/a2a/server/events/event_consumer.py index 2a96eca20..a5c31317d 100644 --- a/src/a2a/server/events/event_consumer.py +++ b/src/a2a/server/events/event_consumer.py @@ -15,14 +15,8 @@ from a2a.utils.errors import ServerError from a2a.utils.telemetry import SpanKind, trace_class -# This is an alias to the execption for closed queue -QueueClosed = asyncio.QueueEmpty -# When using python 3.13 or higher, the closed queue signal is QueueShutdown -if sys.version_info >= (3, 13): - QueueClosed = asyncio.QueueShutDown - -# This is an alias to the exception for closed queue +# This is an alias to the execption for closed queue QueueClosed = asyncio.QueueEmpty # When using python 3.13 or higher, the closed queue signal is QueueShutdown From b7964cfef461cd5ca97d760920b1d85cd31efebe Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 22 May 2025 23:45:50 +0000 Subject: [PATCH 04/29] Fix typo --- src/a2a/server/events/event_consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/a2a/server/events/event_consumer.py b/src/a2a/server/events/event_consumer.py index a5c31317d..518680695 100644 --- a/src/a2a/server/events/event_consumer.py +++ b/src/a2a/server/events/event_consumer.py @@ -16,7 +16,7 @@ from a2a.utils.telemetry import SpanKind, trace_class -# This is an alias to the execption for closed queue +# This is an alias to the exception for closed queue QueueClosed = asyncio.QueueEmpty # When using python 3.13 or higher, the closed queue signal is QueueShutdown From 22a6230912609b442aa93c21a6a96ae6f9148b18 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Wed, 4 Jun 2025 20:33:31 +0000 Subject: [PATCH 05/29] Add gRPC based support in the SDK. - Introduces an A2AGrpcClient to talk to server over grpc - Introduces GrpcHandler to tranlate the gRPC transport to the internal python data model and back. - A set of transform operations in proto_utils.py to handle the transform This is a starting point and can be iterated and optimized as we move forward, especially trying to automate the transform code so it stays in sync. --- buf.gen.yaml | 5 +- src/a2a/client/__init__.py | 2 + src/a2a/client/grpc_client.py | 192 +++++ src/a2a/grpc/__init__.py | 0 src/a2a/grpc/a2a_pb2.py | 180 ++++ src/a2a/grpc/a2a_pb2.pyi | 520 ++++++++++++ src/a2a/grpc/a2a_pb2_grpc.py | 478 +++++++++++ src/a2a/server/request_handlers/__init__.py | 2 + .../default_request_handler.py | 3 + .../server/request_handlers/grpc_handler.py | 358 ++++++++ src/a2a/utils/helpers.py | 33 +- src/a2a/utils/proto_utils.py | 781 ++++++++++++++++++ 12 files changed, 2552 insertions(+), 2 deletions(-) create mode 100644 src/a2a/client/grpc_client.py create mode 100644 src/a2a/grpc/__init__.py create mode 100644 src/a2a/grpc/a2a_pb2.py create mode 100644 src/a2a/grpc/a2a_pb2.pyi create mode 100644 src/a2a/grpc/a2a_pb2_grpc.py create mode 100644 src/a2a/server/request_handlers/grpc_handler.py create mode 100644 src/a2a/utils/proto_utils.py diff --git a/buf.gen.yaml b/buf.gen.yaml index 7102471ef..e5e18e657 100644 --- a/buf.gen.yaml +++ b/buf.gen.yaml @@ -19,9 +19,12 @@ managed: plugins: # Generate python protobuf related code # Generates *_pb2.py files, one for each .proto - - remote: buf.build/protocolbuffers/python + - remote: buf.build/protocolbuffers/python:v29.3 out: src/a2a/grpc # Generate python service code. # Generates *_pb2_grpc.py - remote: buf.build/grpc/python out: src/a2a/grpc + # Generates *_pb2.pyi files. + - remote: buf.build/protocolbuffers/pyi:v29.3 + out: src/a2a/grpc diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 3455c8675..1a2bb5449 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -1,6 +1,7 @@ """Client-side components for interacting with an A2A agent.""" from a2a.client.client import A2ACardResolver, A2AClient +from a2a.client.grpc_client import A2AGrpcClient from a2a.client.errors import ( A2AClientError, A2AClientHTTPError, @@ -15,5 +16,6 @@ 'A2AClientError', 'A2AClientHTTPError', 'A2AClientJSONError', + 'A2AGrpcClient', 'create_text_message_object', ] diff --git a/src/a2a/client/grpc_client.py b/src/a2a/client/grpc_client.py new file mode 100644 index 000000000..40fea5e42 --- /dev/null +++ b/src/a2a/client/grpc_client.py @@ -0,0 +1,192 @@ +import json +import logging +from collections.abc import AsyncGenerator +from typing import Any +from uuid import uuid4 +import grpc + +from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.types import ( + AgentCard, + MessageSendParams, + Task, + TaskStatusUpdateEvent, + TaskArtifactUpdateEvent, + TaskPushNotificationConfig, + TaskIdParams, + TaskQueryParams, + Message, +) +from a2a.utils.telemetry import SpanKind, trace_class +from a2a.utils import proto_utils +from a2a.grpc import a2a_pb2_grpc +from a2a.grpc import a2a_pb2 + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class A2AGrpcClient: + """A2A Client for interacting with an A2A agent via gRPC.""" + + def __init__( + self, + grpc_stub: a2a_pb2_grpc.A2AServiceStub, + agent_card: AgentCard, + ): + """Initializes the A2AGrpcClient. + + Requires an `AgentCard` + + Args: + grpc_stub: A grpc client stub. + agent_card: The agent card object. + """ + self.agent_card = agent_card + self.stub = grpc_stub + + async def send_message( + self, + request: MessageSendParams, + ) -> Task | Message : + """Sends a non-streaming message request to the agent. + + Args: + request: The `MessageSendParams` object containing the message and configuration. + + Returns: + A `Task` or `Message` object containing the agent's response. + """ + response = await self.stub.SendMessage( + a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=proto_utils.ToProto.metadata(request.metadata), + ) + ) + if response.task: + return proto_utils.FromProto.task(response.task) + return proto_utils.FromProto.message(response.msg) + + async def send_message_streaming( + self, + request: MessageSendParams, + ) -> AsyncGenerator[ + Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Sends a streaming message request to the agent and yields responses as they arrive. + + This method uses gRPC streams to receive a stream of updates from the + agent. + + Args: + request: The `MessageSendParams` object containing the message and configuration. + + Yields: + `Message` or `Task` or `TaskStatusUpdateEvent` or + `TaskArtifactUpdateEvent` objects as they are received in the + stream. + """ + stream = self.stub.SendStreamingMessage( + a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=proto_utils.ToProto.metadata(request.metadata), + ) + ) + while True: + response = await stream.read() + if response == grpc.aio.EOF: + break + if response.HasField('msg'): + yield proto_utils.FromProto.message(response.msg) + elif response.HasField('task'): + yield proto_utils.FromProto.task(response.task) + elif response.HasField('status_update'): + yield proto_utils.FromProto.task_status_update_event( + response.status_update + ) + elif response.HasField('artifact_update'): + yield proto_utils.FromProto.task_artifact_update_event( + response.artifact_update + ) + + async def get_task( + self, + request: TaskQueryParams, + ) -> Task: + """Retrieves the current state and history of a specific task. + + Args: + request: The `TaskQueryParams` object specifying the task ID + + Returns: + A `Task` object containing the Task or None. + """ + task = await self.stub.GetTask( + a2a_pb2.GetTaskRequest(name=f'tasks/{request.id}') + ) + return proto_utils.FromProto.task(task) + + async def cancel_task( + self, + request: TaskIdParams, + ) -> Task: + """Requests the agent to cancel a specific task. + + Args: + request: The `TaskIdParams` object specifying the task ID. + + Returns: + A `Task` object containing the updated Task + """ + task = await self.stub.CancelTask( + a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') + ) + return proto_utils.FromProto.task(task) + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task. + + Args: + request: The `TaskPushNotificationConfig` object specifying the task ID and configuration. + + Returns: + A `TaskPushNotificationConfig` object containing the config. + """ + config = await self.stub.CreateTaskPushNotification( + a2a_pb2.CreateTaskPushNotificationRequest( + parent='', + config_id='', + config=proto_utils.ToProto.task_push_notification_config( + request + ), + ) + ) + return proto_utils.FromProto.task_push_notification_config(config) + + async def get_task_callback( + self, + request: TaskIdParams, # TODO: Update to a push id params + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task. + + Args: + request: The `TaskIdParams` object specifying the task ID. + + Returns: + A `TaskPushNotificationConfig` object containing the configuration. + """ + config = await self.stub.GetTaskPushNotification( + a2a_pb2.GetTaskPushNotificationRequest( + name=f'tasks/{request.id}/pushNotification/undefined', + ) + ) + return proto_utils.FromProto.task_push_notification_config(config) diff --git a/src/a2a/grpc/__init__.py b/src/a2a/grpc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/a2a/grpc/a2a_pb2.py b/src/a2a/grpc/a2a_pb2.py new file mode 100644 index 000000000..81078b8be --- /dev/null +++ b/src/a2a/grpc/a2a_pb2.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: a2a.proto +# Protobuf Python Version: 5.29.3 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 3, + '', + 'a2a.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.api import annotations_pb2 as google_dot_api_dot_annotations__pb2 +from google.api import client_pb2 as google_dot_api_dot_client__pb2 +from google.api import field_behavior_pb2 as google_dot_api_dot_field__behavior__pb2 +from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 +from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ta2a.proto\x12\x06\x61\x32\x61.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xde\x01\n\x18SendMessageConfiguration\x12\x32\n\x15\x61\x63\x63\x65pted_output_modes\x18\x01 \x03(\tR\x13\x61\x63\x63\x65ptedOutputModes\x12K\n\x11push_notification\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.PushNotificationConfigR\x10pushNotification\x12%\n\x0ehistory_length\x18\x03 \x01(\x05R\rhistoryLength\x12\x1a\n\x08\x62locking\x18\x04 \x01(\x08R\x08\x62locking\"\xf1\x01\n\x04Task\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12*\n\x06status\x18\x03 \x01(\x0b\x32\x12.a2a.v1.TaskStatusR\x06status\x12.\n\tartifacts\x18\x04 \x03(\x0b\x32\x10.a2a.v1.ArtifactR\tartifacts\x12)\n\x07history\x18\x05 \x03(\x0b\x32\x0f.a2a.v1.MessageR\x07history\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\x98\x01\n\nTaskStatus\x12\'\n\x05state\x18\x01 \x01(\x0e\x32\x11.a2a.v1.TaskStateR\x05state\x12\'\n\x06update\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageR\x06update\x12\x38\n\ttimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\"t\n\x04Part\x12\x14\n\x04text\x18\x01 \x01(\tH\x00R\x04text\x12&\n\x04\x66ile\x18\x02 \x01(\x0b\x32\x10.a2a.v1.FilePartH\x00R\x04\x66ile\x12&\n\x04\x64\x61ta\x18\x03 \x01(\x0b\x32\x10.a2a.v1.DataPartH\x00R\x04\x64\x61taB\x06\n\x04part\"\x7f\n\x08\x46ilePart\x12$\n\rfile_with_uri\x18\x01 \x01(\tH\x00R\x0b\x66ileWithUri\x12(\n\x0f\x66ile_with_bytes\x18\x02 \x01(\x0cH\x00R\rfileWithBytes\x12\x1b\n\tmime_type\x18\x03 \x01(\tR\x08mimeTypeB\x06\n\x04\x66ile\"7\n\x08\x44\x61taPart\x12+\n\x04\x64\x61ta\x18\x01 \x01(\x0b\x32\x17.google.protobuf.StructR\x04\x64\x61ta\"\xff\x01\n\x07Message\x12\x1d\n\nmessage_id\x18\x01 \x01(\tR\tmessageId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12\x17\n\x07task_id\x18\x03 \x01(\tR\x06taskId\x12 \n\x04role\x18\x04 \x01(\x0e\x32\x0c.a2a.v1.RoleR\x04role\x12&\n\x07\x63ontent\x18\x05 \x03(\x0b\x32\x0c.a2a.v1.PartR\x07\x63ontent\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\x12\x1e\n\nextensions\x18\x07 \x03(\tR\nextensions\"\xda\x01\n\x08\x41rtifact\x12\x1f\n\x0b\x61rtifact_id\x18\x01 \x01(\tR\nartifactId\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x04 \x01(\tR\x0b\x64\x65scription\x12\"\n\x05parts\x18\x05 \x03(\x0b\x32\x0c.a2a.v1.PartR\x05parts\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\x12\x1e\n\nextensions\x18\x07 \x03(\tR\nextensions\"\xc6\x01\n\x15TaskStatusUpdateEvent\x12\x17\n\x07task_id\x18\x01 \x01(\tR\x06taskId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12*\n\x06status\x18\x03 \x01(\x0b\x32\x12.a2a.v1.TaskStatusR\x06status\x12\x14\n\x05\x66inal\x18\x04 \x01(\x08R\x05\x66inal\x12\x33\n\x08metadata\x18\x05 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\xeb\x01\n\x17TaskArtifactUpdateEvent\x12\x17\n\x07task_id\x18\x01 \x01(\tR\x06taskId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12,\n\x08\x61rtifact\x18\x03 \x01(\x0b\x32\x10.a2a.v1.ArtifactR\x08\x61rtifact\x12\x16\n\x06\x61ppend\x18\x04 \x01(\x08R\x06\x61ppend\x12\x1d\n\nlast_chunk\x18\x05 \x01(\x08R\tlastChunk\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\x94\x01\n\x16PushNotificationConfig\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x10\n\x03url\x18\x02 \x01(\tR\x03url\x12\x14\n\x05token\x18\x03 \x01(\tR\x05token\x12\x42\n\x0e\x61uthentication\x18\x04 \x01(\x0b\x32\x1a.a2a.v1.AuthenticationInfoR\x0e\x61uthentication\"P\n\x12\x41uthenticationInfo\x12\x18\n\x07schemes\x18\x01 \x03(\tR\x07schemes\x12 \n\x0b\x63redentials\x18\x02 \x01(\tR\x0b\x63redentials\"\xc8\x05\n\tAgentCard\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x02 \x01(\tR\x0b\x64\x65scription\x12\x10\n\x03url\x18\x03 \x01(\tR\x03url\x12\x31\n\x08provider\x18\x04 \x01(\x0b\x32\x15.a2a.v1.AgentProviderR\x08provider\x12\x18\n\x07version\x18\x05 \x01(\tR\x07version\x12+\n\x11\x64ocumentation_url\x18\x06 \x01(\tR\x10\x64ocumentationUrl\x12=\n\x0c\x63\x61pabilities\x18\x07 \x01(\x0b\x32\x19.a2a.v1.AgentCapabilitiesR\x0c\x63\x61pabilities\x12Q\n\x10security_schemes\x18\x08 \x03(\x0b\x32&.a2a.v1.AgentCard.SecuritySchemesEntryR\x0fsecuritySchemes\x12,\n\x08security\x18\t \x03(\x0b\x32\x10.a2a.v1.SecurityR\x08security\x12.\n\x13\x64\x65\x66\x61ult_input_modes\x18\n \x03(\tR\x11\x64\x65\x66\x61ultInputModes\x12\x30\n\x14\x64\x65\x66\x61ult_output_modes\x18\x0b \x03(\tR\x12\x64\x65\x66\x61ultOutputModes\x12*\n\x06skills\x18\x0c \x03(\x0b\x32\x12.a2a.v1.AgentSkillR\x06skills\x12O\n$supports_authenticated_extended_card\x18\r \x01(\x08R!supportsAuthenticatedExtendedCard\x1aZ\n\x14SecuritySchemesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x16.a2a.v1.SecuritySchemeR\x05value:\x02\x38\x01\"E\n\rAgentProvider\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\x12\"\n\x0corganization\x18\x02 \x01(\tR\x0corganization\"\x98\x01\n\x11\x41gentCapabilities\x12\x1c\n\tstreaming\x18\x01 \x01(\x08R\tstreaming\x12-\n\x12push_notifications\x18\x02 \x01(\x08R\x11pushNotifications\x12\x36\n\nextensions\x18\x03 \x03(\x0b\x32\x16.a2a.v1.AgentExtensionR\nextensions\"\x91\x01\n\x0e\x41gentExtension\x12\x10\n\x03uri\x18\x01 \x01(\tR\x03uri\x12 \n\x0b\x64\x65scription\x18\x02 \x01(\tR\x0b\x64\x65scription\x12\x1a\n\x08required\x18\x03 \x01(\x08R\x08required\x12/\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.StructR\x06params\"\xc6\x01\n\nAgentSkill\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x03 \x01(\tR\x0b\x64\x65scription\x12\x12\n\x04tags\x18\x04 \x03(\tR\x04tags\x12\x1a\n\x08\x65xamples\x18\x05 \x03(\tR\x08\x65xamples\x12\x1f\n\x0binput_modes\x18\x06 \x03(\tR\ninputModes\x12!\n\x0coutput_modes\x18\x07 \x03(\tR\x0boutputModes\"\x8a\x01\n\x1aTaskPushNotificationConfig\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12X\n\x18push_notification_config\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.PushNotificationConfigR\x16pushNotificationConfig\" \n\nStringList\x12\x12\n\x04list\x18\x01 \x03(\tR\x04list\"\x93\x01\n\x08Security\x12\x37\n\x07schemes\x18\x01 \x03(\x0b\x32\x1d.a2a.v1.Security.SchemesEntryR\x07schemes\x1aN\n\x0cSchemesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x12.a2a.v1.StringListR\x05value:\x02\x38\x01\"\x91\x03\n\x0eSecurityScheme\x12U\n\x17\x61pi_key_security_scheme\x18\x01 \x01(\x0b\x32\x1c.a2a.v1.APIKeySecuritySchemeH\x00R\x14\x61piKeySecurityScheme\x12[\n\x19http_auth_security_scheme\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.HTTPAuthSecuritySchemeH\x00R\x16httpAuthSecurityScheme\x12T\n\x16oauth2_security_scheme\x18\x03 \x01(\x0b\x32\x1c.a2a.v1.OAuth2SecuritySchemeH\x00R\x14oauth2SecurityScheme\x12k\n\x1fopen_id_connect_security_scheme\x18\x04 \x01(\x0b\x32#.a2a.v1.OpenIdConnectSecuritySchemeH\x00R\x1bopenIdConnectSecuritySchemeB\x08\n\x06scheme\"h\n\x14\x41PIKeySecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12\x1a\n\x08location\x18\x02 \x01(\tR\x08location\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\"w\n\x16HTTPAuthSecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12\x16\n\x06scheme\x18\x02 \x01(\tR\x06scheme\x12#\n\rbearer_format\x18\x03 \x01(\tR\x0c\x62\x65\x61rerFormat\"b\n\x14OAuth2SecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12(\n\x05\x66lows\x18\x02 \x01(\x0b\x32\x12.a2a.v1.OAuthFlowsR\x05\x66lows\"n\n\x1bOpenIdConnectSecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12-\n\x13open_id_connect_url\x18\x02 \x01(\tR\x10openIdConnectUrl\"\xb0\x02\n\nOAuthFlows\x12S\n\x12\x61uthorization_code\x18\x01 \x01(\x0b\x32\".a2a.v1.AuthorizationCodeOAuthFlowH\x00R\x11\x61uthorizationCode\x12S\n\x12\x63lient_credentials\x18\x02 \x01(\x0b\x32\".a2a.v1.ClientCredentialsOAuthFlowH\x00R\x11\x63lientCredentials\x12\x37\n\x08implicit\x18\x03 \x01(\x0b\x32\x19.a2a.v1.ImplicitOAuthFlowH\x00R\x08implicit\x12\x37\n\x08password\x18\x04 \x01(\x0b\x32\x19.a2a.v1.PasswordOAuthFlowH\x00R\x08passwordB\x06\n\x04\x66low\"\x8a\x02\n\x1a\x41uthorizationCodeOAuthFlow\x12+\n\x11\x61uthorization_url\x18\x01 \x01(\tR\x10\x61uthorizationUrl\x12\x1b\n\ttoken_url\x18\x02 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x03 \x01(\tR\nrefreshUrl\x12\x46\n\x06scopes\x18\x04 \x03(\x0b\x32..a2a.v1.AuthorizationCodeOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xdd\x01\n\x1a\x43lientCredentialsOAuthFlow\x12\x1b\n\ttoken_url\x18\x01 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12\x46\n\x06scopes\x18\x03 \x03(\x0b\x32..a2a.v1.ClientCredentialsOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xdb\x01\n\x11ImplicitOAuthFlow\x12+\n\x11\x61uthorization_url\x18\x01 \x01(\tR\x10\x61uthorizationUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12=\n\x06scopes\x18\x03 \x03(\x0b\x32%.a2a.v1.ImplicitOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xcb\x01\n\x11PasswordOAuthFlow\x12\x1b\n\ttoken_url\x18\x01 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12=\n\x06scopes\x18\x03 \x03(\x0b\x32%.a2a.v1.PasswordOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xc1\x01\n\x12SendMessageRequest\x12.\n\x07request\x18\x01 \x01(\x0b\x32\x0f.a2a.v1.MessageB\x03\xe0\x41\x02R\x07request\x12\x46\n\rconfiguration\x18\x02 \x01(\x0b\x32 .a2a.v1.SendMessageConfigurationR\rconfiguration\x12\x33\n\x08metadata\x18\x03 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"P\n\x0eGetTaskRequest\x12\x17\n\x04name\x18\x01 \x01(\tB\x03\xe0\x41\x02R\x04name\x12%\n\x0ehistory_length\x18\x02 \x01(\x05R\rhistoryLength\"\'\n\x11\x43\x61ncelTaskRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"4\n\x1eGetTaskPushNotificationRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"\xa3\x01\n!CreateTaskPushNotificationRequest\x12\x1b\n\x06parent\x18\x01 \x01(\tB\x03\xe0\x41\x02R\x06parent\x12 \n\tconfig_id\x18\x02 \x01(\tB\x03\xe0\x41\x02R\x08\x63onfigId\x12?\n\x06\x63onfig\x18\x03 \x01(\x0b\x32\".a2a.v1.TaskPushNotificationConfigB\x03\xe0\x41\x02R\x06\x63onfig\"-\n\x17TaskSubscriptionRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"u\n\x1fListTaskPushNotificationRequest\x12\x16\n\x06parent\x18\x01 \x01(\tR\x06parent\x12\x1b\n\tpage_size\x18\x02 \x01(\x05R\x08pageSize\x12\x1d\n\npage_token\x18\x03 \x01(\tR\tpageToken\"\x15\n\x13GetAgentCardRequest\"i\n\x13SendMessageResponse\x12\"\n\x04task\x18\x01 \x01(\x0b\x32\x0c.a2a.v1.TaskH\x00R\x04task\x12#\n\x03msg\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageH\x00R\x03msgB\t\n\x07payload\"\xf6\x01\n\x0eStreamResponse\x12\"\n\x04task\x18\x01 \x01(\x0b\x32\x0c.a2a.v1.TaskH\x00R\x04task\x12#\n\x03msg\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageH\x00R\x03msg\x12\x44\n\rstatus_update\x18\x03 \x01(\x0b\x32\x1d.a2a.v1.TaskStatusUpdateEventH\x00R\x0cstatusUpdate\x12J\n\x0f\x61rtifact_update\x18\x04 \x01(\x0b\x32\x1f.a2a.v1.TaskArtifactUpdateEventH\x00R\x0e\x61rtifactUpdateB\t\n\x07payload\"\x88\x01\n ListTaskPushNotificationResponse\x12<\n\x07\x63onfigs\x18\x01 \x03(\x0b\x32\".a2a.v1.TaskPushNotificationConfigR\x07\x63onfigs\x12&\n\x0fnext_page_token\x18\x02 \x01(\tR\rnextPageToken*\xfa\x01\n\tTaskState\x12\x1a\n\x16TASK_STATE_UNSPECIFIED\x10\x00\x12\x18\n\x14TASK_STATE_SUBMITTED\x10\x01\x12\x16\n\x12TASK_STATE_WORKING\x10\x02\x12\x18\n\x14TASK_STATE_COMPLETED\x10\x03\x12\x15\n\x11TASK_STATE_FAILED\x10\x04\x12\x18\n\x14TASK_STATE_CANCELLED\x10\x05\x12\x1d\n\x19TASK_STATE_INPUT_REQUIRED\x10\x06\x12\x17\n\x13TASK_STATE_REJECTED\x10\x07\x12\x1c\n\x18TASK_STATE_AUTH_REQUIRED\x10\x08*;\n\x04Role\x12\x14\n\x10ROLE_UNSPECIFIED\x10\x00\x12\r\n\tROLE_USER\x10\x01\x12\x0e\n\nROLE_AGENT\x10\x02\x32\xd0\x08\n\nA2AService\x12\x64\n\x0bSendMessage\x12\x1a.a2a.v1.SendMessageRequest\x1a\x1b.a2a.v1.SendMessageResponse\"\x1c\x82\xd3\xe4\x93\x02\x16\"\x11/v1//message:send:\x01*\x12k\n\x14SendStreamingMessage\x12\x1a.a2a.v1.SendMessageRequest\x1a\x16.a2a.v1.StreamResponse\"\x1d\x82\xd3\xe4\x93\x02\x17\"\x12/v1/message:stream:\x01*0\x01\x12R\n\x07GetTask\x12\x16.a2a.v1.GetTaskRequest\x1a\x0c.a2a.v1.Task\"!\xda\x41\x04name\x82\xd3\xe4\x93\x02\x14\x12\x12/v1/{name=tasks/*}\x12W\n\nCancelTask\x12\x19.a2a.v1.CancelTaskRequest\x1a\x0c.a2a.v1.Task\" \x82\xd3\xe4\x93\x02\x1a\"\x15/v1/tasks/{id}:cancel:\x01*\x12s\n\x10TaskSubscription\x12\x1f.a2a.v1.TaskSubscriptionRequest\x1a\x16.a2a.v1.StreamResponse\"$\x82\xd3\xe4\x93\x02\x1e\x12\x1c/v1/{name=tasks/*}:subscribe0\x01\x12\xb2\x01\n\x1a\x43reateTaskPushNotification\x12).a2a.v1.CreateTaskPushNotificationRequest\x1a\".a2a.v1.TaskPushNotificationConfig\"E\xda\x41\rparent,config\x82\xd3\xe4\x93\x02/\"%/v1/{parent=task/*/pushNotifications}:\x06\x63onfig\x12\x9c\x01\n\x17GetTaskPushNotification\x12&.a2a.v1.GetTaskPushNotificationRequest\x1a\".a2a.v1.TaskPushNotificationConfig\"5\xda\x41\x04name\x82\xd3\xe4\x93\x02(\x12&/v1/{name=tasks/*/pushNotifications/*}\x12\xa6\x01\n\x18ListTaskPushNotification\x12\'.a2a.v1.ListTaskPushNotificationRequest\x1a(.a2a.v1.ListTaskPushNotificationResponse\"7\xda\x41\x06parent\x82\xd3\xe4\x93\x02(\x12&/v1/{parent=tasks/*}/pushNotifications\x12P\n\x0cGetAgentCard\x12\x1b.a2a.v1.GetAgentCardRequest\x1a\x11.a2a.v1.AgentCard\"\x10\x82\xd3\xe4\x93\x02\n\x12\x08/v1/cardBi\n\ncom.a2a.v1B\x08\x41\x32\x61ProtoP\x01Z\x18google.golang.org/a2a/v1\xa2\x02\x03\x41XX\xaa\x02\x06\x41\x32\x61.V1\xca\x02\x06\x41\x32\x61\\V1\xe2\x02\x12\x41\x32\x61\\V1\\GPBMetadata\xea\x02\x07\x41\x32\x61::V1b\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'a2a_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\ncom.a2a.v1B\010A2aProtoP\001Z\030google.golang.org/a2a/v1\242\002\003AXX\252\002\006A2a.V1\312\002\006A2a\\V1\342\002\022A2a\\V1\\GPBMetadata\352\002\007A2a::V1' + _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._loaded_options = None + _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._serialized_options = b'8\001' + _globals['_SECURITY_SCHEMESENTRY']._loaded_options = None + _globals['_SECURITY_SCHEMESENTRY']._serialized_options = b'8\001' + _globals['_AUTHORIZATIONCODEOAUTHFLOW_SCOPESENTRY']._loaded_options = None + _globals['_AUTHORIZATIONCODEOAUTHFLOW_SCOPESENTRY']._serialized_options = b'8\001' + _globals['_CLIENTCREDENTIALSOAUTHFLOW_SCOPESENTRY']._loaded_options = None + _globals['_CLIENTCREDENTIALSOAUTHFLOW_SCOPESENTRY']._serialized_options = b'8\001' + _globals['_IMPLICITOAUTHFLOW_SCOPESENTRY']._loaded_options = None + _globals['_IMPLICITOAUTHFLOW_SCOPESENTRY']._serialized_options = b'8\001' + _globals['_PASSWORDOAUTHFLOW_SCOPESENTRY']._loaded_options = None + _globals['_PASSWORDOAUTHFLOW_SCOPESENTRY']._serialized_options = b'8\001' + _globals['_SENDMESSAGEREQUEST'].fields_by_name['request']._loaded_options = None + _globals['_SENDMESSAGEREQUEST'].fields_by_name['request']._serialized_options = b'\340A\002' + _globals['_GETTASKREQUEST'].fields_by_name['name']._loaded_options = None + _globals['_GETTASKREQUEST'].fields_by_name['name']._serialized_options = b'\340A\002' + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['parent']._loaded_options = None + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['parent']._serialized_options = b'\340A\002' + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['config_id']._loaded_options = None + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['config_id']._serialized_options = b'\340A\002' + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['config']._loaded_options = None + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['config']._serialized_options = b'\340A\002' + _globals['_A2ASERVICE'].methods_by_name['SendMessage']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['SendMessage']._serialized_options = b'\202\323\344\223\002\026\"\021/v1//message:send:\001*' + _globals['_A2ASERVICE'].methods_by_name['SendStreamingMessage']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['SendStreamingMessage']._serialized_options = b'\202\323\344\223\002\027\"\022/v1/message:stream:\001*' + _globals['_A2ASERVICE'].methods_by_name['GetTask']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['GetTask']._serialized_options = b'\332A\004name\202\323\344\223\002\024\022\022/v1/{name=tasks/*}' + _globals['_A2ASERVICE'].methods_by_name['CancelTask']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['CancelTask']._serialized_options = b'\202\323\344\223\002\032\"\025/v1/tasks/{id}:cancel:\001*' + _globals['_A2ASERVICE'].methods_by_name['TaskSubscription']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['TaskSubscription']._serialized_options = b'\202\323\344\223\002\036\022\034/v1/{name=tasks/*}:subscribe' + _globals['_A2ASERVICE'].methods_by_name['CreateTaskPushNotification']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['CreateTaskPushNotification']._serialized_options = b'\332A\rparent,config\202\323\344\223\002/\"%/v1/{parent=task/*/pushNotifications}:\006config' + _globals['_A2ASERVICE'].methods_by_name['GetTaskPushNotification']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['GetTaskPushNotification']._serialized_options = b'\332A\004name\202\323\344\223\002(\022&/v1/{name=tasks/*/pushNotifications/*}' + _globals['_A2ASERVICE'].methods_by_name['ListTaskPushNotification']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['ListTaskPushNotification']._serialized_options = b'\332A\006parent\202\323\344\223\002(\022&/v1/{parent=tasks/*}/pushNotifications' + _globals['_A2ASERVICE'].methods_by_name['GetAgentCard']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['GetAgentCard']._serialized_options = b'\202\323\344\223\002\n\022\010/v1/card' + _globals['_TASKSTATE']._serialized_start=7161 + _globals['_TASKSTATE']._serialized_end=7411 + _globals['_ROLE']._serialized_start=7413 + _globals['_ROLE']._serialized_end=7472 + _globals['_SENDMESSAGECONFIGURATION']._serialized_start=173 + _globals['_SENDMESSAGECONFIGURATION']._serialized_end=395 + _globals['_TASK']._serialized_start=398 + _globals['_TASK']._serialized_end=639 + _globals['_TASKSTATUS']._serialized_start=642 + _globals['_TASKSTATUS']._serialized_end=794 + _globals['_PART']._serialized_start=796 + _globals['_PART']._serialized_end=912 + _globals['_FILEPART']._serialized_start=914 + _globals['_FILEPART']._serialized_end=1041 + _globals['_DATAPART']._serialized_start=1043 + _globals['_DATAPART']._serialized_end=1098 + _globals['_MESSAGE']._serialized_start=1101 + _globals['_MESSAGE']._serialized_end=1356 + _globals['_ARTIFACT']._serialized_start=1359 + _globals['_ARTIFACT']._serialized_end=1577 + _globals['_TASKSTATUSUPDATEEVENT']._serialized_start=1580 + _globals['_TASKSTATUSUPDATEEVENT']._serialized_end=1778 + _globals['_TASKARTIFACTUPDATEEVENT']._serialized_start=1781 + _globals['_TASKARTIFACTUPDATEEVENT']._serialized_end=2016 + _globals['_PUSHNOTIFICATIONCONFIG']._serialized_start=2019 + _globals['_PUSHNOTIFICATIONCONFIG']._serialized_end=2167 + _globals['_AUTHENTICATIONINFO']._serialized_start=2169 + _globals['_AUTHENTICATIONINFO']._serialized_end=2249 + _globals['_AGENTCARD']._serialized_start=2252 + _globals['_AGENTCARD']._serialized_end=2964 + _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._serialized_start=2874 + _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._serialized_end=2964 + _globals['_AGENTPROVIDER']._serialized_start=2966 + _globals['_AGENTPROVIDER']._serialized_end=3035 + _globals['_AGENTCAPABILITIES']._serialized_start=3038 + _globals['_AGENTCAPABILITIES']._serialized_end=3190 + _globals['_AGENTEXTENSION']._serialized_start=3193 + _globals['_AGENTEXTENSION']._serialized_end=3338 + _globals['_AGENTSKILL']._serialized_start=3341 + _globals['_AGENTSKILL']._serialized_end=3539 + _globals['_TASKPUSHNOTIFICATIONCONFIG']._serialized_start=3542 + _globals['_TASKPUSHNOTIFICATIONCONFIG']._serialized_end=3680 + _globals['_STRINGLIST']._serialized_start=3682 + _globals['_STRINGLIST']._serialized_end=3714 + _globals['_SECURITY']._serialized_start=3717 + _globals['_SECURITY']._serialized_end=3864 + _globals['_SECURITY_SCHEMESENTRY']._serialized_start=3786 + _globals['_SECURITY_SCHEMESENTRY']._serialized_end=3864 + _globals['_SECURITYSCHEME']._serialized_start=3867 + _globals['_SECURITYSCHEME']._serialized_end=4268 + _globals['_APIKEYSECURITYSCHEME']._serialized_start=4270 + _globals['_APIKEYSECURITYSCHEME']._serialized_end=4374 + _globals['_HTTPAUTHSECURITYSCHEME']._serialized_start=4376 + _globals['_HTTPAUTHSECURITYSCHEME']._serialized_end=4495 + _globals['_OAUTH2SECURITYSCHEME']._serialized_start=4497 + _globals['_OAUTH2SECURITYSCHEME']._serialized_end=4595 + _globals['_OPENIDCONNECTSECURITYSCHEME']._serialized_start=4597 + _globals['_OPENIDCONNECTSECURITYSCHEME']._serialized_end=4707 + _globals['_OAUTHFLOWS']._serialized_start=4710 + _globals['_OAUTHFLOWS']._serialized_end=5014 + _globals['_AUTHORIZATIONCODEOAUTHFLOW']._serialized_start=5017 + _globals['_AUTHORIZATIONCODEOAUTHFLOW']._serialized_end=5283 + _globals['_AUTHORIZATIONCODEOAUTHFLOW_SCOPESENTRY']._serialized_start=5226 + _globals['_AUTHORIZATIONCODEOAUTHFLOW_SCOPESENTRY']._serialized_end=5283 + _globals['_CLIENTCREDENTIALSOAUTHFLOW']._serialized_start=5286 + _globals['_CLIENTCREDENTIALSOAUTHFLOW']._serialized_end=5507 + _globals['_CLIENTCREDENTIALSOAUTHFLOW_SCOPESENTRY']._serialized_start=5226 + _globals['_CLIENTCREDENTIALSOAUTHFLOW_SCOPESENTRY']._serialized_end=5283 + _globals['_IMPLICITOAUTHFLOW']._serialized_start=5510 + _globals['_IMPLICITOAUTHFLOW']._serialized_end=5729 + _globals['_IMPLICITOAUTHFLOW_SCOPESENTRY']._serialized_start=5226 + _globals['_IMPLICITOAUTHFLOW_SCOPESENTRY']._serialized_end=5283 + _globals['_PASSWORDOAUTHFLOW']._serialized_start=5732 + _globals['_PASSWORDOAUTHFLOW']._serialized_end=5935 + _globals['_PASSWORDOAUTHFLOW_SCOPESENTRY']._serialized_start=5226 + _globals['_PASSWORDOAUTHFLOW_SCOPESENTRY']._serialized_end=5283 + _globals['_SENDMESSAGEREQUEST']._serialized_start=5938 + _globals['_SENDMESSAGEREQUEST']._serialized_end=6131 + _globals['_GETTASKREQUEST']._serialized_start=6133 + _globals['_GETTASKREQUEST']._serialized_end=6213 + _globals['_CANCELTASKREQUEST']._serialized_start=6215 + _globals['_CANCELTASKREQUEST']._serialized_end=6254 + _globals['_GETTASKPUSHNOTIFICATIONREQUEST']._serialized_start=6256 + _globals['_GETTASKPUSHNOTIFICATIONREQUEST']._serialized_end=6308 + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST']._serialized_start=6311 + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST']._serialized_end=6474 + _globals['_TASKSUBSCRIPTIONREQUEST']._serialized_start=6476 + _globals['_TASKSUBSCRIPTIONREQUEST']._serialized_end=6521 + _globals['_LISTTASKPUSHNOTIFICATIONREQUEST']._serialized_start=6523 + _globals['_LISTTASKPUSHNOTIFICATIONREQUEST']._serialized_end=6640 + _globals['_GETAGENTCARDREQUEST']._serialized_start=6642 + _globals['_GETAGENTCARDREQUEST']._serialized_end=6663 + _globals['_SENDMESSAGERESPONSE']._serialized_start=6665 + _globals['_SENDMESSAGERESPONSE']._serialized_end=6770 + _globals['_STREAMRESPONSE']._serialized_start=6773 + _globals['_STREAMRESPONSE']._serialized_end=7019 + _globals['_LISTTASKPUSHNOTIFICATIONRESPONSE']._serialized_start=7022 + _globals['_LISTTASKPUSHNOTIFICATIONRESPONSE']._serialized_end=7158 + _globals['_A2ASERVICE']._serialized_start=7475 + _globals['_A2ASERVICE']._serialized_end=8579 +# @@protoc_insertion_point(module_scope) diff --git a/src/a2a/grpc/a2a_pb2.pyi b/src/a2a/grpc/a2a_pb2.pyi new file mode 100644 index 000000000..8d2fad9b8 --- /dev/null +++ b/src/a2a/grpc/a2a_pb2.pyi @@ -0,0 +1,520 @@ +from google.api import annotations_pb2 as _annotations_pb2 +from google.api import client_pb2 as _client_pb2 +from google.api import field_behavior_pb2 as _field_behavior_pb2 +from google.protobuf import struct_pb2 as _struct_pb2 +from google.protobuf import timestamp_pb2 as _timestamp_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class TaskState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + TASK_STATE_UNSPECIFIED: _ClassVar[TaskState] + TASK_STATE_SUBMITTED: _ClassVar[TaskState] + TASK_STATE_WORKING: _ClassVar[TaskState] + TASK_STATE_COMPLETED: _ClassVar[TaskState] + TASK_STATE_FAILED: _ClassVar[TaskState] + TASK_STATE_CANCELLED: _ClassVar[TaskState] + TASK_STATE_INPUT_REQUIRED: _ClassVar[TaskState] + TASK_STATE_REJECTED: _ClassVar[TaskState] + TASK_STATE_AUTH_REQUIRED: _ClassVar[TaskState] + +class Role(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + ROLE_UNSPECIFIED: _ClassVar[Role] + ROLE_USER: _ClassVar[Role] + ROLE_AGENT: _ClassVar[Role] +TASK_STATE_UNSPECIFIED: TaskState +TASK_STATE_SUBMITTED: TaskState +TASK_STATE_WORKING: TaskState +TASK_STATE_COMPLETED: TaskState +TASK_STATE_FAILED: TaskState +TASK_STATE_CANCELLED: TaskState +TASK_STATE_INPUT_REQUIRED: TaskState +TASK_STATE_REJECTED: TaskState +TASK_STATE_AUTH_REQUIRED: TaskState +ROLE_UNSPECIFIED: Role +ROLE_USER: Role +ROLE_AGENT: Role + +class SendMessageConfiguration(_message.Message): + __slots__ = ("accepted_output_modes", "push_notification", "history_length", "blocking") + ACCEPTED_OUTPUT_MODES_FIELD_NUMBER: _ClassVar[int] + PUSH_NOTIFICATION_FIELD_NUMBER: _ClassVar[int] + HISTORY_LENGTH_FIELD_NUMBER: _ClassVar[int] + BLOCKING_FIELD_NUMBER: _ClassVar[int] + accepted_output_modes: _containers.RepeatedScalarFieldContainer[str] + push_notification: PushNotificationConfig + history_length: int + blocking: bool + def __init__(self, accepted_output_modes: _Optional[_Iterable[str]] = ..., push_notification: _Optional[_Union[PushNotificationConfig, _Mapping]] = ..., history_length: _Optional[int] = ..., blocking: bool = ...) -> None: ... + +class Task(_message.Message): + __slots__ = ("id", "context_id", "status", "artifacts", "history", "metadata") + ID_FIELD_NUMBER: _ClassVar[int] + CONTEXT_ID_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + ARTIFACTS_FIELD_NUMBER: _ClassVar[int] + HISTORY_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + id: str + context_id: str + status: TaskStatus + artifacts: _containers.RepeatedCompositeFieldContainer[Artifact] + history: _containers.RepeatedCompositeFieldContainer[Message] + metadata: _struct_pb2.Struct + def __init__(self, id: _Optional[str] = ..., context_id: _Optional[str] = ..., status: _Optional[_Union[TaskStatus, _Mapping]] = ..., artifacts: _Optional[_Iterable[_Union[Artifact, _Mapping]]] = ..., history: _Optional[_Iterable[_Union[Message, _Mapping]]] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class TaskStatus(_message.Message): + __slots__ = ("state", "update", "timestamp") + STATE_FIELD_NUMBER: _ClassVar[int] + UPDATE_FIELD_NUMBER: _ClassVar[int] + TIMESTAMP_FIELD_NUMBER: _ClassVar[int] + state: TaskState + update: Message + timestamp: _timestamp_pb2.Timestamp + def __init__(self, state: _Optional[_Union[TaskState, str]] = ..., update: _Optional[_Union[Message, _Mapping]] = ..., timestamp: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ...) -> None: ... + +class Part(_message.Message): + __slots__ = ("text", "file", "data") + TEXT_FIELD_NUMBER: _ClassVar[int] + FILE_FIELD_NUMBER: _ClassVar[int] + DATA_FIELD_NUMBER: _ClassVar[int] + text: str + file: FilePart + data: DataPart + def __init__(self, text: _Optional[str] = ..., file: _Optional[_Union[FilePart, _Mapping]] = ..., data: _Optional[_Union[DataPart, _Mapping]] = ...) -> None: ... + +class FilePart(_message.Message): + __slots__ = ("file_with_uri", "file_with_bytes", "mime_type") + FILE_WITH_URI_FIELD_NUMBER: _ClassVar[int] + FILE_WITH_BYTES_FIELD_NUMBER: _ClassVar[int] + MIME_TYPE_FIELD_NUMBER: _ClassVar[int] + file_with_uri: str + file_with_bytes: bytes + mime_type: str + def __init__(self, file_with_uri: _Optional[str] = ..., file_with_bytes: _Optional[bytes] = ..., mime_type: _Optional[str] = ...) -> None: ... + +class DataPart(_message.Message): + __slots__ = ("data",) + DATA_FIELD_NUMBER: _ClassVar[int] + data: _struct_pb2.Struct + def __init__(self, data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class Message(_message.Message): + __slots__ = ("message_id", "context_id", "task_id", "role", "content", "metadata", "extensions") + MESSAGE_ID_FIELD_NUMBER: _ClassVar[int] + CONTEXT_ID_FIELD_NUMBER: _ClassVar[int] + TASK_ID_FIELD_NUMBER: _ClassVar[int] + ROLE_FIELD_NUMBER: _ClassVar[int] + CONTENT_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + EXTENSIONS_FIELD_NUMBER: _ClassVar[int] + message_id: str + context_id: str + task_id: str + role: Role + content: _containers.RepeatedCompositeFieldContainer[Part] + metadata: _struct_pb2.Struct + extensions: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, message_id: _Optional[str] = ..., context_id: _Optional[str] = ..., task_id: _Optional[str] = ..., role: _Optional[_Union[Role, str]] = ..., content: _Optional[_Iterable[_Union[Part, _Mapping]]] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., extensions: _Optional[_Iterable[str]] = ...) -> None: ... + +class Artifact(_message.Message): + __slots__ = ("artifact_id", "name", "description", "parts", "metadata", "extensions") + ARTIFACT_ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + PARTS_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + EXTENSIONS_FIELD_NUMBER: _ClassVar[int] + artifact_id: str + name: str + description: str + parts: _containers.RepeatedCompositeFieldContainer[Part] + metadata: _struct_pb2.Struct + extensions: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, artifact_id: _Optional[str] = ..., name: _Optional[str] = ..., description: _Optional[str] = ..., parts: _Optional[_Iterable[_Union[Part, _Mapping]]] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., extensions: _Optional[_Iterable[str]] = ...) -> None: ... + +class TaskStatusUpdateEvent(_message.Message): + __slots__ = ("task_id", "context_id", "status", "final", "metadata") + TASK_ID_FIELD_NUMBER: _ClassVar[int] + CONTEXT_ID_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + FINAL_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + task_id: str + context_id: str + status: TaskStatus + final: bool + metadata: _struct_pb2.Struct + def __init__(self, task_id: _Optional[str] = ..., context_id: _Optional[str] = ..., status: _Optional[_Union[TaskStatus, _Mapping]] = ..., final: bool = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class TaskArtifactUpdateEvent(_message.Message): + __slots__ = ("task_id", "context_id", "artifact", "append", "last_chunk", "metadata") + TASK_ID_FIELD_NUMBER: _ClassVar[int] + CONTEXT_ID_FIELD_NUMBER: _ClassVar[int] + ARTIFACT_FIELD_NUMBER: _ClassVar[int] + APPEND_FIELD_NUMBER: _ClassVar[int] + LAST_CHUNK_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + task_id: str + context_id: str + artifact: Artifact + append: bool + last_chunk: bool + metadata: _struct_pb2.Struct + def __init__(self, task_id: _Optional[str] = ..., context_id: _Optional[str] = ..., artifact: _Optional[_Union[Artifact, _Mapping]] = ..., append: bool = ..., last_chunk: bool = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class PushNotificationConfig(_message.Message): + __slots__ = ("id", "url", "token", "authentication") + ID_FIELD_NUMBER: _ClassVar[int] + URL_FIELD_NUMBER: _ClassVar[int] + TOKEN_FIELD_NUMBER: _ClassVar[int] + AUTHENTICATION_FIELD_NUMBER: _ClassVar[int] + id: str + url: str + token: str + authentication: AuthenticationInfo + def __init__(self, id: _Optional[str] = ..., url: _Optional[str] = ..., token: _Optional[str] = ..., authentication: _Optional[_Union[AuthenticationInfo, _Mapping]] = ...) -> None: ... + +class AuthenticationInfo(_message.Message): + __slots__ = ("schemes", "credentials") + SCHEMES_FIELD_NUMBER: _ClassVar[int] + CREDENTIALS_FIELD_NUMBER: _ClassVar[int] + schemes: _containers.RepeatedScalarFieldContainer[str] + credentials: str + def __init__(self, schemes: _Optional[_Iterable[str]] = ..., credentials: _Optional[str] = ...) -> None: ... + +class AgentCard(_message.Message): + __slots__ = ("name", "description", "url", "provider", "version", "documentation_url", "capabilities", "security_schemes", "security", "default_input_modes", "default_output_modes", "skills", "supports_authenticated_extended_card") + class SecuritySchemesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: SecurityScheme + def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[SecurityScheme, _Mapping]] = ...) -> None: ... + NAME_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + URL_FIELD_NUMBER: _ClassVar[int] + PROVIDER_FIELD_NUMBER: _ClassVar[int] + VERSION_FIELD_NUMBER: _ClassVar[int] + DOCUMENTATION_URL_FIELD_NUMBER: _ClassVar[int] + CAPABILITIES_FIELD_NUMBER: _ClassVar[int] + SECURITY_SCHEMES_FIELD_NUMBER: _ClassVar[int] + SECURITY_FIELD_NUMBER: _ClassVar[int] + DEFAULT_INPUT_MODES_FIELD_NUMBER: _ClassVar[int] + DEFAULT_OUTPUT_MODES_FIELD_NUMBER: _ClassVar[int] + SKILLS_FIELD_NUMBER: _ClassVar[int] + SUPPORTS_AUTHENTICATED_EXTENDED_CARD_FIELD_NUMBER: _ClassVar[int] + name: str + description: str + url: str + provider: AgentProvider + version: str + documentation_url: str + capabilities: AgentCapabilities + security_schemes: _containers.MessageMap[str, SecurityScheme] + security: _containers.RepeatedCompositeFieldContainer[Security] + default_input_modes: _containers.RepeatedScalarFieldContainer[str] + default_output_modes: _containers.RepeatedScalarFieldContainer[str] + skills: _containers.RepeatedCompositeFieldContainer[AgentSkill] + supports_authenticated_extended_card: bool + def __init__(self, name: _Optional[str] = ..., description: _Optional[str] = ..., url: _Optional[str] = ..., provider: _Optional[_Union[AgentProvider, _Mapping]] = ..., version: _Optional[str] = ..., documentation_url: _Optional[str] = ..., capabilities: _Optional[_Union[AgentCapabilities, _Mapping]] = ..., security_schemes: _Optional[_Mapping[str, SecurityScheme]] = ..., security: _Optional[_Iterable[_Union[Security, _Mapping]]] = ..., default_input_modes: _Optional[_Iterable[str]] = ..., default_output_modes: _Optional[_Iterable[str]] = ..., skills: _Optional[_Iterable[_Union[AgentSkill, _Mapping]]] = ..., supports_authenticated_extended_card: bool = ...) -> None: ... + +class AgentProvider(_message.Message): + __slots__ = ("url", "organization") + URL_FIELD_NUMBER: _ClassVar[int] + ORGANIZATION_FIELD_NUMBER: _ClassVar[int] + url: str + organization: str + def __init__(self, url: _Optional[str] = ..., organization: _Optional[str] = ...) -> None: ... + +class AgentCapabilities(_message.Message): + __slots__ = ("streaming", "push_notifications", "extensions") + STREAMING_FIELD_NUMBER: _ClassVar[int] + PUSH_NOTIFICATIONS_FIELD_NUMBER: _ClassVar[int] + EXTENSIONS_FIELD_NUMBER: _ClassVar[int] + streaming: bool + push_notifications: bool + extensions: _containers.RepeatedCompositeFieldContainer[AgentExtension] + def __init__(self, streaming: bool = ..., push_notifications: bool = ..., extensions: _Optional[_Iterable[_Union[AgentExtension, _Mapping]]] = ...) -> None: ... + +class AgentExtension(_message.Message): + __slots__ = ("uri", "description", "required", "params") + URI_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + REQUIRED_FIELD_NUMBER: _ClassVar[int] + PARAMS_FIELD_NUMBER: _ClassVar[int] + uri: str + description: str + required: bool + params: _struct_pb2.Struct + def __init__(self, uri: _Optional[str] = ..., description: _Optional[str] = ..., required: bool = ..., params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class AgentSkill(_message.Message): + __slots__ = ("id", "name", "description", "tags", "examples", "input_modes", "output_modes") + ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + TAGS_FIELD_NUMBER: _ClassVar[int] + EXAMPLES_FIELD_NUMBER: _ClassVar[int] + INPUT_MODES_FIELD_NUMBER: _ClassVar[int] + OUTPUT_MODES_FIELD_NUMBER: _ClassVar[int] + id: str + name: str + description: str + tags: _containers.RepeatedScalarFieldContainer[str] + examples: _containers.RepeatedScalarFieldContainer[str] + input_modes: _containers.RepeatedScalarFieldContainer[str] + output_modes: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., description: _Optional[str] = ..., tags: _Optional[_Iterable[str]] = ..., examples: _Optional[_Iterable[str]] = ..., input_modes: _Optional[_Iterable[str]] = ..., output_modes: _Optional[_Iterable[str]] = ...) -> None: ... + +class TaskPushNotificationConfig(_message.Message): + __slots__ = ("name", "push_notification_config") + NAME_FIELD_NUMBER: _ClassVar[int] + PUSH_NOTIFICATION_CONFIG_FIELD_NUMBER: _ClassVar[int] + name: str + push_notification_config: PushNotificationConfig + def __init__(self, name: _Optional[str] = ..., push_notification_config: _Optional[_Union[PushNotificationConfig, _Mapping]] = ...) -> None: ... + +class StringList(_message.Message): + __slots__ = ("list",) + LIST_FIELD_NUMBER: _ClassVar[int] + list: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, list: _Optional[_Iterable[str]] = ...) -> None: ... + +class Security(_message.Message): + __slots__ = ("schemes",) + class SchemesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: StringList + def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[StringList, _Mapping]] = ...) -> None: ... + SCHEMES_FIELD_NUMBER: _ClassVar[int] + schemes: _containers.MessageMap[str, StringList] + def __init__(self, schemes: _Optional[_Mapping[str, StringList]] = ...) -> None: ... + +class SecurityScheme(_message.Message): + __slots__ = ("api_key_security_scheme", "http_auth_security_scheme", "oauth2_security_scheme", "open_id_connect_security_scheme") + API_KEY_SECURITY_SCHEME_FIELD_NUMBER: _ClassVar[int] + HTTP_AUTH_SECURITY_SCHEME_FIELD_NUMBER: _ClassVar[int] + OAUTH2_SECURITY_SCHEME_FIELD_NUMBER: _ClassVar[int] + OPEN_ID_CONNECT_SECURITY_SCHEME_FIELD_NUMBER: _ClassVar[int] + api_key_security_scheme: APIKeySecurityScheme + http_auth_security_scheme: HTTPAuthSecurityScheme + oauth2_security_scheme: OAuth2SecurityScheme + open_id_connect_security_scheme: OpenIdConnectSecurityScheme + def __init__(self, api_key_security_scheme: _Optional[_Union[APIKeySecurityScheme, _Mapping]] = ..., http_auth_security_scheme: _Optional[_Union[HTTPAuthSecurityScheme, _Mapping]] = ..., oauth2_security_scheme: _Optional[_Union[OAuth2SecurityScheme, _Mapping]] = ..., open_id_connect_security_scheme: _Optional[_Union[OpenIdConnectSecurityScheme, _Mapping]] = ...) -> None: ... + +class APIKeySecurityScheme(_message.Message): + __slots__ = ("description", "location", "name") + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + LOCATION_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + description: str + location: str + name: str + def __init__(self, description: _Optional[str] = ..., location: _Optional[str] = ..., name: _Optional[str] = ...) -> None: ... + +class HTTPAuthSecurityScheme(_message.Message): + __slots__ = ("description", "scheme", "bearer_format") + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + SCHEME_FIELD_NUMBER: _ClassVar[int] + BEARER_FORMAT_FIELD_NUMBER: _ClassVar[int] + description: str + scheme: str + bearer_format: str + def __init__(self, description: _Optional[str] = ..., scheme: _Optional[str] = ..., bearer_format: _Optional[str] = ...) -> None: ... + +class OAuth2SecurityScheme(_message.Message): + __slots__ = ("description", "flows") + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + FLOWS_FIELD_NUMBER: _ClassVar[int] + description: str + flows: OAuthFlows + def __init__(self, description: _Optional[str] = ..., flows: _Optional[_Union[OAuthFlows, _Mapping]] = ...) -> None: ... + +class OpenIdConnectSecurityScheme(_message.Message): + __slots__ = ("description", "open_id_connect_url") + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + OPEN_ID_CONNECT_URL_FIELD_NUMBER: _ClassVar[int] + description: str + open_id_connect_url: str + def __init__(self, description: _Optional[str] = ..., open_id_connect_url: _Optional[str] = ...) -> None: ... + +class OAuthFlows(_message.Message): + __slots__ = ("authorization_code", "client_credentials", "implicit", "password") + AUTHORIZATION_CODE_FIELD_NUMBER: _ClassVar[int] + CLIENT_CREDENTIALS_FIELD_NUMBER: _ClassVar[int] + IMPLICIT_FIELD_NUMBER: _ClassVar[int] + PASSWORD_FIELD_NUMBER: _ClassVar[int] + authorization_code: AuthorizationCodeOAuthFlow + client_credentials: ClientCredentialsOAuthFlow + implicit: ImplicitOAuthFlow + password: PasswordOAuthFlow + def __init__(self, authorization_code: _Optional[_Union[AuthorizationCodeOAuthFlow, _Mapping]] = ..., client_credentials: _Optional[_Union[ClientCredentialsOAuthFlow, _Mapping]] = ..., implicit: _Optional[_Union[ImplicitOAuthFlow, _Mapping]] = ..., password: _Optional[_Union[PasswordOAuthFlow, _Mapping]] = ...) -> None: ... + +class AuthorizationCodeOAuthFlow(_message.Message): + __slots__ = ("authorization_url", "token_url", "refresh_url", "scopes") + class ScopesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + AUTHORIZATION_URL_FIELD_NUMBER: _ClassVar[int] + TOKEN_URL_FIELD_NUMBER: _ClassVar[int] + REFRESH_URL_FIELD_NUMBER: _ClassVar[int] + SCOPES_FIELD_NUMBER: _ClassVar[int] + authorization_url: str + token_url: str + refresh_url: str + scopes: _containers.ScalarMap[str, str] + def __init__(self, authorization_url: _Optional[str] = ..., token_url: _Optional[str] = ..., refresh_url: _Optional[str] = ..., scopes: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class ClientCredentialsOAuthFlow(_message.Message): + __slots__ = ("token_url", "refresh_url", "scopes") + class ScopesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + TOKEN_URL_FIELD_NUMBER: _ClassVar[int] + REFRESH_URL_FIELD_NUMBER: _ClassVar[int] + SCOPES_FIELD_NUMBER: _ClassVar[int] + token_url: str + refresh_url: str + scopes: _containers.ScalarMap[str, str] + def __init__(self, token_url: _Optional[str] = ..., refresh_url: _Optional[str] = ..., scopes: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class ImplicitOAuthFlow(_message.Message): + __slots__ = ("authorization_url", "refresh_url", "scopes") + class ScopesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + AUTHORIZATION_URL_FIELD_NUMBER: _ClassVar[int] + REFRESH_URL_FIELD_NUMBER: _ClassVar[int] + SCOPES_FIELD_NUMBER: _ClassVar[int] + authorization_url: str + refresh_url: str + scopes: _containers.ScalarMap[str, str] + def __init__(self, authorization_url: _Optional[str] = ..., refresh_url: _Optional[str] = ..., scopes: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class PasswordOAuthFlow(_message.Message): + __slots__ = ("token_url", "refresh_url", "scopes") + class ScopesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + TOKEN_URL_FIELD_NUMBER: _ClassVar[int] + REFRESH_URL_FIELD_NUMBER: _ClassVar[int] + SCOPES_FIELD_NUMBER: _ClassVar[int] + token_url: str + refresh_url: str + scopes: _containers.ScalarMap[str, str] + def __init__(self, token_url: _Optional[str] = ..., refresh_url: _Optional[str] = ..., scopes: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class SendMessageRequest(_message.Message): + __slots__ = ("request", "configuration", "metadata") + REQUEST_FIELD_NUMBER: _ClassVar[int] + CONFIGURATION_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + request: Message + configuration: SendMessageConfiguration + metadata: _struct_pb2.Struct + def __init__(self, request: _Optional[_Union[Message, _Mapping]] = ..., configuration: _Optional[_Union[SendMessageConfiguration, _Mapping]] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class GetTaskRequest(_message.Message): + __slots__ = ("name", "history_length") + NAME_FIELD_NUMBER: _ClassVar[int] + HISTORY_LENGTH_FIELD_NUMBER: _ClassVar[int] + name: str + history_length: int + def __init__(self, name: _Optional[str] = ..., history_length: _Optional[int] = ...) -> None: ... + +class CancelTaskRequest(_message.Message): + __slots__ = ("name",) + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class GetTaskPushNotificationRequest(_message.Message): + __slots__ = ("name",) + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class CreateTaskPushNotificationRequest(_message.Message): + __slots__ = ("parent", "config_id", "config") + PARENT_FIELD_NUMBER: _ClassVar[int] + CONFIG_ID_FIELD_NUMBER: _ClassVar[int] + CONFIG_FIELD_NUMBER: _ClassVar[int] + parent: str + config_id: str + config: TaskPushNotificationConfig + def __init__(self, parent: _Optional[str] = ..., config_id: _Optional[str] = ..., config: _Optional[_Union[TaskPushNotificationConfig, _Mapping]] = ...) -> None: ... + +class TaskSubscriptionRequest(_message.Message): + __slots__ = ("name",) + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class ListTaskPushNotificationRequest(_message.Message): + __slots__ = ("parent", "page_size", "page_token") + PARENT_FIELD_NUMBER: _ClassVar[int] + PAGE_SIZE_FIELD_NUMBER: _ClassVar[int] + PAGE_TOKEN_FIELD_NUMBER: _ClassVar[int] + parent: str + page_size: int + page_token: str + def __init__(self, parent: _Optional[str] = ..., page_size: _Optional[int] = ..., page_token: _Optional[str] = ...) -> None: ... + +class GetAgentCardRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class SendMessageResponse(_message.Message): + __slots__ = ("task", "msg") + TASK_FIELD_NUMBER: _ClassVar[int] + MSG_FIELD_NUMBER: _ClassVar[int] + task: Task + msg: Message + def __init__(self, task: _Optional[_Union[Task, _Mapping]] = ..., msg: _Optional[_Union[Message, _Mapping]] = ...) -> None: ... + +class StreamResponse(_message.Message): + __slots__ = ("task", "msg", "status_update", "artifact_update") + TASK_FIELD_NUMBER: _ClassVar[int] + MSG_FIELD_NUMBER: _ClassVar[int] + STATUS_UPDATE_FIELD_NUMBER: _ClassVar[int] + ARTIFACT_UPDATE_FIELD_NUMBER: _ClassVar[int] + task: Task + msg: Message + status_update: TaskStatusUpdateEvent + artifact_update: TaskArtifactUpdateEvent + def __init__(self, task: _Optional[_Union[Task, _Mapping]] = ..., msg: _Optional[_Union[Message, _Mapping]] = ..., status_update: _Optional[_Union[TaskStatusUpdateEvent, _Mapping]] = ..., artifact_update: _Optional[_Union[TaskArtifactUpdateEvent, _Mapping]] = ...) -> None: ... + +class ListTaskPushNotificationResponse(_message.Message): + __slots__ = ("configs", "next_page_token") + CONFIGS_FIELD_NUMBER: _ClassVar[int] + NEXT_PAGE_TOKEN_FIELD_NUMBER: _ClassVar[int] + configs: _containers.RepeatedCompositeFieldContainer[TaskPushNotificationConfig] + next_page_token: str + def __init__(self, configs: _Optional[_Iterable[_Union[TaskPushNotificationConfig, _Mapping]]] = ..., next_page_token: _Optional[str] = ...) -> None: ... diff --git a/src/a2a/grpc/a2a_pb2_grpc.py b/src/a2a/grpc/a2a_pb2_grpc.py new file mode 100644 index 000000000..01a283739 --- /dev/null +++ b/src/a2a/grpc/a2a_pb2_grpc.py @@ -0,0 +1,478 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from . import a2a_pb2 as a2a__pb2 + + +class A2AServiceStub(object): + """A2AService defines the gRPC version of the A2A protocol. This has a slightly + different shape than the JSONRPC version to better conform to AIP-127, + where appropriate. The nouns are AgentCard, Message, Task and + TaskPushNotification. + - Messages are not a standard resource so there is no get/delete/update/list + interface, only a send and stream custom methods. + - Tasks have a get interface and custom cancel and subscribe methods. + - TaskPushNotification are a resource whose parent is a task. They have get, + list and create methods. + - AgentCard is a static resource with only a get method. + Of particular note of deviation from JSONRPC approach the request metadata + fields are not present as they don't comply with AIP rules, and the + optional history_length on the get task method is not present as it also + violates AIP-127 and AIP-131. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendMessage = channel.unary_unary( + '/a2a.v1.A2AService/SendMessage', + request_serializer=a2a__pb2.SendMessageRequest.SerializeToString, + response_deserializer=a2a__pb2.SendMessageResponse.FromString, + _registered_method=True) + self.SendStreamingMessage = channel.unary_stream( + '/a2a.v1.A2AService/SendStreamingMessage', + request_serializer=a2a__pb2.SendMessageRequest.SerializeToString, + response_deserializer=a2a__pb2.StreamResponse.FromString, + _registered_method=True) + self.GetTask = channel.unary_unary( + '/a2a.v1.A2AService/GetTask', + request_serializer=a2a__pb2.GetTaskRequest.SerializeToString, + response_deserializer=a2a__pb2.Task.FromString, + _registered_method=True) + self.CancelTask = channel.unary_unary( + '/a2a.v1.A2AService/CancelTask', + request_serializer=a2a__pb2.CancelTaskRequest.SerializeToString, + response_deserializer=a2a__pb2.Task.FromString, + _registered_method=True) + self.TaskSubscription = channel.unary_stream( + '/a2a.v1.A2AService/TaskSubscription', + request_serializer=a2a__pb2.TaskSubscriptionRequest.SerializeToString, + response_deserializer=a2a__pb2.StreamResponse.FromString, + _registered_method=True) + self.CreateTaskPushNotification = channel.unary_unary( + '/a2a.v1.A2AService/CreateTaskPushNotification', + request_serializer=a2a__pb2.CreateTaskPushNotificationRequest.SerializeToString, + response_deserializer=a2a__pb2.TaskPushNotificationConfig.FromString, + _registered_method=True) + self.GetTaskPushNotification = channel.unary_unary( + '/a2a.v1.A2AService/GetTaskPushNotification', + request_serializer=a2a__pb2.GetTaskPushNotificationRequest.SerializeToString, + response_deserializer=a2a__pb2.TaskPushNotificationConfig.FromString, + _registered_method=True) + self.ListTaskPushNotification = channel.unary_unary( + '/a2a.v1.A2AService/ListTaskPushNotification', + request_serializer=a2a__pb2.ListTaskPushNotificationRequest.SerializeToString, + response_deserializer=a2a__pb2.ListTaskPushNotificationResponse.FromString, + _registered_method=True) + self.GetAgentCard = channel.unary_unary( + '/a2a.v1.A2AService/GetAgentCard', + request_serializer=a2a__pb2.GetAgentCardRequest.SerializeToString, + response_deserializer=a2a__pb2.AgentCard.FromString, + _registered_method=True) + + +class A2AServiceServicer(object): + """A2AService defines the gRPC version of the A2A protocol. This has a slightly + different shape than the JSONRPC version to better conform to AIP-127, + where appropriate. The nouns are AgentCard, Message, Task and + TaskPushNotification. + - Messages are not a standard resource so there is no get/delete/update/list + interface, only a send and stream custom methods. + - Tasks have a get interface and custom cancel and subscribe methods. + - TaskPushNotification are a resource whose parent is a task. They have get, + list and create methods. + - AgentCard is a static resource with only a get method. + Of particular note of deviation from JSONRPC approach the request metadata + fields are not present as they don't comply with AIP rules, and the + optional history_length on the get task method is not present as it also + violates AIP-127 and AIP-131. + """ + + def SendMessage(self, request, context): + """Send a message to the agent. This is a blocking call that will return the + task once it is completed, or a LRO if requested. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendStreamingMessage(self, request, context): + """SendStreamingMessage is a streaming call that will return a stream of + task update events until the Task is in an interrupted or terminal state. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetTask(self, request, context): + """Get the current state of a task from the agent. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CancelTask(self, request, context): + """Cancel a task from the agent. If supported one should expect no + more task updates for the task. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TaskSubscription(self, request, context): + """TaskSubscription is a streaming call that will return a stream of task + update events. This attaches the stream to an existing in process task. + If the task is complete the stream will return the completed task (like + GetTask) and close the stream. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreateTaskPushNotification(self, request, context): + """Set a push notification config for a task. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetTaskPushNotification(self, request, context): + """Get a push notification config for a task. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ListTaskPushNotification(self, request, context): + """Get a list of push notifications configured for a task. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetAgentCard(self, request, context): + """GetAgentCard returns the agent card for the agent. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_A2AServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SendMessage': grpc.unary_unary_rpc_method_handler( + servicer.SendMessage, + request_deserializer=a2a__pb2.SendMessageRequest.FromString, + response_serializer=a2a__pb2.SendMessageResponse.SerializeToString, + ), + 'SendStreamingMessage': grpc.unary_stream_rpc_method_handler( + servicer.SendStreamingMessage, + request_deserializer=a2a__pb2.SendMessageRequest.FromString, + response_serializer=a2a__pb2.StreamResponse.SerializeToString, + ), + 'GetTask': grpc.unary_unary_rpc_method_handler( + servicer.GetTask, + request_deserializer=a2a__pb2.GetTaskRequest.FromString, + response_serializer=a2a__pb2.Task.SerializeToString, + ), + 'CancelTask': grpc.unary_unary_rpc_method_handler( + servicer.CancelTask, + request_deserializer=a2a__pb2.CancelTaskRequest.FromString, + response_serializer=a2a__pb2.Task.SerializeToString, + ), + 'TaskSubscription': grpc.unary_stream_rpc_method_handler( + servicer.TaskSubscription, + request_deserializer=a2a__pb2.TaskSubscriptionRequest.FromString, + response_serializer=a2a__pb2.StreamResponse.SerializeToString, + ), + 'CreateTaskPushNotification': grpc.unary_unary_rpc_method_handler( + servicer.CreateTaskPushNotification, + request_deserializer=a2a__pb2.CreateTaskPushNotificationRequest.FromString, + response_serializer=a2a__pb2.TaskPushNotificationConfig.SerializeToString, + ), + 'GetTaskPushNotification': grpc.unary_unary_rpc_method_handler( + servicer.GetTaskPushNotification, + request_deserializer=a2a__pb2.GetTaskPushNotificationRequest.FromString, + response_serializer=a2a__pb2.TaskPushNotificationConfig.SerializeToString, + ), + 'ListTaskPushNotification': grpc.unary_unary_rpc_method_handler( + servicer.ListTaskPushNotification, + request_deserializer=a2a__pb2.ListTaskPushNotificationRequest.FromString, + response_serializer=a2a__pb2.ListTaskPushNotificationResponse.SerializeToString, + ), + 'GetAgentCard': grpc.unary_unary_rpc_method_handler( + servicer.GetAgentCard, + request_deserializer=a2a__pb2.GetAgentCardRequest.FromString, + response_serializer=a2a__pb2.AgentCard.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'a2a.v1.A2AService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('a2a.v1.A2AService', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class A2AService(object): + """A2AService defines the gRPC version of the A2A protocol. This has a slightly + different shape than the JSONRPC version to better conform to AIP-127, + where appropriate. The nouns are AgentCard, Message, Task and + TaskPushNotification. + - Messages are not a standard resource so there is no get/delete/update/list + interface, only a send and stream custom methods. + - Tasks have a get interface and custom cancel and subscribe methods. + - TaskPushNotification are a resource whose parent is a task. They have get, + list and create methods. + - AgentCard is a static resource with only a get method. + Of particular note of deviation from JSONRPC approach the request metadata + fields are not present as they don't comply with AIP rules, and the + optional history_length on the get task method is not present as it also + violates AIP-127 and AIP-131. + """ + + @staticmethod + def SendMessage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/SendMessage', + a2a__pb2.SendMessageRequest.SerializeToString, + a2a__pb2.SendMessageResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendStreamingMessage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/a2a.v1.A2AService/SendStreamingMessage', + a2a__pb2.SendMessageRequest.SerializeToString, + a2a__pb2.StreamResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetTask(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/GetTask', + a2a__pb2.GetTaskRequest.SerializeToString, + a2a__pb2.Task.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def CancelTask(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/CancelTask', + a2a__pb2.CancelTaskRequest.SerializeToString, + a2a__pb2.Task.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def TaskSubscription(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/a2a.v1.A2AService/TaskSubscription', + a2a__pb2.TaskSubscriptionRequest.SerializeToString, + a2a__pb2.StreamResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def CreateTaskPushNotification(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/CreateTaskPushNotification', + a2a__pb2.CreateTaskPushNotificationRequest.SerializeToString, + a2a__pb2.TaskPushNotificationConfig.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetTaskPushNotification(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/GetTaskPushNotification', + a2a__pb2.GetTaskPushNotificationRequest.SerializeToString, + a2a__pb2.TaskPushNotificationConfig.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def ListTaskPushNotification(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/ListTaskPushNotification', + a2a__pb2.ListTaskPushNotificationRequest.SerializeToString, + a2a__pb2.ListTaskPushNotificationResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetAgentCard(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/GetAgentCard', + a2a__pb2.GetAgentCardRequest.SerializeToString, + a2a__pb2.AgentCard.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index f0d2667d8..623843848 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -4,6 +4,7 @@ DefaultRequestHandler, ) from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler +from a2a.server.request_handlers.grpc_handler import GrpcHandler from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.response_helpers import ( build_error_response, @@ -14,6 +15,7 @@ __all__ = [ 'DefaultRequestHandler', 'JSONRPCHandler', + 'GrpcHandler', 'RequestHandler', 'build_error_response', 'prepare_response_object', diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 09b1d3049..660ef7ef2 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -3,6 +3,7 @@ from collections.abc import AsyncGenerator from typing import cast +import uuid from a2a.server.agent_execution import ( AgentExecutor, @@ -364,6 +365,8 @@ async def on_set_task_push_notification_config( if not task: raise ServerError(error=TaskNotFoundError()) + # Generate a unique id for the notification + params.pushNotificationConfig.id = str(uuid.uuid4()) await self._push_notifier.set_info( params.taskId, params.pushNotificationConfig, diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py new file mode 100644 index 000000000..66c24804b --- /dev/null +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -0,0 +1,358 @@ +import logging +import grpc +import contextlib + +from typing import AsyncIterable +from abc import ABC, abstractmethod + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import ( + AgentCard, + InternalError, + Message, + Task, + TaskArtifactUpdateEvent, + TaskNotFoundError, + TaskPushNotificationConfig, + TaskStatusUpdateEvent, +) +from a2a import types +from a2a.auth.user import User as A2AUser +from a2a.auth.user import UnauthenticatedUser +from a2a.server.context import ServerCallContext +from a2a.utils.errors import ServerError +from a2a.utils.helpers import validate, validate_async_generator +from a2a.utils import proto_utils +import a2a.grpc.a2a_pb2 as a2a_pb2 +import a2a.grpc.a2a_pb2_grpc as a2a_grpc + + +logger = logging.getLogger(__name__) + +# For now we use a trivial wrapper on the grpc context object + +class CallContextBuilder(ABC): + """A class for building ServerCallContexts using the Starlette Request.""" + + @abstractmethod + def build(self, context: grpc.ServicerContext) -> ServerCallContext: + """Builds a ServerCallContext from a gRPC Request.""" + + +class DefaultCallContextBuilder(CallContextBuilder): + """A default implementation of CallContextBuilder.""" + + def build(self, context: grpc.ServicerContext) -> ServerCallContext: + user = UnauthenticatedUser() + state = {} + with contextlib.suppress(Exception): + state['grpc_context'] = context + return ServerCallContext(user=user, state=state) + + +class GrpcHandler(a2a_grpc.A2AServiceServicer): + """Maps incoming gRPC requests to the appropriate request handler method + and formats responses.""" + + def __init__( + self, + agent_card: AgentCard, + request_handler: RequestHandler, + context_builder: CallContextBuilder = DefaultCallContextBuilder(), + ): + """Initializes the GrpcHandler. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The underlying `RequestHandler` instance to delegat +e requests to. + """ + self.agent_card = agent_card + self.request_handler = request_handler + self.context_builder = context_builder + + async def SendMessage( + self, + request: a2a_pb2.SendMessageRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.SendMessageResponse: + """Handles the 'SendMessage' gRPC method. + + Args: + request: The incoming `SendMessageRequest` object. + context: Context provided by the server. + + Returns: + A `SendMessageResponse` object containing the result (Task or Messag +e) + or throws an error response if a `ServerError` is raised by the han +dler. + """ + try: + # Construct the server context object + server_context = self.context_builder.build(context) + # Transform the proto object to the python internal objects + a2a_request = proto_utils.FromProto.message_send_params( + request, + ) + task_or_message = await self.request_handler.on_message_send( + a2a_request, server_context + ) + return proto_utils.ToProto.task_or_message(task_or_message) + except ServerError as e: + await self.abort_context(e, context) + + @validate_async_generator( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def SendStreamingMessage( + self, + request: a2a_pb2.SendMessageRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterable[a2a_pb2.StreamResponse]: + """Handles the 'StreamMessage' gRPC method. + + Yields response objects as they are produced by the underlying handler's + stream. + + Args: + request: The incoming `SendMessageRequest` object. + context: Context provided by the server. + + Yields: + `StreamResponse` objects containing streaming events + (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) + or gRPC error responses if a `ServerError` is raised. + """ + server_context = self.context_builder.build(context) + # Transform the proto object to the python internal objects + a2a_request = proto_utils.FromProto.message_send_params( + request, + ) + try: + async for event in self.request_handler.on_message_send_stream( + a2a_request, server_context + ): + yield proto_utils.ToProto.stream_response(event) + except ServerError as e: + await self.abort_context(e, context) + return + + async def CancelTask( + self, + request: a2a_pb2.CancelTaskRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.Task: + """Handles the 'CancelTask' gRPC method. + + Args: + request: The incoming `CancelTaskRequest` object. + context: Context provided by the server. + + Returns: + A `Task` object containing the updated Task or a gRPC error. + """ + try: + server_context = self.context_builder.build(context) + task_id_params = proto_utils.FromProto.task_id_params(request) + task = await self.request_handler.on_cancel_task( + task_id_params, server_context + ) + if task: + return proto_utils.ToProto.task(task) + self.abort_context(ServerError(error=TaskNotFoundError()), context) + except ServerError as e: + await self.abort_context(e, context) + + @validate_async_generator( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def TaskSubscription( + self, + request: a2a_pb2.TaskSubscriptionRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterable[a2a_pb2.StreamResponse]: + """Handles the 'TaskSubscription' gRPC method. + + Yields response objects as they are produced by the underlying handler's + stream. + + Args: + request: The incoming `TaskSubscriptionRequest` object. + context: Context provided by the server. + + Yields: + `StreamResponse` objects containing streaming events + """ + try: + server_context = self.context_builder.build(context) + async for event in self.request_handler.on_resubscribe_to_task( + proto_utils.FromProto.task_id_params(request), server_context, + ): + yield proto_utils.ToProto.stream_response(event) + except ServerError as e: + await self.abort_context(e, context) + + async def GetTaskPushNotification( + self, + request: a2a_pb2.GetTaskPushNotificationRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.TaskPushNotificationConfig: + """Handles the 'GetTaskPushNotification' gRPC method. + + Args: + request: The incoming `GetTaskPushNotificationConfigRequest` object. + context: Context provided by the server. + + Returns: + A `TaskPushNotificationConfig` object containing the config. + """ + try: + server_context = self.context_builder.build(context) + config = ( + await self.request_handler.on_get_task_push_notification_config( + proto_utils.FromProto.task_id_params(request), + server_context, + ) + ) + return proto_utils.ToProto.task_push_notification_config(config) + except ServerError as e: + await self.abort_context(e, context) + + @validate( + lambda self: self.agent_card.capabilities.pushNotifications, + 'Push notifications are not supported by the agent', + ) + async def CreateTaskPushNotification( + self, + request: a2a_pb2.CreateTaskPushNotificationRequest, + context: grpc.aio.ServicerContext, + ) -> TaskPushNotificationConfig: + """Handles the 'CreateTaskPushNotification' gRPC method. + + Requires the agent to support push notifications. + + Args: + request: The incoming `CreateTaskPushNotificationRequest` object. + context: Context provided by the server. + + Returns: + A `TaskPushNotificationConfig` object + + Raises: + ServerError: If push notifications are not supported by the agent + (due to the `@validate` decorator). + """ + try: + server_context = self.context_builder.build(context) + config = ( + await self.request_handler.on_set_task_push_notification_config( + proto_utils.FromProto.task_push_notification_config( + request, + ), + server_context, + ) + ) + return proto_utils.ToProto.task_push_notification_config(config) + except ServerError as e: + await self.abort_context(e, context) + + async def GetTask( + self, + request: a2a_pb2.GetTaskRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.Task: + """Handles the 'GetTask' gRPC method. + + Args: + request: The incoming `GetTaskRequest` object. + context: Context provided by the server. + + Returns: + A `Task` object. + """ + try: + server_context = self.context_builder.build(context) + task = await self.request_handler.on_get_task( + proto_utils.FromProto.task_query_params(request), server_context + ) + if task: + return proto_utils.ToProto.task(task) + self.abort_context(ServerError(error=TaskNotFoundError()), context) + except ServerError as e: + await self.abort_context(e, context) + + async def GetAgentCard( + self, + request: a2a_pb2.GetAgentCardRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.AgentCard: + return proto_utils.ToProto.agent_card(self.agent_card) + + async def abort_context( + self, error: ServerError, context: grpc.ServicerContext + ): + match error.error: + case types.JSONParseError(): + await context.abort( + grpc.StatusCode.INTERNAL, + f'JSONParseError: {error.error.message}', + ) + case types.InvalidRequestError(): + await context.abort( + grpc.StatusCode.INVALID_ARGUMENT, + f'InvalidRequestError: {error.error.message}', + ) + case types.MethodNotFoundError(): + await context.abort( + grpc.StatusCode.NOT_FOUND, + f'MethodNotFoundError: {error.error.message}', + ) + case types.InvalidParamsError(): + await context.abort( + grpc.StatusCode.INVALID_ARGUMENT, + f'InvalidParamsError: {error.error.message}', + ) + case types.InternalError(): + await context.abort( + grpc.StatusCode.INTERNAL, + f'InternalError: {error.error.message}', + ) + case types.TaskNotFoundError(): + await context.abort( + grpc.StatusCode.NOT_FOUND, + f'TaskNotFoundError: {error.error.message}', + ) + case types.TaskNotCancelableError(): + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + f'TaskNotCancelableError: {error.error.message}', + ) + case types.PushNotificationNotSupportedError(): + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + f'PushNotificationNotSupportedError: {error.error.message}', + ) + case types.UnsupportedOperationError(): + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + f'UnsupportedOperationError: {error.error.message}', + ) + case types.ContentTypeNotSupportedError(): + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + f'ContentTypeNotSupportedError: {error.error.message}', + ) + case types.InvalidAgentResponseError(): + await context.abort( + grpc.StatusCode.INTERNAL, + f'InvalidAgentResponseError: {error.error.message}', + ) + case _: + await context.abort( + grpc.StatusCode.UNKNOWN, + f'Unknown error type: {error.error}', + ) diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 243ac87b0..4260aa6e1 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -1,5 +1,5 @@ """General utility functions for the A2A Python SDK.""" - +import functools import logging from collections.abc import Callable @@ -147,6 +147,37 @@ def wrapper(self, *args, **kwargs): return decorator +def validate_async_generator( + expression: Callable[[Any], bool], error_message: str | None = None +): + """Decorator that validates if a given expression evaluates to True. + + Typically used on class methods to check capabilities or configuration + before executing the method's logic. If the expression is False, + a `ServerError` with an `UnsupportedOperationError` is raised. + + Args: + expression: A callable that takes the instance (`self`) as its argument + and returns a boolean. + error_message: An optional custom error message for the `UnsupportedOperationError`. + If None, the string representation of the expression will be used. + """ + + def decorator(function): + @functools.wraps(function) + async def wrapper(self, *args, **kwargs): + if not expression(self): + final_message = error_message or str(expression) + logger.error(f'Unsupported Operation: {final_message}') + raise ServerError( + UnsupportedOperationError(message=final_message) + ) + async for i in function(self, *args, **kwargs): + yield i + + return wrapper + + return decorator def are_modalities_compatible( server_output_modes: list[str] | None, client_output_modes: list[str] | None diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py new file mode 100644 index 000000000..bc78abbf8 --- /dev/null +++ b/src/a2a/utils/proto_utils.py @@ -0,0 +1,781 @@ +"""Utils for converting between proto and Python types.""" + +import json +from typing import Any, Dict +import re + +from a2a.grpc import a2a_pb2 +from a2a import types +from a2a.utils.errors import ServerError +from google.protobuf import struct_pb2 +from google.protobuf import json_format + + +# Regexp patterns for matching +_TASK_NAME_MATCH = r'tasks/(\w+)' +_TASK_PUSH_CONFIG_NAME_MATCH = r'tasks/(\w+)/pushNotifications/(\w+)' + + +class ToProto: + """Converts Python types to proto types.""" + + @classmethod + def message(cls, message: types.Message | None) -> a2a_pb2.Message | None: + if message is None: + return None + return a2a_pb2.Message( + message_id=message.messageId, + content=[ToProto.part(p) for p in message.parts], + context_id=message.contextId, + task_id=message.taskId, + role=cls.role(message.role.name), + metadata=ToProto.metadata(message.metadata), + ) + + @classmethod + def metadata( + cls, metadata: Dict[str, Any] | None + ) -> struct_pb2.Struct | None: + if metadata is None: + return None + return struct_pb2.Struct( + # TODO: Add support for other types. + fields={ + key: struct_pb2.Value(string_value=value) + for key, value in metadata.items() + if isinstance(value, str) + } + ) + + @classmethod + def part(cls, part: types.Part) -> a2a_pb2.Part: + if isinstance(part.root, types.TextPart): + return a2a_pb2.Part(text=part.root.text) + elif isinstance(part.root, types.FilePart): + return a2a_pb2.Part(file=ToProto.file(part.root.file)) + elif isinstance(part.root, types.DataPart): + return a2a_pb2.Part(data=ToProto.data(part.root.data)) + else: + raise ValueError(f'Unsupported part type: {part.root}') + + @classmethod + def data(cls, data: Dict[str, Any]) -> a2a_pb2.DataPart: + json_data = json.dumps(data) + return a2a_pb2.DataPart( + data=json_format.Parse( + json_data, + struct_pb2.Struct(), + ) + ) + + @classmethod + def file( + cls, file: types.FileWithUri | types.FileWithBytes + ) -> a2a_pb2.FilePart: + if isinstance(file, types.FileWithUri): + return a2a_pb2.FilePart(file_with_uri=file.uri) + return a2a_pb2.FilePart(file_with_bytes=file.bytes.encode('utf-8')) + + @classmethod + def task(cls, task: types.Task) -> a2a_pb2.Task: + return a2a_pb2.Task( + id=task.id, + context_id=task.contextId, + status=ToProto.task_status(task.status), + artifacts=([ + ToProto.artifact(a) for a in task.artifacts + ] if task.artifacts else None), + history=([ + ToProto.message(h) for h in task.history + ] if task.history else None), + ) + + @classmethod + def task_status(cls, status: types.TaskStatus) -> a2a_pb2.TaskStatus: + return a2a_pb2.TaskStatus( + state=ToProto.task_state(status.state), + update=ToProto.message(status.message), + ) + + @classmethod + def task_state(cls, state: types.TaskState) -> a2a_pb2.TaskState: + match state: + case types.TaskState.submitted: + return a2a_pb2.TaskState.TASK_STATE_SUBMITTED + case types.TaskState.working: + return a2a_pb2.TaskState.TASK_STATE_WORKING + case types.TaskState.completed: + return a2a_pb2.TaskState.TASK_STATE_COMPLETED + case types.TaskState.canceled: + return a2a_pb2.TaskState.TASK_STATE_CANCELLED + case types.TaskState.failed: + return a2a_pb2.TaskState.TASK_STATE_FAILED + case types.TaskState.input_required: + return a2a_pb2.TaskState.TASK_STATE_INPUT_REQUIRED + case _: + return a2a_pb2.TaskState.TASK_STATE_UNSPECIFIED + + @classmethod + def artifact(cls, artifact: types.Artifact) -> a2a_pb2.Artifact: + return a2a_pb2.Artifact( + artifact_id=artifact.artifactId, + description=artifact.description, + metadata=ToProto.metadata(artifact.metadata), + name=artifact.name, + parts=[ToProto.part(p) for p in artifact.parts], + ) + + @classmethod + def authentication_info( + cls, info: types.PushNotificationAuthenticationInfo + ) -> a2a_pb2.AuthenticationInfo: + return a2a_pb2.AuthenticationInfo( + schemes=info.schemes, + credentials=info.credentials, + ) + + @classmethod + def push_notification_config( + cls, config: types.PushNotificationConfig + ) -> a2a_pb2.PushNotificationConfig: + return a2a_pb2.PushNotificationConfig( + id=config.id if id else "", + url=config.url, + token=config.token, + authentication=ToProto.authentication_info(config.authentication), + ) + + @classmethod + def task_artifact_update_event( + cls, event: types.TaskArtifactUpdateEvent + ) -> a2a_pb2.TaskArtifactUpdateEvent: + return a2a_pb2.TaskArtifactUpdateEvent( + task_id=event.taskId, + context_id=event.contextId, + artifact=ToProto.artifact(event.artifact), + metadata=ToProto.metadata(event.metadata), + append=event.append, + last_chunk=event.lastChunk, + ) + + @classmethod + def task_status_update_event( + cls, event: types.TaskStatusUpdateEvent + ) -> a2a_pb2.TaskStatusUpdateEvent: + return a2a_pb2.TaskStatusUpdateEvent( + task_id=event.taskId, + context_id=event.contextId, + status=ToProto.task_status(event.status), + metadata=ToProto.metadata(event.metadata), + final=event.final, + ) + + @classmethod + def message_send_configuration( + cls, config: types.MessageSendConfiguration | None + ) -> a2a_pb2.SendMessageConfiguration: + if not config: + return a2a_pb2.SendMessageConfiguration() + return a2a_pb2.SendMessageConfiguration( + accepted_output_modes=list(config.acceptedOutputModes), + push_notification=ToProto.push_notification_config( + config.pushNotificationConfig + ), + history_length=config.historyLength, + blocking=config.blocking, + ) + + @classmethod + def update_event( + cls, event: types.Task | types.Message | types.TaskStatusUpdateEvent + | types.TaskArtifactUpdateEvent + ) -> a2a_pb2.StreamResponse: + """Converts a task, message, or task update event to a StreamResponse.""" + if isinstance(event, types.TaskStatusUpdateEvent): + return a2a_pb2.StreamResponse( + status_update=ToProto.task_status_update_event(event) + ) + elif isinstance(event, types.TaskArtifactUpdateEvent): + return a2a_pb2.StreamResponse( + artifact_update=ToProto.task_artifact_update_event(event) + ) + elif isinstance(event, types.Message): + return a2a_pb2.StreamResponse(msg=ToProto.message(event)) + elif isinstance(event, types.Task): + return a2a_pb2.StreamResponse(task=ToProto.task(event)) + else: + raise ValueError(f'Unsupported event type: {type(event)}') + + @classmethod + def task_or_message( + cls, event: types.Task | types.Message + ) -> a2a_pb2.SendMessageResponse: + if isinstance(event, types.Message): + return a2a_pb2.SendMessageResponse( + msg=cls.message(event), + ) + return a2a_pb2.SendMessageResponse( + task=cls.task(event), + ) + + @classmethod + def stream_response( + cls, + event: ( + types.Message | + types.Task | + types.TaskStatusUpdateEvent | + types.TaskArtifactUpdateEvent) + ) -> a2a_pb2.StreamResponse: + if isinstance(event, types.Message): + return a2a_pb2.StreamResponse(msg=cls.message(event)) + elif isinstance(event, types.Task): + return a2a_pb2.StreamResponse(task=cls.task(event)) + elif isinstance(event, types.TaskStatusUpdateEvent): + return a2a_pb2.StreamResponse( + status_update=cls.task_status_update_event(event), + ) + return a2a_pb2.StreamResponse( + artifact_update=cls.task_artifact_update_event(event), + ) + + @classmethod + def task_push_notification_config( + cls, + config: types.TaskPushNotificationConfig + ) -> a2a_pb2.TaskPushNotificationConfig: + return a2a_pb2.TaskPushNotificationConfig( + name=f'tasks/{config.taskId}/pushNotifications/{config.taskId}', + push_notification_config=cls.push_notification_config( + config.pushNotificationConfig, + ), + ) + + @classmethod + def agent_card( + cls, card: types.AgentCard, + ) -> a2a_pb2.AgentCard: + return a2a_pb2.AgentCard( + capabilities=cls.capabilities(card.capabilities), + default_input_modes=list(card.defaultInputModes), + default_output_modes=list(card.defaultOutputModes), + description=card.description, + documentation_url=card.documentationUrl, + name=card.name, + provider=cls.provider(card.provider), + security=cls.security(card.security), + security_schemes=cls.security_schemes(card.securitySchemes), + skills=[cls.skill(x) for x in card.skills] if card.skills else [], + url=card.url, + version=card.version, + supports_authenticated_extended_card=card.supportsAuthenticatedExtendedCard, + ) + + @classmethod + def capabilities( + cls, capabilities: types.AgentCapabilities + ) -> a2a_pb2.AgentCapabilities: + return a2a_pb2.AgentCapabilities( + streaming=capabilities.streaming, + push_notifications=capabilities.pushNotifications, + ) + + @classmethod + def provider( + cls, provider: types.AgentProvider | None + ) -> a2a_pb2.AgentProvider | None: + if not provider: + return None + return a2a_pb2.AgentProvider( + organization=provider.organization, + url=provider.url, + ) + + @classmethod + def security( + cls, security: list[dict[str, list[str]]] | None, + ) -> list[a2a_pb2.Security] | None: + if not security: + return None + rval: list[a2a_pb2.Security] = [] + for s in security: + rval.append( + a2a_pb2.Security( + schemes={ + k: a2a_pb2.StringList(list=v.list) for (k, v) in s.items() + } + ) + ) + return rval + + @classmethod + def security_schemes( + cls, schemes: dict[str, types.SecurityScheme] | None, + ) -> dict[str, a2a_pb2.SecurityScheme] | None: + if not schemes: + return None + return {k: cls.security_scheme(v) for (k, v) in schemes.items()} + + @classmethod + def security_scheme( + cls, scheme: types.SecurityScheme, + ) -> a2a_pb2.SecurityScheme: + if isinstance(scheme.root, types.ApiKeySecurityScheme): + return a2a_pb2.SecurityScheme( + api_key_security_scheme=a2a_pb2.APIKeySecurityScheme( + description=scheme.root.description, + location=scheme.root.in_, + name=scheme.root.name, + ) + ) + if isinstance(scheme.root, types.HTTPAuthSecurityScheme): + return a2a_pb2.SecurityScheme( + http_auth_security_scheme=a2a_pb2.HttpAuthSecurityScheme( + description=scheme.root.description, + scheme=scheme.root.scheme, + bearer_format=scheme.root.bearerFormat, + ) + ) + if isinstance(scheme.root, types.Oauth2SecurityScheme): + return a2a_pb2.SecurityScheme( + oauth2_security_scheme=a2a_pb2.OAuth2SecurityScheme( + description=scheme.root.description, + flows=cls.oauth2_flows(scheme.root.flows), + ) + ) + return a2a_pb2.SecurityScheme( + open_id_connect_security_scheme=a2a_pb2.OpenIdConnectSecurityScheme( + description=scheme.root.description, + open_id_connect_url=scheme.root.openIdConnectUrl, + ) + ) + + @classmethod + def oauth2_flows(cls, flows: types.OAuthFlows) -> a2a_pb2.OAuthFlows: + if flows.authorizationCode: + return a2a_pb2.OAuthFlows( + authorization_code=a2a_pb2.AuthorizationCodeAuthFlow( + authorization_url=flows.authorizationCode.authorizationUrl, + refresh_url=flows.authorizationCode.refreshUrl, + scopes={ + k: v for (k, v) in flows.authorizationCode.scopes.items() + }, + token_url=flows.authorizationCode.tokenUrl, + ), + ) + if flows.clientCredentials: + return a2a_pb2.OAuthFlows( + client_credentials=a2a_pb2.ClientCredentialsAuthFlow( + refresh_url=flows.clientCredentials.refreshUrl, + scopes={ + k:v for (k, v) in flows.clientCredentials.scopes.items() + }, + token_url=flows.client_credentials.tokenUrl, + ), + ) + if flows.implicit: + return a2a_pb2.OAuthFlows( + implicit=a2a_pb2.ImplicitOAuthFlow( + authorization_url=flows.implicit.authorization_Url, + refresh_url=flows.implicit.refreshUrl, + scopes={k: v for (k, v) in flows.implicit.scopes.items()}, + ), + ) + return a2a_pb2.OAuthFlows( + password=types.PasswordOAuthFlow( + refresh_url=flows.password.refreshUrl, + scopes={k: v for (k, v) in flows.password.scopes.items()}, + token_url=flows.password.tokenUrl, + ), + ) + + @classmethod + def skill(cls, skill: types.AgentSkill) -> a2a_pb2.AgentSkill: + return a2a_pb2.AgentSkill( + id=skill.id, + name=skill.name, + description=skill.description, + tags=skill.tags, + examples=skill.examples, + input_modes=skill.inputModes, + output_modes=skill.outputModes, + ) + + @classmethod + def role(cls, role: types.Role) -> a2a_pb2.Role: + match role: + case types.Role.user: + return a2a_pb2.Role.ROLE_USER + case types.Role.agent: + return a2a_pb2.Role.ROLE_AGENT + case _: + return a2a_pb2.Role.ROLE_UNSPECIFIED + + +class FromProto: + """Converts proto types to Python types.""" + + @classmethod + def message(cls, message: a2a_pb2.Message) -> types.Message: + return types.Message( + messageId=message.message_id, + parts=[FromProto.part(p) for p in message.content], + contextId=message.context_id, + taskId=message.task_id, + role=FromProto.role(message.role), + metadata=FromProto.metadata(message.metadata), + ) + + @classmethod + def metadata(cls, metadata: struct_pb2.Struct) -> Dict[str, Any]: + return { + key: value.string_value + for key, value in metadata.fields.items() + if value.string_value + } + + @classmethod + def part(cls, part: a2a_pb2.Part) -> types.Part: + if part.HasField('text'): + return types.Part(root=types.TextPart(text=part.text)) + elif part.HasField('file'): + return types.Part(root=types.FilePart(file=FromProto.file(part.file))) + elif part.HasField('data'): + return types.Part(root=types.DataPart(data=FromProto.data(part.data))) + else: + raise ValueError(f'Unsupported part type: {part}') + + @classmethod + def data(cls, data: a2a_pb2.DataPart) -> Dict[str, Any]: + json_data = json_format.MessageToJson(data.data) + return json.loads(json_data) + + @classmethod + def file( + cls, file: a2a_pb2.FilePart + ) -> types.FileWithUri | types.FileWithBytes: + if file.HasField('file_with_uri'): + return types.FileWithUri(uri=file.file_with_uri) + return types.FileWithBytes(bytes=file.file_with_bytes.decode('utf-8')) + + @classmethod + def task(cls, task: a2a_pb2.Task) -> types.Task: + return types.Task( + id=task.id, + contextId=task.context_id, + status=FromProto.task_status(task.status), + artifacts=[FromProto.artifact(a) for a in task.artifacts], + history=[FromProto.message(h) for h in task.history], + ) + + @classmethod + def task_status(cls, status: a2a_pb2.TaskStatus) -> types.TaskStatus: + return types.TaskStatus( + state=FromProto.task_state(status.state), + message=FromProto.message(status.update), + ) + + @classmethod + def task_state(cls, state: a2a_pb2.TaskState) -> types.TaskState: + match state: + case a2a_pb2.TaskState.TASK_STATE_SUBMITTED: + return types.TaskState.submitted + case a2a_pb2.TaskState.TASK_STATE_WORKING: + return types.TaskState.working + case a2a_pb2.TaskState.TASK_STATE_COMPLETED: + return types.TaskState.completed + case a2a_pb2.TaskState.TASK_STATE_CANCELLED: + return types.TaskState.canceled + case a2a_pb2.TaskState.TASK_STATE_FAILED: + return types.TaskState.failed + case a2a_pb2.TaskState.TASK_STATE_INPUT_REQUIRED: + return types.TaskState.input_required + case _: + return types.TaskState.unknown + + @classmethod + def artifact(cls, artifact: a2a_pb2.Artifact) -> types.Artifact: + return types.Artifact( + artifactId=artifact.artifact_id, + description=artifact.description, + metadata=FromProto.metadata(artifact.metadata), + name=artifact.name, + parts=[FromProto.part(p) for p in artifact.parts], + ) + + @classmethod + def task_artifact_update_event( + cls, event: a2a_pb2.TaskArtifactUpdateEvent + ) -> types.TaskArtifactUpdateEvent: + return types.TaskArtifactUpdateEvent( + taskId=event.task_id, + contextId=event.context_id, + artifact=FromProto.artifact(event.artifact), + metadata=FromProto.metadata(event.metadata), + append=event.append, + lastChunk=event.last_chunk, + ) + + @classmethod + def task_status_update_event( + cls, event: a2a_pb2.TaskStatusUpdateEvent + ) -> types.TaskStatusUpdateEvent: + return types.TaskStatusUpdateEvent( + taskId=event.task_id, + contextId=event.context_id, + status=FromProto.task_status(event.status), + metadata=FromProto.metadata(event.metadata), + final=event.final, + ) + + @classmethod + def push_notification_config( + cls, config: a2a_pb2.PushNotificationConfig + ) -> types.PushNotificationConfig: + return types.PushNotificationConfig( + id=config.id, + url=config.url, + token=config.token, + authentication=FromProto.authentication_info(config.authentication), + ) + + @classmethod + def authentication_info( + cls, info: a2a_pb2.AuthenticationInfo + ) -> types.PushNotificationAuthenticationInfo: + return types.PushNotificationAuthenticationInfo( + schemes=list(info.schemes), + credentials=info.credentials, + ) + + @classmethod + def message_send_configuration( + cls, config: a2a_pb2.SendMessageConfiguration + ) -> types.MessageSendConfiguration: + return types.MessageSendConfiguration( + acceptedOutputModes=list(config.accepted_output_modes), + pushNotificationConfig=FromProto.push_notification_config( + config.push_notification + ), + historyLength=config.history_length, + blocking=config.blocking, + ) + + @classmethod + def message_send_params( + cls, request: a2a_pb2.SendMessageRequest + ) -> types.MessageSendParams: + return types.MessageSendParams( + configuration=cls.message_send_configuration( + request.configuration + ), + message=cls.message(request.request), + metadata=cls.metadata(request.metadata), + ) + + @classmethod + def task_id_params( + cls, request: ( + a2a_pb2.CancelTaskRequest | + a2a_pb2.TaskSubscriptionRequest | + a2a_pb2.GetTaskPushNotificationRequest + ), + ) -> types.TaskIdParams: + # This is currently incomplete until the core sdk supports multiple + # configs for a single task. + if isinstance(request, a2a_pb2.GetTaskPushNotificationRequest): + m = re.match(_TASK_PUSH_NOTIFICATION_NAME_MATCH, task.name) + if not m: + raise ServerError( + error=types.InvalidParamsError( + message=f'No task for {task.name}' + ) + ) + return types.TaskIdParams(id = m.group(1)) + m = re.match(_TASK_NAME_MATCH, task.name) + if not m: + raise ServerError( + error=types.InvalidParamsError(message=f'No task for {task.name}') + ) + return types.TaskIdParams(id = m.group(1)) + + @classmethod + def task_push_notification_config( + cls, request: a2a_pb2.CreateTaskPushNotificationRequest, + ) -> types.TaskPushNotificationConfig: + m = re.match(_TASK_NAME_MATCH, request.parent) + if not m: + raise ServerError( + error=types.InvalidParamsError( + message=f'No task for {request.parent}' + ) + ) + return types.TaskPushNotificationConfig( + pushNotificationConfig=cls.push_notification_config( + request.config.push_notification_config, + ), + taskId=m.group(1), + ) + + @classmethod + def task_query_params( + cls, request: a2a_pb2.GetTaskRequest, + ) -> types.TaskQueryParams: + m = re.match(_TASK_NAME_MATCH, request.name) + if not m: + raise ServerError( + error=types.InvalidParamsError( + message=f'No task for {request.name}' + ) + ) + return types.TaskQueryParams( + historyLength=request.history_length if request.history_length else None, + id=m.group(1), + metadata=None, + ) + + @classmethod + def agent_card( + cls, card: a2a_pb2.AgentCard, + ) -> types.AgentCard: + return types.AgentCard( + capabilities=cls.capabilities(card.capabilities), + defaultInputModes=list(card.default_input_modes), + defaultOutputModes=list(card.default_output_modes), + description=card.description, + documentationUrl=card.documentation_url, + name=card.name, + provider=cls.provider(card.provider), + security=cls.security(card.security), + securitySchemes=cls.security_schemes(card.security_schemes), + skills=[cls.skill(x) for x in card.skills] if card.skills else [], + url=card.url, + version=card.version, + supportsAuthenticatedExtendedCard=card.supports_authenticated_extended_card, + ) + + @classmethod + def capabilities( + cls, capabilities: a2a_pb2.AgentCapabilities + ) -> types.AgentCapabilities: + return types.AgentCapabilities( + streaming=capabilities.streaming, + pushNotifications=capabilities.push_notifications, + ) + + @classmethod + def provider( + cls, provider: a2a_pb2.AgentProvider | None + ) -> types.AgentProvider | None: + if not provider: + return None + return types.AgentProvider( + organization=provider.organization, + url=provider.url, + ) + + @classmethod + def security( + cls, security: list[a2a_pb2.Security] | None, + ) -> list[dict[str, list[str]]] | None: + if not security: + return None + rval: list[dict[str, list[str]]] = [] + for s in security: + rval.append({k: list(v.list) for (k, v) in s.items()}) + return rval + + @classmethod + def security_schemes( + cls, schemes: dict[str, a2a_pb2.SecurityScheme] | None, + ) -> dict[str, types.SecurityScheme] | None: + if not schemes: + return None + return {k: cls.security_scheme(v) for (k, v) in schemes.items()} + + @classmethod + def security_scheme( + cls, scheme: a2a_pb2.SecurityScheme, + ) -> types.SecurityScheme: + if scheme.HasApiKeySecurityScheme(): + return types.SecurityScheme(root=types.APIKeySecurityScheme( + description=scheme.api_key_security_scheme.description, + in_=scheme.api_key_security_scheme.location, + name=scheme.api_key_security_scheme.name, + )) + if scheme.HasHttpAuthSecurityScheme(): + return types.SecurityScheme(root=types.HTTPAuthSecurityScheme( + description=scheme.http_auth_security_scheme.description, + scheme=scheme.http_auth_security_scheme.scheme, + bearerFormat=scheme.http_auth_security_scheme.bearer_format, + )) + if scheme.HasOauth2SecurityScheme(): + return types.SecurityScheme(root=types.OAuth2SecurityScheme( + description=scheme.oauth2_security_scheme.description, + flows=cls.oauth2_flows(scheme.oauth2_security_scheme.flows), + )) + return types.SecurityScheme(root=types.OpenIdConnectSecurityScheme( + description=scheme.open_id_connect_security_scheme.description, + openIdConnectUrl=scheme.open_id_connect_security_scheme.open_id_connect_url, + )) + + @classmethod + def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: + if flows.HasAuthorizationCode(): + return types.OAuthFlows( + authorizationCode=types.AuthorizationCodeAuthFlow( + authorizationUrl=flows.authorization_code.authorization_url, + refreshUrl=flows.authorization_code.refresh_url, + scopes={ + k: v for (k, v) in flows.authorization_code.scopes.items() + }, + tokenUrl=flows.authorization_code.token_url, + ), + ) + if flows.HasClientCredentials(): + return types.OAuthFlows( + clientCredentials=types.ClientCredentialsAuthFlow( + refreshUrl=flows.client_credentials.refresh_url, + scopes={ + k:v for (k, v) in flows.client_credentials.scopes.items() + }, + tokenUrl=flows.client_credentials.token_url, + ), + ) + if flows.HasImplicit(): + return types.OAuthFlows( + implicit=types.ImplicitOAuthFlow( + authorizationUrl=flows.implicit.authorization_url, + refreshUrl=flows.implicit.refresh_url, + scopes={k: v for (k, v) in flows.implicit.scopes.items()}, + ), + ) + return types.OAuthFlows( + password=types.PasswordOAuthFlow( + refreshUrl=flows.password.refresh_url, + scopes={k: v for (k, v) in flows.password.scopes.items()}, + tokenUrl=flows.password.token_url, + ), + ) + + @classmethod + def skill(cls, skill: a2a_pb2.AgentSkill) -> types.AgentSkill: + return types.AgentSkill( + id=skill.id, + name=skill.name, + description=skill.description, + tags=list(skill.tags), + examples=list(skill.examples), + inputModes=list(skill.input_modes), + outputModes=list(skill.output_modes), + ) + + @classmethod + def role(cls, role: a2a_pb2.Role) -> types.Role: + match role: + case a2a_pb2.Role.ROLE_USER: + return types.Role.user + case a2a_pb2.Role.ROLE_AGENT: + return types.Role.agent + case _: + return types.Role.agent From f81a002e29900230cdb9801cecbd8f76a5b1b2d0 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Wed, 4 Jun 2025 20:53:42 +0000 Subject: [PATCH 06/29] Update pyproject.toml --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 991fc8df4..058831464 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,9 @@ dependencies = [ "pydantic>=2.11.3", "sse-starlette>=2.3.3", "starlette>=0.46.2", + "grpcio>=1.60", + "grpcio-tools>=1.60", + "grpcio_reflection>=1.7.0", ] classifiers = [ From cd752b918a181bd0b914f69893f08a4d64e28429 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 22 May 2025 00:57:36 +0000 Subject: [PATCH 07/29] feat: Update to support python 3.12 - add case statement to use the asyncio.Queue.shutdown method for 3.13+ - add special handling to allow for similar semantics as asyncio.Queue.shutdown for 3.12 Tested on multiple samples in the a2a repo and some examples in this repo --- pyproject.toml | 4 ++++ src/a2a/server/events/event_consumer.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 991fc8df4..9dcbd527a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,10 +22,14 @@ classifiers = [ "Intended Audience :: Developers", "Programming Language :: Python", "Programming Language :: Python :: 3", +<<<<<<< HEAD "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", +======= + "Programming Language :: Python :: 3.12", +>>>>>>> 8ec734c (feat: Update to support python 3.12) "Operating System :: OS Independent", "Topic :: Software Development :: Libraries :: Python Modules", "License :: OSI Approved :: Apache Software License", diff --git a/src/a2a/server/events/event_consumer.py b/src/a2a/server/events/event_consumer.py index 518680695..2a96eca20 100644 --- a/src/a2a/server/events/event_consumer.py +++ b/src/a2a/server/events/event_consumer.py @@ -15,6 +15,12 @@ from a2a.utils.errors import ServerError from a2a.utils.telemetry import SpanKind, trace_class +# This is an alias to the execption for closed queue +QueueClosed = asyncio.QueueEmpty + +# When using python 3.13 or higher, the closed queue signal is QueueShutdown +if sys.version_info >= (3, 13): + QueueClosed = asyncio.QueueShutDown # This is an alias to the exception for closed queue QueueClosed = asyncio.QueueEmpty From 850ef3c7acfbeed6380242e753868add81cc48e2 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 22 May 2025 15:08:38 +0000 Subject: [PATCH 08/29] Change to 3.10 and provided detailed description about event queue usage --- pyproject.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9dcbd527a..991fc8df4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,14 +22,10 @@ classifiers = [ "Intended Audience :: Developers", "Programming Language :: Python", "Programming Language :: Python :: 3", -<<<<<<< HEAD "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", -======= - "Programming Language :: Python :: 3.12", ->>>>>>> 8ec734c (feat: Update to support python 3.12) "Operating System :: OS Independent", "Topic :: Software Development :: Libraries :: Python Modules", "License :: OSI Approved :: Apache Software License", From 85472d933c26edb6dee8398d32083be64cd4ae50 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 22 May 2025 23:44:47 +0000 Subject: [PATCH 09/29] fix merge conflict --- src/a2a/server/events/event_consumer.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/a2a/server/events/event_consumer.py b/src/a2a/server/events/event_consumer.py index 2a96eca20..a5c31317d 100644 --- a/src/a2a/server/events/event_consumer.py +++ b/src/a2a/server/events/event_consumer.py @@ -15,14 +15,8 @@ from a2a.utils.errors import ServerError from a2a.utils.telemetry import SpanKind, trace_class -# This is an alias to the execption for closed queue -QueueClosed = asyncio.QueueEmpty -# When using python 3.13 or higher, the closed queue signal is QueueShutdown -if sys.version_info >= (3, 13): - QueueClosed = asyncio.QueueShutDown - -# This is an alias to the exception for closed queue +# This is an alias to the execption for closed queue QueueClosed = asyncio.QueueEmpty # When using python 3.13 or higher, the closed queue signal is QueueShutdown From 3526a2a5e95302d23571474aed78020657852d44 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 22 May 2025 23:45:50 +0000 Subject: [PATCH 10/29] Fix typo --- src/a2a/server/events/event_consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/a2a/server/events/event_consumer.py b/src/a2a/server/events/event_consumer.py index a5c31317d..518680695 100644 --- a/src/a2a/server/events/event_consumer.py +++ b/src/a2a/server/events/event_consumer.py @@ -16,7 +16,7 @@ from a2a.utils.telemetry import SpanKind, trace_class -# This is an alias to the execption for closed queue +# This is an alias to the exception for closed queue QueueClosed = asyncio.QueueEmpty # When using python 3.13 or higher, the closed queue signal is QueueShutdown From d7e3bce94509e469947bc759388ebc1bbc691b49 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Wed, 4 Jun 2025 20:33:31 +0000 Subject: [PATCH 11/29] Add gRPC based support in the SDK. - Introduces an A2AGrpcClient to talk to server over grpc - Introduces GrpcHandler to tranlate the gRPC transport to the internal python data model and back. - A set of transform operations in proto_utils.py to handle the transform This is a starting point and can be iterated and optimized as we move forward, especially trying to automate the transform code so it stays in sync. --- buf.gen.yaml | 5 +- src/a2a/client/__init__.py | 2 + src/a2a/client/grpc_client.py | 192 +++++ src/a2a/grpc/__init__.py | 0 src/a2a/grpc/a2a_pb2.py | 180 ++++ src/a2a/grpc/a2a_pb2.pyi | 520 ++++++++++++ src/a2a/grpc/a2a_pb2_grpc.py | 478 +++++++++++ src/a2a/server/request_handlers/__init__.py | 2 + .../default_request_handler.py | 3 + .../server/request_handlers/grpc_handler.py | 358 ++++++++ src/a2a/utils/helpers.py | 33 +- src/a2a/utils/proto_utils.py | 781 ++++++++++++++++++ 12 files changed, 2552 insertions(+), 2 deletions(-) create mode 100644 src/a2a/client/grpc_client.py create mode 100644 src/a2a/grpc/__init__.py create mode 100644 src/a2a/grpc/a2a_pb2.py create mode 100644 src/a2a/grpc/a2a_pb2.pyi create mode 100644 src/a2a/grpc/a2a_pb2_grpc.py create mode 100644 src/a2a/server/request_handlers/grpc_handler.py create mode 100644 src/a2a/utils/proto_utils.py diff --git a/buf.gen.yaml b/buf.gen.yaml index 7102471ef..e5e18e657 100644 --- a/buf.gen.yaml +++ b/buf.gen.yaml @@ -19,9 +19,12 @@ managed: plugins: # Generate python protobuf related code # Generates *_pb2.py files, one for each .proto - - remote: buf.build/protocolbuffers/python + - remote: buf.build/protocolbuffers/python:v29.3 out: src/a2a/grpc # Generate python service code. # Generates *_pb2_grpc.py - remote: buf.build/grpc/python out: src/a2a/grpc + # Generates *_pb2.pyi files. + - remote: buf.build/protocolbuffers/pyi:v29.3 + out: src/a2a/grpc diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 3455c8675..1a2bb5449 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -1,6 +1,7 @@ """Client-side components for interacting with an A2A agent.""" from a2a.client.client import A2ACardResolver, A2AClient +from a2a.client.grpc_client import A2AGrpcClient from a2a.client.errors import ( A2AClientError, A2AClientHTTPError, @@ -15,5 +16,6 @@ 'A2AClientError', 'A2AClientHTTPError', 'A2AClientJSONError', + 'A2AGrpcClient', 'create_text_message_object', ] diff --git a/src/a2a/client/grpc_client.py b/src/a2a/client/grpc_client.py new file mode 100644 index 000000000..40fea5e42 --- /dev/null +++ b/src/a2a/client/grpc_client.py @@ -0,0 +1,192 @@ +import json +import logging +from collections.abc import AsyncGenerator +from typing import Any +from uuid import uuid4 +import grpc + +from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.types import ( + AgentCard, + MessageSendParams, + Task, + TaskStatusUpdateEvent, + TaskArtifactUpdateEvent, + TaskPushNotificationConfig, + TaskIdParams, + TaskQueryParams, + Message, +) +from a2a.utils.telemetry import SpanKind, trace_class +from a2a.utils import proto_utils +from a2a.grpc import a2a_pb2_grpc +from a2a.grpc import a2a_pb2 + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class A2AGrpcClient: + """A2A Client for interacting with an A2A agent via gRPC.""" + + def __init__( + self, + grpc_stub: a2a_pb2_grpc.A2AServiceStub, + agent_card: AgentCard, + ): + """Initializes the A2AGrpcClient. + + Requires an `AgentCard` + + Args: + grpc_stub: A grpc client stub. + agent_card: The agent card object. + """ + self.agent_card = agent_card + self.stub = grpc_stub + + async def send_message( + self, + request: MessageSendParams, + ) -> Task | Message : + """Sends a non-streaming message request to the agent. + + Args: + request: The `MessageSendParams` object containing the message and configuration. + + Returns: + A `Task` or `Message` object containing the agent's response. + """ + response = await self.stub.SendMessage( + a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=proto_utils.ToProto.metadata(request.metadata), + ) + ) + if response.task: + return proto_utils.FromProto.task(response.task) + return proto_utils.FromProto.message(response.msg) + + async def send_message_streaming( + self, + request: MessageSendParams, + ) -> AsyncGenerator[ + Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Sends a streaming message request to the agent and yields responses as they arrive. + + This method uses gRPC streams to receive a stream of updates from the + agent. + + Args: + request: The `MessageSendParams` object containing the message and configuration. + + Yields: + `Message` or `Task` or `TaskStatusUpdateEvent` or + `TaskArtifactUpdateEvent` objects as they are received in the + stream. + """ + stream = self.stub.SendStreamingMessage( + a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=proto_utils.ToProto.metadata(request.metadata), + ) + ) + while True: + response = await stream.read() + if response == grpc.aio.EOF: + break + if response.HasField('msg'): + yield proto_utils.FromProto.message(response.msg) + elif response.HasField('task'): + yield proto_utils.FromProto.task(response.task) + elif response.HasField('status_update'): + yield proto_utils.FromProto.task_status_update_event( + response.status_update + ) + elif response.HasField('artifact_update'): + yield proto_utils.FromProto.task_artifact_update_event( + response.artifact_update + ) + + async def get_task( + self, + request: TaskQueryParams, + ) -> Task: + """Retrieves the current state and history of a specific task. + + Args: + request: The `TaskQueryParams` object specifying the task ID + + Returns: + A `Task` object containing the Task or None. + """ + task = await self.stub.GetTask( + a2a_pb2.GetTaskRequest(name=f'tasks/{request.id}') + ) + return proto_utils.FromProto.task(task) + + async def cancel_task( + self, + request: TaskIdParams, + ) -> Task: + """Requests the agent to cancel a specific task. + + Args: + request: The `TaskIdParams` object specifying the task ID. + + Returns: + A `Task` object containing the updated Task + """ + task = await self.stub.CancelTask( + a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') + ) + return proto_utils.FromProto.task(task) + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task. + + Args: + request: The `TaskPushNotificationConfig` object specifying the task ID and configuration. + + Returns: + A `TaskPushNotificationConfig` object containing the config. + """ + config = await self.stub.CreateTaskPushNotification( + a2a_pb2.CreateTaskPushNotificationRequest( + parent='', + config_id='', + config=proto_utils.ToProto.task_push_notification_config( + request + ), + ) + ) + return proto_utils.FromProto.task_push_notification_config(config) + + async def get_task_callback( + self, + request: TaskIdParams, # TODO: Update to a push id params + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task. + + Args: + request: The `TaskIdParams` object specifying the task ID. + + Returns: + A `TaskPushNotificationConfig` object containing the configuration. + """ + config = await self.stub.GetTaskPushNotification( + a2a_pb2.GetTaskPushNotificationRequest( + name=f'tasks/{request.id}/pushNotification/undefined', + ) + ) + return proto_utils.FromProto.task_push_notification_config(config) diff --git a/src/a2a/grpc/__init__.py b/src/a2a/grpc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/a2a/grpc/a2a_pb2.py b/src/a2a/grpc/a2a_pb2.py new file mode 100644 index 000000000..81078b8be --- /dev/null +++ b/src/a2a/grpc/a2a_pb2.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: a2a.proto +# Protobuf Python Version: 5.29.3 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 3, + '', + 'a2a.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.api import annotations_pb2 as google_dot_api_dot_annotations__pb2 +from google.api import client_pb2 as google_dot_api_dot_client__pb2 +from google.api import field_behavior_pb2 as google_dot_api_dot_field__behavior__pb2 +from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 +from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ta2a.proto\x12\x06\x61\x32\x61.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xde\x01\n\x18SendMessageConfiguration\x12\x32\n\x15\x61\x63\x63\x65pted_output_modes\x18\x01 \x03(\tR\x13\x61\x63\x63\x65ptedOutputModes\x12K\n\x11push_notification\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.PushNotificationConfigR\x10pushNotification\x12%\n\x0ehistory_length\x18\x03 \x01(\x05R\rhistoryLength\x12\x1a\n\x08\x62locking\x18\x04 \x01(\x08R\x08\x62locking\"\xf1\x01\n\x04Task\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12*\n\x06status\x18\x03 \x01(\x0b\x32\x12.a2a.v1.TaskStatusR\x06status\x12.\n\tartifacts\x18\x04 \x03(\x0b\x32\x10.a2a.v1.ArtifactR\tartifacts\x12)\n\x07history\x18\x05 \x03(\x0b\x32\x0f.a2a.v1.MessageR\x07history\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\x98\x01\n\nTaskStatus\x12\'\n\x05state\x18\x01 \x01(\x0e\x32\x11.a2a.v1.TaskStateR\x05state\x12\'\n\x06update\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageR\x06update\x12\x38\n\ttimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\"t\n\x04Part\x12\x14\n\x04text\x18\x01 \x01(\tH\x00R\x04text\x12&\n\x04\x66ile\x18\x02 \x01(\x0b\x32\x10.a2a.v1.FilePartH\x00R\x04\x66ile\x12&\n\x04\x64\x61ta\x18\x03 \x01(\x0b\x32\x10.a2a.v1.DataPartH\x00R\x04\x64\x61taB\x06\n\x04part\"\x7f\n\x08\x46ilePart\x12$\n\rfile_with_uri\x18\x01 \x01(\tH\x00R\x0b\x66ileWithUri\x12(\n\x0f\x66ile_with_bytes\x18\x02 \x01(\x0cH\x00R\rfileWithBytes\x12\x1b\n\tmime_type\x18\x03 \x01(\tR\x08mimeTypeB\x06\n\x04\x66ile\"7\n\x08\x44\x61taPart\x12+\n\x04\x64\x61ta\x18\x01 \x01(\x0b\x32\x17.google.protobuf.StructR\x04\x64\x61ta\"\xff\x01\n\x07Message\x12\x1d\n\nmessage_id\x18\x01 \x01(\tR\tmessageId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12\x17\n\x07task_id\x18\x03 \x01(\tR\x06taskId\x12 \n\x04role\x18\x04 \x01(\x0e\x32\x0c.a2a.v1.RoleR\x04role\x12&\n\x07\x63ontent\x18\x05 \x03(\x0b\x32\x0c.a2a.v1.PartR\x07\x63ontent\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\x12\x1e\n\nextensions\x18\x07 \x03(\tR\nextensions\"\xda\x01\n\x08\x41rtifact\x12\x1f\n\x0b\x61rtifact_id\x18\x01 \x01(\tR\nartifactId\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x04 \x01(\tR\x0b\x64\x65scription\x12\"\n\x05parts\x18\x05 \x03(\x0b\x32\x0c.a2a.v1.PartR\x05parts\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\x12\x1e\n\nextensions\x18\x07 \x03(\tR\nextensions\"\xc6\x01\n\x15TaskStatusUpdateEvent\x12\x17\n\x07task_id\x18\x01 \x01(\tR\x06taskId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12*\n\x06status\x18\x03 \x01(\x0b\x32\x12.a2a.v1.TaskStatusR\x06status\x12\x14\n\x05\x66inal\x18\x04 \x01(\x08R\x05\x66inal\x12\x33\n\x08metadata\x18\x05 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\xeb\x01\n\x17TaskArtifactUpdateEvent\x12\x17\n\x07task_id\x18\x01 \x01(\tR\x06taskId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12,\n\x08\x61rtifact\x18\x03 \x01(\x0b\x32\x10.a2a.v1.ArtifactR\x08\x61rtifact\x12\x16\n\x06\x61ppend\x18\x04 \x01(\x08R\x06\x61ppend\x12\x1d\n\nlast_chunk\x18\x05 \x01(\x08R\tlastChunk\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\x94\x01\n\x16PushNotificationConfig\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x10\n\x03url\x18\x02 \x01(\tR\x03url\x12\x14\n\x05token\x18\x03 \x01(\tR\x05token\x12\x42\n\x0e\x61uthentication\x18\x04 \x01(\x0b\x32\x1a.a2a.v1.AuthenticationInfoR\x0e\x61uthentication\"P\n\x12\x41uthenticationInfo\x12\x18\n\x07schemes\x18\x01 \x03(\tR\x07schemes\x12 \n\x0b\x63redentials\x18\x02 \x01(\tR\x0b\x63redentials\"\xc8\x05\n\tAgentCard\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x02 \x01(\tR\x0b\x64\x65scription\x12\x10\n\x03url\x18\x03 \x01(\tR\x03url\x12\x31\n\x08provider\x18\x04 \x01(\x0b\x32\x15.a2a.v1.AgentProviderR\x08provider\x12\x18\n\x07version\x18\x05 \x01(\tR\x07version\x12+\n\x11\x64ocumentation_url\x18\x06 \x01(\tR\x10\x64ocumentationUrl\x12=\n\x0c\x63\x61pabilities\x18\x07 \x01(\x0b\x32\x19.a2a.v1.AgentCapabilitiesR\x0c\x63\x61pabilities\x12Q\n\x10security_schemes\x18\x08 \x03(\x0b\x32&.a2a.v1.AgentCard.SecuritySchemesEntryR\x0fsecuritySchemes\x12,\n\x08security\x18\t \x03(\x0b\x32\x10.a2a.v1.SecurityR\x08security\x12.\n\x13\x64\x65\x66\x61ult_input_modes\x18\n \x03(\tR\x11\x64\x65\x66\x61ultInputModes\x12\x30\n\x14\x64\x65\x66\x61ult_output_modes\x18\x0b \x03(\tR\x12\x64\x65\x66\x61ultOutputModes\x12*\n\x06skills\x18\x0c \x03(\x0b\x32\x12.a2a.v1.AgentSkillR\x06skills\x12O\n$supports_authenticated_extended_card\x18\r \x01(\x08R!supportsAuthenticatedExtendedCard\x1aZ\n\x14SecuritySchemesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x16.a2a.v1.SecuritySchemeR\x05value:\x02\x38\x01\"E\n\rAgentProvider\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\x12\"\n\x0corganization\x18\x02 \x01(\tR\x0corganization\"\x98\x01\n\x11\x41gentCapabilities\x12\x1c\n\tstreaming\x18\x01 \x01(\x08R\tstreaming\x12-\n\x12push_notifications\x18\x02 \x01(\x08R\x11pushNotifications\x12\x36\n\nextensions\x18\x03 \x03(\x0b\x32\x16.a2a.v1.AgentExtensionR\nextensions\"\x91\x01\n\x0e\x41gentExtension\x12\x10\n\x03uri\x18\x01 \x01(\tR\x03uri\x12 \n\x0b\x64\x65scription\x18\x02 \x01(\tR\x0b\x64\x65scription\x12\x1a\n\x08required\x18\x03 \x01(\x08R\x08required\x12/\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.StructR\x06params\"\xc6\x01\n\nAgentSkill\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x03 \x01(\tR\x0b\x64\x65scription\x12\x12\n\x04tags\x18\x04 \x03(\tR\x04tags\x12\x1a\n\x08\x65xamples\x18\x05 \x03(\tR\x08\x65xamples\x12\x1f\n\x0binput_modes\x18\x06 \x03(\tR\ninputModes\x12!\n\x0coutput_modes\x18\x07 \x03(\tR\x0boutputModes\"\x8a\x01\n\x1aTaskPushNotificationConfig\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12X\n\x18push_notification_config\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.PushNotificationConfigR\x16pushNotificationConfig\" \n\nStringList\x12\x12\n\x04list\x18\x01 \x03(\tR\x04list\"\x93\x01\n\x08Security\x12\x37\n\x07schemes\x18\x01 \x03(\x0b\x32\x1d.a2a.v1.Security.SchemesEntryR\x07schemes\x1aN\n\x0cSchemesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x12.a2a.v1.StringListR\x05value:\x02\x38\x01\"\x91\x03\n\x0eSecurityScheme\x12U\n\x17\x61pi_key_security_scheme\x18\x01 \x01(\x0b\x32\x1c.a2a.v1.APIKeySecuritySchemeH\x00R\x14\x61piKeySecurityScheme\x12[\n\x19http_auth_security_scheme\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.HTTPAuthSecuritySchemeH\x00R\x16httpAuthSecurityScheme\x12T\n\x16oauth2_security_scheme\x18\x03 \x01(\x0b\x32\x1c.a2a.v1.OAuth2SecuritySchemeH\x00R\x14oauth2SecurityScheme\x12k\n\x1fopen_id_connect_security_scheme\x18\x04 \x01(\x0b\x32#.a2a.v1.OpenIdConnectSecuritySchemeH\x00R\x1bopenIdConnectSecuritySchemeB\x08\n\x06scheme\"h\n\x14\x41PIKeySecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12\x1a\n\x08location\x18\x02 \x01(\tR\x08location\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\"w\n\x16HTTPAuthSecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12\x16\n\x06scheme\x18\x02 \x01(\tR\x06scheme\x12#\n\rbearer_format\x18\x03 \x01(\tR\x0c\x62\x65\x61rerFormat\"b\n\x14OAuth2SecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12(\n\x05\x66lows\x18\x02 \x01(\x0b\x32\x12.a2a.v1.OAuthFlowsR\x05\x66lows\"n\n\x1bOpenIdConnectSecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12-\n\x13open_id_connect_url\x18\x02 \x01(\tR\x10openIdConnectUrl\"\xb0\x02\n\nOAuthFlows\x12S\n\x12\x61uthorization_code\x18\x01 \x01(\x0b\x32\".a2a.v1.AuthorizationCodeOAuthFlowH\x00R\x11\x61uthorizationCode\x12S\n\x12\x63lient_credentials\x18\x02 \x01(\x0b\x32\".a2a.v1.ClientCredentialsOAuthFlowH\x00R\x11\x63lientCredentials\x12\x37\n\x08implicit\x18\x03 \x01(\x0b\x32\x19.a2a.v1.ImplicitOAuthFlowH\x00R\x08implicit\x12\x37\n\x08password\x18\x04 \x01(\x0b\x32\x19.a2a.v1.PasswordOAuthFlowH\x00R\x08passwordB\x06\n\x04\x66low\"\x8a\x02\n\x1a\x41uthorizationCodeOAuthFlow\x12+\n\x11\x61uthorization_url\x18\x01 \x01(\tR\x10\x61uthorizationUrl\x12\x1b\n\ttoken_url\x18\x02 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x03 \x01(\tR\nrefreshUrl\x12\x46\n\x06scopes\x18\x04 \x03(\x0b\x32..a2a.v1.AuthorizationCodeOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xdd\x01\n\x1a\x43lientCredentialsOAuthFlow\x12\x1b\n\ttoken_url\x18\x01 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12\x46\n\x06scopes\x18\x03 \x03(\x0b\x32..a2a.v1.ClientCredentialsOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xdb\x01\n\x11ImplicitOAuthFlow\x12+\n\x11\x61uthorization_url\x18\x01 \x01(\tR\x10\x61uthorizationUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12=\n\x06scopes\x18\x03 \x03(\x0b\x32%.a2a.v1.ImplicitOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xcb\x01\n\x11PasswordOAuthFlow\x12\x1b\n\ttoken_url\x18\x01 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12=\n\x06scopes\x18\x03 \x03(\x0b\x32%.a2a.v1.PasswordOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xc1\x01\n\x12SendMessageRequest\x12.\n\x07request\x18\x01 \x01(\x0b\x32\x0f.a2a.v1.MessageB\x03\xe0\x41\x02R\x07request\x12\x46\n\rconfiguration\x18\x02 \x01(\x0b\x32 .a2a.v1.SendMessageConfigurationR\rconfiguration\x12\x33\n\x08metadata\x18\x03 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"P\n\x0eGetTaskRequest\x12\x17\n\x04name\x18\x01 \x01(\tB\x03\xe0\x41\x02R\x04name\x12%\n\x0ehistory_length\x18\x02 \x01(\x05R\rhistoryLength\"\'\n\x11\x43\x61ncelTaskRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"4\n\x1eGetTaskPushNotificationRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"\xa3\x01\n!CreateTaskPushNotificationRequest\x12\x1b\n\x06parent\x18\x01 \x01(\tB\x03\xe0\x41\x02R\x06parent\x12 \n\tconfig_id\x18\x02 \x01(\tB\x03\xe0\x41\x02R\x08\x63onfigId\x12?\n\x06\x63onfig\x18\x03 \x01(\x0b\x32\".a2a.v1.TaskPushNotificationConfigB\x03\xe0\x41\x02R\x06\x63onfig\"-\n\x17TaskSubscriptionRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"u\n\x1fListTaskPushNotificationRequest\x12\x16\n\x06parent\x18\x01 \x01(\tR\x06parent\x12\x1b\n\tpage_size\x18\x02 \x01(\x05R\x08pageSize\x12\x1d\n\npage_token\x18\x03 \x01(\tR\tpageToken\"\x15\n\x13GetAgentCardRequest\"i\n\x13SendMessageResponse\x12\"\n\x04task\x18\x01 \x01(\x0b\x32\x0c.a2a.v1.TaskH\x00R\x04task\x12#\n\x03msg\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageH\x00R\x03msgB\t\n\x07payload\"\xf6\x01\n\x0eStreamResponse\x12\"\n\x04task\x18\x01 \x01(\x0b\x32\x0c.a2a.v1.TaskH\x00R\x04task\x12#\n\x03msg\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageH\x00R\x03msg\x12\x44\n\rstatus_update\x18\x03 \x01(\x0b\x32\x1d.a2a.v1.TaskStatusUpdateEventH\x00R\x0cstatusUpdate\x12J\n\x0f\x61rtifact_update\x18\x04 \x01(\x0b\x32\x1f.a2a.v1.TaskArtifactUpdateEventH\x00R\x0e\x61rtifactUpdateB\t\n\x07payload\"\x88\x01\n ListTaskPushNotificationResponse\x12<\n\x07\x63onfigs\x18\x01 \x03(\x0b\x32\".a2a.v1.TaskPushNotificationConfigR\x07\x63onfigs\x12&\n\x0fnext_page_token\x18\x02 \x01(\tR\rnextPageToken*\xfa\x01\n\tTaskState\x12\x1a\n\x16TASK_STATE_UNSPECIFIED\x10\x00\x12\x18\n\x14TASK_STATE_SUBMITTED\x10\x01\x12\x16\n\x12TASK_STATE_WORKING\x10\x02\x12\x18\n\x14TASK_STATE_COMPLETED\x10\x03\x12\x15\n\x11TASK_STATE_FAILED\x10\x04\x12\x18\n\x14TASK_STATE_CANCELLED\x10\x05\x12\x1d\n\x19TASK_STATE_INPUT_REQUIRED\x10\x06\x12\x17\n\x13TASK_STATE_REJECTED\x10\x07\x12\x1c\n\x18TASK_STATE_AUTH_REQUIRED\x10\x08*;\n\x04Role\x12\x14\n\x10ROLE_UNSPECIFIED\x10\x00\x12\r\n\tROLE_USER\x10\x01\x12\x0e\n\nROLE_AGENT\x10\x02\x32\xd0\x08\n\nA2AService\x12\x64\n\x0bSendMessage\x12\x1a.a2a.v1.SendMessageRequest\x1a\x1b.a2a.v1.SendMessageResponse\"\x1c\x82\xd3\xe4\x93\x02\x16\"\x11/v1//message:send:\x01*\x12k\n\x14SendStreamingMessage\x12\x1a.a2a.v1.SendMessageRequest\x1a\x16.a2a.v1.StreamResponse\"\x1d\x82\xd3\xe4\x93\x02\x17\"\x12/v1/message:stream:\x01*0\x01\x12R\n\x07GetTask\x12\x16.a2a.v1.GetTaskRequest\x1a\x0c.a2a.v1.Task\"!\xda\x41\x04name\x82\xd3\xe4\x93\x02\x14\x12\x12/v1/{name=tasks/*}\x12W\n\nCancelTask\x12\x19.a2a.v1.CancelTaskRequest\x1a\x0c.a2a.v1.Task\" \x82\xd3\xe4\x93\x02\x1a\"\x15/v1/tasks/{id}:cancel:\x01*\x12s\n\x10TaskSubscription\x12\x1f.a2a.v1.TaskSubscriptionRequest\x1a\x16.a2a.v1.StreamResponse\"$\x82\xd3\xe4\x93\x02\x1e\x12\x1c/v1/{name=tasks/*}:subscribe0\x01\x12\xb2\x01\n\x1a\x43reateTaskPushNotification\x12).a2a.v1.CreateTaskPushNotificationRequest\x1a\".a2a.v1.TaskPushNotificationConfig\"E\xda\x41\rparent,config\x82\xd3\xe4\x93\x02/\"%/v1/{parent=task/*/pushNotifications}:\x06\x63onfig\x12\x9c\x01\n\x17GetTaskPushNotification\x12&.a2a.v1.GetTaskPushNotificationRequest\x1a\".a2a.v1.TaskPushNotificationConfig\"5\xda\x41\x04name\x82\xd3\xe4\x93\x02(\x12&/v1/{name=tasks/*/pushNotifications/*}\x12\xa6\x01\n\x18ListTaskPushNotification\x12\'.a2a.v1.ListTaskPushNotificationRequest\x1a(.a2a.v1.ListTaskPushNotificationResponse\"7\xda\x41\x06parent\x82\xd3\xe4\x93\x02(\x12&/v1/{parent=tasks/*}/pushNotifications\x12P\n\x0cGetAgentCard\x12\x1b.a2a.v1.GetAgentCardRequest\x1a\x11.a2a.v1.AgentCard\"\x10\x82\xd3\xe4\x93\x02\n\x12\x08/v1/cardBi\n\ncom.a2a.v1B\x08\x41\x32\x61ProtoP\x01Z\x18google.golang.org/a2a/v1\xa2\x02\x03\x41XX\xaa\x02\x06\x41\x32\x61.V1\xca\x02\x06\x41\x32\x61\\V1\xe2\x02\x12\x41\x32\x61\\V1\\GPBMetadata\xea\x02\x07\x41\x32\x61::V1b\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'a2a_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\ncom.a2a.v1B\010A2aProtoP\001Z\030google.golang.org/a2a/v1\242\002\003AXX\252\002\006A2a.V1\312\002\006A2a\\V1\342\002\022A2a\\V1\\GPBMetadata\352\002\007A2a::V1' + _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._loaded_options = None + _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._serialized_options = b'8\001' + _globals['_SECURITY_SCHEMESENTRY']._loaded_options = None + _globals['_SECURITY_SCHEMESENTRY']._serialized_options = b'8\001' + _globals['_AUTHORIZATIONCODEOAUTHFLOW_SCOPESENTRY']._loaded_options = None + _globals['_AUTHORIZATIONCODEOAUTHFLOW_SCOPESENTRY']._serialized_options = b'8\001' + _globals['_CLIENTCREDENTIALSOAUTHFLOW_SCOPESENTRY']._loaded_options = None + _globals['_CLIENTCREDENTIALSOAUTHFLOW_SCOPESENTRY']._serialized_options = b'8\001' + _globals['_IMPLICITOAUTHFLOW_SCOPESENTRY']._loaded_options = None + _globals['_IMPLICITOAUTHFLOW_SCOPESENTRY']._serialized_options = b'8\001' + _globals['_PASSWORDOAUTHFLOW_SCOPESENTRY']._loaded_options = None + _globals['_PASSWORDOAUTHFLOW_SCOPESENTRY']._serialized_options = b'8\001' + _globals['_SENDMESSAGEREQUEST'].fields_by_name['request']._loaded_options = None + _globals['_SENDMESSAGEREQUEST'].fields_by_name['request']._serialized_options = b'\340A\002' + _globals['_GETTASKREQUEST'].fields_by_name['name']._loaded_options = None + _globals['_GETTASKREQUEST'].fields_by_name['name']._serialized_options = b'\340A\002' + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['parent']._loaded_options = None + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['parent']._serialized_options = b'\340A\002' + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['config_id']._loaded_options = None + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['config_id']._serialized_options = b'\340A\002' + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['config']._loaded_options = None + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['config']._serialized_options = b'\340A\002' + _globals['_A2ASERVICE'].methods_by_name['SendMessage']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['SendMessage']._serialized_options = b'\202\323\344\223\002\026\"\021/v1//message:send:\001*' + _globals['_A2ASERVICE'].methods_by_name['SendStreamingMessage']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['SendStreamingMessage']._serialized_options = b'\202\323\344\223\002\027\"\022/v1/message:stream:\001*' + _globals['_A2ASERVICE'].methods_by_name['GetTask']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['GetTask']._serialized_options = b'\332A\004name\202\323\344\223\002\024\022\022/v1/{name=tasks/*}' + _globals['_A2ASERVICE'].methods_by_name['CancelTask']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['CancelTask']._serialized_options = b'\202\323\344\223\002\032\"\025/v1/tasks/{id}:cancel:\001*' + _globals['_A2ASERVICE'].methods_by_name['TaskSubscription']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['TaskSubscription']._serialized_options = b'\202\323\344\223\002\036\022\034/v1/{name=tasks/*}:subscribe' + _globals['_A2ASERVICE'].methods_by_name['CreateTaskPushNotification']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['CreateTaskPushNotification']._serialized_options = b'\332A\rparent,config\202\323\344\223\002/\"%/v1/{parent=task/*/pushNotifications}:\006config' + _globals['_A2ASERVICE'].methods_by_name['GetTaskPushNotification']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['GetTaskPushNotification']._serialized_options = b'\332A\004name\202\323\344\223\002(\022&/v1/{name=tasks/*/pushNotifications/*}' + _globals['_A2ASERVICE'].methods_by_name['ListTaskPushNotification']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['ListTaskPushNotification']._serialized_options = b'\332A\006parent\202\323\344\223\002(\022&/v1/{parent=tasks/*}/pushNotifications' + _globals['_A2ASERVICE'].methods_by_name['GetAgentCard']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['GetAgentCard']._serialized_options = b'\202\323\344\223\002\n\022\010/v1/card' + _globals['_TASKSTATE']._serialized_start=7161 + _globals['_TASKSTATE']._serialized_end=7411 + _globals['_ROLE']._serialized_start=7413 + _globals['_ROLE']._serialized_end=7472 + _globals['_SENDMESSAGECONFIGURATION']._serialized_start=173 + _globals['_SENDMESSAGECONFIGURATION']._serialized_end=395 + _globals['_TASK']._serialized_start=398 + _globals['_TASK']._serialized_end=639 + _globals['_TASKSTATUS']._serialized_start=642 + _globals['_TASKSTATUS']._serialized_end=794 + _globals['_PART']._serialized_start=796 + _globals['_PART']._serialized_end=912 + _globals['_FILEPART']._serialized_start=914 + _globals['_FILEPART']._serialized_end=1041 + _globals['_DATAPART']._serialized_start=1043 + _globals['_DATAPART']._serialized_end=1098 + _globals['_MESSAGE']._serialized_start=1101 + _globals['_MESSAGE']._serialized_end=1356 + _globals['_ARTIFACT']._serialized_start=1359 + _globals['_ARTIFACT']._serialized_end=1577 + _globals['_TASKSTATUSUPDATEEVENT']._serialized_start=1580 + _globals['_TASKSTATUSUPDATEEVENT']._serialized_end=1778 + _globals['_TASKARTIFACTUPDATEEVENT']._serialized_start=1781 + _globals['_TASKARTIFACTUPDATEEVENT']._serialized_end=2016 + _globals['_PUSHNOTIFICATIONCONFIG']._serialized_start=2019 + _globals['_PUSHNOTIFICATIONCONFIG']._serialized_end=2167 + _globals['_AUTHENTICATIONINFO']._serialized_start=2169 + _globals['_AUTHENTICATIONINFO']._serialized_end=2249 + _globals['_AGENTCARD']._serialized_start=2252 + _globals['_AGENTCARD']._serialized_end=2964 + _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._serialized_start=2874 + _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._serialized_end=2964 + _globals['_AGENTPROVIDER']._serialized_start=2966 + _globals['_AGENTPROVIDER']._serialized_end=3035 + _globals['_AGENTCAPABILITIES']._serialized_start=3038 + _globals['_AGENTCAPABILITIES']._serialized_end=3190 + _globals['_AGENTEXTENSION']._serialized_start=3193 + _globals['_AGENTEXTENSION']._serialized_end=3338 + _globals['_AGENTSKILL']._serialized_start=3341 + _globals['_AGENTSKILL']._serialized_end=3539 + _globals['_TASKPUSHNOTIFICATIONCONFIG']._serialized_start=3542 + _globals['_TASKPUSHNOTIFICATIONCONFIG']._serialized_end=3680 + _globals['_STRINGLIST']._serialized_start=3682 + _globals['_STRINGLIST']._serialized_end=3714 + _globals['_SECURITY']._serialized_start=3717 + _globals['_SECURITY']._serialized_end=3864 + _globals['_SECURITY_SCHEMESENTRY']._serialized_start=3786 + _globals['_SECURITY_SCHEMESENTRY']._serialized_end=3864 + _globals['_SECURITYSCHEME']._serialized_start=3867 + _globals['_SECURITYSCHEME']._serialized_end=4268 + _globals['_APIKEYSECURITYSCHEME']._serialized_start=4270 + _globals['_APIKEYSECURITYSCHEME']._serialized_end=4374 + _globals['_HTTPAUTHSECURITYSCHEME']._serialized_start=4376 + _globals['_HTTPAUTHSECURITYSCHEME']._serialized_end=4495 + _globals['_OAUTH2SECURITYSCHEME']._serialized_start=4497 + _globals['_OAUTH2SECURITYSCHEME']._serialized_end=4595 + _globals['_OPENIDCONNECTSECURITYSCHEME']._serialized_start=4597 + _globals['_OPENIDCONNECTSECURITYSCHEME']._serialized_end=4707 + _globals['_OAUTHFLOWS']._serialized_start=4710 + _globals['_OAUTHFLOWS']._serialized_end=5014 + _globals['_AUTHORIZATIONCODEOAUTHFLOW']._serialized_start=5017 + _globals['_AUTHORIZATIONCODEOAUTHFLOW']._serialized_end=5283 + _globals['_AUTHORIZATIONCODEOAUTHFLOW_SCOPESENTRY']._serialized_start=5226 + _globals['_AUTHORIZATIONCODEOAUTHFLOW_SCOPESENTRY']._serialized_end=5283 + _globals['_CLIENTCREDENTIALSOAUTHFLOW']._serialized_start=5286 + _globals['_CLIENTCREDENTIALSOAUTHFLOW']._serialized_end=5507 + _globals['_CLIENTCREDENTIALSOAUTHFLOW_SCOPESENTRY']._serialized_start=5226 + _globals['_CLIENTCREDENTIALSOAUTHFLOW_SCOPESENTRY']._serialized_end=5283 + _globals['_IMPLICITOAUTHFLOW']._serialized_start=5510 + _globals['_IMPLICITOAUTHFLOW']._serialized_end=5729 + _globals['_IMPLICITOAUTHFLOW_SCOPESENTRY']._serialized_start=5226 + _globals['_IMPLICITOAUTHFLOW_SCOPESENTRY']._serialized_end=5283 + _globals['_PASSWORDOAUTHFLOW']._serialized_start=5732 + _globals['_PASSWORDOAUTHFLOW']._serialized_end=5935 + _globals['_PASSWORDOAUTHFLOW_SCOPESENTRY']._serialized_start=5226 + _globals['_PASSWORDOAUTHFLOW_SCOPESENTRY']._serialized_end=5283 + _globals['_SENDMESSAGEREQUEST']._serialized_start=5938 + _globals['_SENDMESSAGEREQUEST']._serialized_end=6131 + _globals['_GETTASKREQUEST']._serialized_start=6133 + _globals['_GETTASKREQUEST']._serialized_end=6213 + _globals['_CANCELTASKREQUEST']._serialized_start=6215 + _globals['_CANCELTASKREQUEST']._serialized_end=6254 + _globals['_GETTASKPUSHNOTIFICATIONREQUEST']._serialized_start=6256 + _globals['_GETTASKPUSHNOTIFICATIONREQUEST']._serialized_end=6308 + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST']._serialized_start=6311 + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST']._serialized_end=6474 + _globals['_TASKSUBSCRIPTIONREQUEST']._serialized_start=6476 + _globals['_TASKSUBSCRIPTIONREQUEST']._serialized_end=6521 + _globals['_LISTTASKPUSHNOTIFICATIONREQUEST']._serialized_start=6523 + _globals['_LISTTASKPUSHNOTIFICATIONREQUEST']._serialized_end=6640 + _globals['_GETAGENTCARDREQUEST']._serialized_start=6642 + _globals['_GETAGENTCARDREQUEST']._serialized_end=6663 + _globals['_SENDMESSAGERESPONSE']._serialized_start=6665 + _globals['_SENDMESSAGERESPONSE']._serialized_end=6770 + _globals['_STREAMRESPONSE']._serialized_start=6773 + _globals['_STREAMRESPONSE']._serialized_end=7019 + _globals['_LISTTASKPUSHNOTIFICATIONRESPONSE']._serialized_start=7022 + _globals['_LISTTASKPUSHNOTIFICATIONRESPONSE']._serialized_end=7158 + _globals['_A2ASERVICE']._serialized_start=7475 + _globals['_A2ASERVICE']._serialized_end=8579 +# @@protoc_insertion_point(module_scope) diff --git a/src/a2a/grpc/a2a_pb2.pyi b/src/a2a/grpc/a2a_pb2.pyi new file mode 100644 index 000000000..8d2fad9b8 --- /dev/null +++ b/src/a2a/grpc/a2a_pb2.pyi @@ -0,0 +1,520 @@ +from google.api import annotations_pb2 as _annotations_pb2 +from google.api import client_pb2 as _client_pb2 +from google.api import field_behavior_pb2 as _field_behavior_pb2 +from google.protobuf import struct_pb2 as _struct_pb2 +from google.protobuf import timestamp_pb2 as _timestamp_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class TaskState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + TASK_STATE_UNSPECIFIED: _ClassVar[TaskState] + TASK_STATE_SUBMITTED: _ClassVar[TaskState] + TASK_STATE_WORKING: _ClassVar[TaskState] + TASK_STATE_COMPLETED: _ClassVar[TaskState] + TASK_STATE_FAILED: _ClassVar[TaskState] + TASK_STATE_CANCELLED: _ClassVar[TaskState] + TASK_STATE_INPUT_REQUIRED: _ClassVar[TaskState] + TASK_STATE_REJECTED: _ClassVar[TaskState] + TASK_STATE_AUTH_REQUIRED: _ClassVar[TaskState] + +class Role(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + ROLE_UNSPECIFIED: _ClassVar[Role] + ROLE_USER: _ClassVar[Role] + ROLE_AGENT: _ClassVar[Role] +TASK_STATE_UNSPECIFIED: TaskState +TASK_STATE_SUBMITTED: TaskState +TASK_STATE_WORKING: TaskState +TASK_STATE_COMPLETED: TaskState +TASK_STATE_FAILED: TaskState +TASK_STATE_CANCELLED: TaskState +TASK_STATE_INPUT_REQUIRED: TaskState +TASK_STATE_REJECTED: TaskState +TASK_STATE_AUTH_REQUIRED: TaskState +ROLE_UNSPECIFIED: Role +ROLE_USER: Role +ROLE_AGENT: Role + +class SendMessageConfiguration(_message.Message): + __slots__ = ("accepted_output_modes", "push_notification", "history_length", "blocking") + ACCEPTED_OUTPUT_MODES_FIELD_NUMBER: _ClassVar[int] + PUSH_NOTIFICATION_FIELD_NUMBER: _ClassVar[int] + HISTORY_LENGTH_FIELD_NUMBER: _ClassVar[int] + BLOCKING_FIELD_NUMBER: _ClassVar[int] + accepted_output_modes: _containers.RepeatedScalarFieldContainer[str] + push_notification: PushNotificationConfig + history_length: int + blocking: bool + def __init__(self, accepted_output_modes: _Optional[_Iterable[str]] = ..., push_notification: _Optional[_Union[PushNotificationConfig, _Mapping]] = ..., history_length: _Optional[int] = ..., blocking: bool = ...) -> None: ... + +class Task(_message.Message): + __slots__ = ("id", "context_id", "status", "artifacts", "history", "metadata") + ID_FIELD_NUMBER: _ClassVar[int] + CONTEXT_ID_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + ARTIFACTS_FIELD_NUMBER: _ClassVar[int] + HISTORY_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + id: str + context_id: str + status: TaskStatus + artifacts: _containers.RepeatedCompositeFieldContainer[Artifact] + history: _containers.RepeatedCompositeFieldContainer[Message] + metadata: _struct_pb2.Struct + def __init__(self, id: _Optional[str] = ..., context_id: _Optional[str] = ..., status: _Optional[_Union[TaskStatus, _Mapping]] = ..., artifacts: _Optional[_Iterable[_Union[Artifact, _Mapping]]] = ..., history: _Optional[_Iterable[_Union[Message, _Mapping]]] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class TaskStatus(_message.Message): + __slots__ = ("state", "update", "timestamp") + STATE_FIELD_NUMBER: _ClassVar[int] + UPDATE_FIELD_NUMBER: _ClassVar[int] + TIMESTAMP_FIELD_NUMBER: _ClassVar[int] + state: TaskState + update: Message + timestamp: _timestamp_pb2.Timestamp + def __init__(self, state: _Optional[_Union[TaskState, str]] = ..., update: _Optional[_Union[Message, _Mapping]] = ..., timestamp: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ...) -> None: ... + +class Part(_message.Message): + __slots__ = ("text", "file", "data") + TEXT_FIELD_NUMBER: _ClassVar[int] + FILE_FIELD_NUMBER: _ClassVar[int] + DATA_FIELD_NUMBER: _ClassVar[int] + text: str + file: FilePart + data: DataPart + def __init__(self, text: _Optional[str] = ..., file: _Optional[_Union[FilePart, _Mapping]] = ..., data: _Optional[_Union[DataPart, _Mapping]] = ...) -> None: ... + +class FilePart(_message.Message): + __slots__ = ("file_with_uri", "file_with_bytes", "mime_type") + FILE_WITH_URI_FIELD_NUMBER: _ClassVar[int] + FILE_WITH_BYTES_FIELD_NUMBER: _ClassVar[int] + MIME_TYPE_FIELD_NUMBER: _ClassVar[int] + file_with_uri: str + file_with_bytes: bytes + mime_type: str + def __init__(self, file_with_uri: _Optional[str] = ..., file_with_bytes: _Optional[bytes] = ..., mime_type: _Optional[str] = ...) -> None: ... + +class DataPart(_message.Message): + __slots__ = ("data",) + DATA_FIELD_NUMBER: _ClassVar[int] + data: _struct_pb2.Struct + def __init__(self, data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class Message(_message.Message): + __slots__ = ("message_id", "context_id", "task_id", "role", "content", "metadata", "extensions") + MESSAGE_ID_FIELD_NUMBER: _ClassVar[int] + CONTEXT_ID_FIELD_NUMBER: _ClassVar[int] + TASK_ID_FIELD_NUMBER: _ClassVar[int] + ROLE_FIELD_NUMBER: _ClassVar[int] + CONTENT_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + EXTENSIONS_FIELD_NUMBER: _ClassVar[int] + message_id: str + context_id: str + task_id: str + role: Role + content: _containers.RepeatedCompositeFieldContainer[Part] + metadata: _struct_pb2.Struct + extensions: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, message_id: _Optional[str] = ..., context_id: _Optional[str] = ..., task_id: _Optional[str] = ..., role: _Optional[_Union[Role, str]] = ..., content: _Optional[_Iterable[_Union[Part, _Mapping]]] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., extensions: _Optional[_Iterable[str]] = ...) -> None: ... + +class Artifact(_message.Message): + __slots__ = ("artifact_id", "name", "description", "parts", "metadata", "extensions") + ARTIFACT_ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + PARTS_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + EXTENSIONS_FIELD_NUMBER: _ClassVar[int] + artifact_id: str + name: str + description: str + parts: _containers.RepeatedCompositeFieldContainer[Part] + metadata: _struct_pb2.Struct + extensions: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, artifact_id: _Optional[str] = ..., name: _Optional[str] = ..., description: _Optional[str] = ..., parts: _Optional[_Iterable[_Union[Part, _Mapping]]] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., extensions: _Optional[_Iterable[str]] = ...) -> None: ... + +class TaskStatusUpdateEvent(_message.Message): + __slots__ = ("task_id", "context_id", "status", "final", "metadata") + TASK_ID_FIELD_NUMBER: _ClassVar[int] + CONTEXT_ID_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + FINAL_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + task_id: str + context_id: str + status: TaskStatus + final: bool + metadata: _struct_pb2.Struct + def __init__(self, task_id: _Optional[str] = ..., context_id: _Optional[str] = ..., status: _Optional[_Union[TaskStatus, _Mapping]] = ..., final: bool = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class TaskArtifactUpdateEvent(_message.Message): + __slots__ = ("task_id", "context_id", "artifact", "append", "last_chunk", "metadata") + TASK_ID_FIELD_NUMBER: _ClassVar[int] + CONTEXT_ID_FIELD_NUMBER: _ClassVar[int] + ARTIFACT_FIELD_NUMBER: _ClassVar[int] + APPEND_FIELD_NUMBER: _ClassVar[int] + LAST_CHUNK_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + task_id: str + context_id: str + artifact: Artifact + append: bool + last_chunk: bool + metadata: _struct_pb2.Struct + def __init__(self, task_id: _Optional[str] = ..., context_id: _Optional[str] = ..., artifact: _Optional[_Union[Artifact, _Mapping]] = ..., append: bool = ..., last_chunk: bool = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class PushNotificationConfig(_message.Message): + __slots__ = ("id", "url", "token", "authentication") + ID_FIELD_NUMBER: _ClassVar[int] + URL_FIELD_NUMBER: _ClassVar[int] + TOKEN_FIELD_NUMBER: _ClassVar[int] + AUTHENTICATION_FIELD_NUMBER: _ClassVar[int] + id: str + url: str + token: str + authentication: AuthenticationInfo + def __init__(self, id: _Optional[str] = ..., url: _Optional[str] = ..., token: _Optional[str] = ..., authentication: _Optional[_Union[AuthenticationInfo, _Mapping]] = ...) -> None: ... + +class AuthenticationInfo(_message.Message): + __slots__ = ("schemes", "credentials") + SCHEMES_FIELD_NUMBER: _ClassVar[int] + CREDENTIALS_FIELD_NUMBER: _ClassVar[int] + schemes: _containers.RepeatedScalarFieldContainer[str] + credentials: str + def __init__(self, schemes: _Optional[_Iterable[str]] = ..., credentials: _Optional[str] = ...) -> None: ... + +class AgentCard(_message.Message): + __slots__ = ("name", "description", "url", "provider", "version", "documentation_url", "capabilities", "security_schemes", "security", "default_input_modes", "default_output_modes", "skills", "supports_authenticated_extended_card") + class SecuritySchemesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: SecurityScheme + def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[SecurityScheme, _Mapping]] = ...) -> None: ... + NAME_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + URL_FIELD_NUMBER: _ClassVar[int] + PROVIDER_FIELD_NUMBER: _ClassVar[int] + VERSION_FIELD_NUMBER: _ClassVar[int] + DOCUMENTATION_URL_FIELD_NUMBER: _ClassVar[int] + CAPABILITIES_FIELD_NUMBER: _ClassVar[int] + SECURITY_SCHEMES_FIELD_NUMBER: _ClassVar[int] + SECURITY_FIELD_NUMBER: _ClassVar[int] + DEFAULT_INPUT_MODES_FIELD_NUMBER: _ClassVar[int] + DEFAULT_OUTPUT_MODES_FIELD_NUMBER: _ClassVar[int] + SKILLS_FIELD_NUMBER: _ClassVar[int] + SUPPORTS_AUTHENTICATED_EXTENDED_CARD_FIELD_NUMBER: _ClassVar[int] + name: str + description: str + url: str + provider: AgentProvider + version: str + documentation_url: str + capabilities: AgentCapabilities + security_schemes: _containers.MessageMap[str, SecurityScheme] + security: _containers.RepeatedCompositeFieldContainer[Security] + default_input_modes: _containers.RepeatedScalarFieldContainer[str] + default_output_modes: _containers.RepeatedScalarFieldContainer[str] + skills: _containers.RepeatedCompositeFieldContainer[AgentSkill] + supports_authenticated_extended_card: bool + def __init__(self, name: _Optional[str] = ..., description: _Optional[str] = ..., url: _Optional[str] = ..., provider: _Optional[_Union[AgentProvider, _Mapping]] = ..., version: _Optional[str] = ..., documentation_url: _Optional[str] = ..., capabilities: _Optional[_Union[AgentCapabilities, _Mapping]] = ..., security_schemes: _Optional[_Mapping[str, SecurityScheme]] = ..., security: _Optional[_Iterable[_Union[Security, _Mapping]]] = ..., default_input_modes: _Optional[_Iterable[str]] = ..., default_output_modes: _Optional[_Iterable[str]] = ..., skills: _Optional[_Iterable[_Union[AgentSkill, _Mapping]]] = ..., supports_authenticated_extended_card: bool = ...) -> None: ... + +class AgentProvider(_message.Message): + __slots__ = ("url", "organization") + URL_FIELD_NUMBER: _ClassVar[int] + ORGANIZATION_FIELD_NUMBER: _ClassVar[int] + url: str + organization: str + def __init__(self, url: _Optional[str] = ..., organization: _Optional[str] = ...) -> None: ... + +class AgentCapabilities(_message.Message): + __slots__ = ("streaming", "push_notifications", "extensions") + STREAMING_FIELD_NUMBER: _ClassVar[int] + PUSH_NOTIFICATIONS_FIELD_NUMBER: _ClassVar[int] + EXTENSIONS_FIELD_NUMBER: _ClassVar[int] + streaming: bool + push_notifications: bool + extensions: _containers.RepeatedCompositeFieldContainer[AgentExtension] + def __init__(self, streaming: bool = ..., push_notifications: bool = ..., extensions: _Optional[_Iterable[_Union[AgentExtension, _Mapping]]] = ...) -> None: ... + +class AgentExtension(_message.Message): + __slots__ = ("uri", "description", "required", "params") + URI_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + REQUIRED_FIELD_NUMBER: _ClassVar[int] + PARAMS_FIELD_NUMBER: _ClassVar[int] + uri: str + description: str + required: bool + params: _struct_pb2.Struct + def __init__(self, uri: _Optional[str] = ..., description: _Optional[str] = ..., required: bool = ..., params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class AgentSkill(_message.Message): + __slots__ = ("id", "name", "description", "tags", "examples", "input_modes", "output_modes") + ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + TAGS_FIELD_NUMBER: _ClassVar[int] + EXAMPLES_FIELD_NUMBER: _ClassVar[int] + INPUT_MODES_FIELD_NUMBER: _ClassVar[int] + OUTPUT_MODES_FIELD_NUMBER: _ClassVar[int] + id: str + name: str + description: str + tags: _containers.RepeatedScalarFieldContainer[str] + examples: _containers.RepeatedScalarFieldContainer[str] + input_modes: _containers.RepeatedScalarFieldContainer[str] + output_modes: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., description: _Optional[str] = ..., tags: _Optional[_Iterable[str]] = ..., examples: _Optional[_Iterable[str]] = ..., input_modes: _Optional[_Iterable[str]] = ..., output_modes: _Optional[_Iterable[str]] = ...) -> None: ... + +class TaskPushNotificationConfig(_message.Message): + __slots__ = ("name", "push_notification_config") + NAME_FIELD_NUMBER: _ClassVar[int] + PUSH_NOTIFICATION_CONFIG_FIELD_NUMBER: _ClassVar[int] + name: str + push_notification_config: PushNotificationConfig + def __init__(self, name: _Optional[str] = ..., push_notification_config: _Optional[_Union[PushNotificationConfig, _Mapping]] = ...) -> None: ... + +class StringList(_message.Message): + __slots__ = ("list",) + LIST_FIELD_NUMBER: _ClassVar[int] + list: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, list: _Optional[_Iterable[str]] = ...) -> None: ... + +class Security(_message.Message): + __slots__ = ("schemes",) + class SchemesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: StringList + def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[StringList, _Mapping]] = ...) -> None: ... + SCHEMES_FIELD_NUMBER: _ClassVar[int] + schemes: _containers.MessageMap[str, StringList] + def __init__(self, schemes: _Optional[_Mapping[str, StringList]] = ...) -> None: ... + +class SecurityScheme(_message.Message): + __slots__ = ("api_key_security_scheme", "http_auth_security_scheme", "oauth2_security_scheme", "open_id_connect_security_scheme") + API_KEY_SECURITY_SCHEME_FIELD_NUMBER: _ClassVar[int] + HTTP_AUTH_SECURITY_SCHEME_FIELD_NUMBER: _ClassVar[int] + OAUTH2_SECURITY_SCHEME_FIELD_NUMBER: _ClassVar[int] + OPEN_ID_CONNECT_SECURITY_SCHEME_FIELD_NUMBER: _ClassVar[int] + api_key_security_scheme: APIKeySecurityScheme + http_auth_security_scheme: HTTPAuthSecurityScheme + oauth2_security_scheme: OAuth2SecurityScheme + open_id_connect_security_scheme: OpenIdConnectSecurityScheme + def __init__(self, api_key_security_scheme: _Optional[_Union[APIKeySecurityScheme, _Mapping]] = ..., http_auth_security_scheme: _Optional[_Union[HTTPAuthSecurityScheme, _Mapping]] = ..., oauth2_security_scheme: _Optional[_Union[OAuth2SecurityScheme, _Mapping]] = ..., open_id_connect_security_scheme: _Optional[_Union[OpenIdConnectSecurityScheme, _Mapping]] = ...) -> None: ... + +class APIKeySecurityScheme(_message.Message): + __slots__ = ("description", "location", "name") + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + LOCATION_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + description: str + location: str + name: str + def __init__(self, description: _Optional[str] = ..., location: _Optional[str] = ..., name: _Optional[str] = ...) -> None: ... + +class HTTPAuthSecurityScheme(_message.Message): + __slots__ = ("description", "scheme", "bearer_format") + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + SCHEME_FIELD_NUMBER: _ClassVar[int] + BEARER_FORMAT_FIELD_NUMBER: _ClassVar[int] + description: str + scheme: str + bearer_format: str + def __init__(self, description: _Optional[str] = ..., scheme: _Optional[str] = ..., bearer_format: _Optional[str] = ...) -> None: ... + +class OAuth2SecurityScheme(_message.Message): + __slots__ = ("description", "flows") + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + FLOWS_FIELD_NUMBER: _ClassVar[int] + description: str + flows: OAuthFlows + def __init__(self, description: _Optional[str] = ..., flows: _Optional[_Union[OAuthFlows, _Mapping]] = ...) -> None: ... + +class OpenIdConnectSecurityScheme(_message.Message): + __slots__ = ("description", "open_id_connect_url") + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + OPEN_ID_CONNECT_URL_FIELD_NUMBER: _ClassVar[int] + description: str + open_id_connect_url: str + def __init__(self, description: _Optional[str] = ..., open_id_connect_url: _Optional[str] = ...) -> None: ... + +class OAuthFlows(_message.Message): + __slots__ = ("authorization_code", "client_credentials", "implicit", "password") + AUTHORIZATION_CODE_FIELD_NUMBER: _ClassVar[int] + CLIENT_CREDENTIALS_FIELD_NUMBER: _ClassVar[int] + IMPLICIT_FIELD_NUMBER: _ClassVar[int] + PASSWORD_FIELD_NUMBER: _ClassVar[int] + authorization_code: AuthorizationCodeOAuthFlow + client_credentials: ClientCredentialsOAuthFlow + implicit: ImplicitOAuthFlow + password: PasswordOAuthFlow + def __init__(self, authorization_code: _Optional[_Union[AuthorizationCodeOAuthFlow, _Mapping]] = ..., client_credentials: _Optional[_Union[ClientCredentialsOAuthFlow, _Mapping]] = ..., implicit: _Optional[_Union[ImplicitOAuthFlow, _Mapping]] = ..., password: _Optional[_Union[PasswordOAuthFlow, _Mapping]] = ...) -> None: ... + +class AuthorizationCodeOAuthFlow(_message.Message): + __slots__ = ("authorization_url", "token_url", "refresh_url", "scopes") + class ScopesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + AUTHORIZATION_URL_FIELD_NUMBER: _ClassVar[int] + TOKEN_URL_FIELD_NUMBER: _ClassVar[int] + REFRESH_URL_FIELD_NUMBER: _ClassVar[int] + SCOPES_FIELD_NUMBER: _ClassVar[int] + authorization_url: str + token_url: str + refresh_url: str + scopes: _containers.ScalarMap[str, str] + def __init__(self, authorization_url: _Optional[str] = ..., token_url: _Optional[str] = ..., refresh_url: _Optional[str] = ..., scopes: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class ClientCredentialsOAuthFlow(_message.Message): + __slots__ = ("token_url", "refresh_url", "scopes") + class ScopesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + TOKEN_URL_FIELD_NUMBER: _ClassVar[int] + REFRESH_URL_FIELD_NUMBER: _ClassVar[int] + SCOPES_FIELD_NUMBER: _ClassVar[int] + token_url: str + refresh_url: str + scopes: _containers.ScalarMap[str, str] + def __init__(self, token_url: _Optional[str] = ..., refresh_url: _Optional[str] = ..., scopes: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class ImplicitOAuthFlow(_message.Message): + __slots__ = ("authorization_url", "refresh_url", "scopes") + class ScopesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + AUTHORIZATION_URL_FIELD_NUMBER: _ClassVar[int] + REFRESH_URL_FIELD_NUMBER: _ClassVar[int] + SCOPES_FIELD_NUMBER: _ClassVar[int] + authorization_url: str + refresh_url: str + scopes: _containers.ScalarMap[str, str] + def __init__(self, authorization_url: _Optional[str] = ..., refresh_url: _Optional[str] = ..., scopes: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class PasswordOAuthFlow(_message.Message): + __slots__ = ("token_url", "refresh_url", "scopes") + class ScopesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + TOKEN_URL_FIELD_NUMBER: _ClassVar[int] + REFRESH_URL_FIELD_NUMBER: _ClassVar[int] + SCOPES_FIELD_NUMBER: _ClassVar[int] + token_url: str + refresh_url: str + scopes: _containers.ScalarMap[str, str] + def __init__(self, token_url: _Optional[str] = ..., refresh_url: _Optional[str] = ..., scopes: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class SendMessageRequest(_message.Message): + __slots__ = ("request", "configuration", "metadata") + REQUEST_FIELD_NUMBER: _ClassVar[int] + CONFIGURATION_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + request: Message + configuration: SendMessageConfiguration + metadata: _struct_pb2.Struct + def __init__(self, request: _Optional[_Union[Message, _Mapping]] = ..., configuration: _Optional[_Union[SendMessageConfiguration, _Mapping]] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class GetTaskRequest(_message.Message): + __slots__ = ("name", "history_length") + NAME_FIELD_NUMBER: _ClassVar[int] + HISTORY_LENGTH_FIELD_NUMBER: _ClassVar[int] + name: str + history_length: int + def __init__(self, name: _Optional[str] = ..., history_length: _Optional[int] = ...) -> None: ... + +class CancelTaskRequest(_message.Message): + __slots__ = ("name",) + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class GetTaskPushNotificationRequest(_message.Message): + __slots__ = ("name",) + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class CreateTaskPushNotificationRequest(_message.Message): + __slots__ = ("parent", "config_id", "config") + PARENT_FIELD_NUMBER: _ClassVar[int] + CONFIG_ID_FIELD_NUMBER: _ClassVar[int] + CONFIG_FIELD_NUMBER: _ClassVar[int] + parent: str + config_id: str + config: TaskPushNotificationConfig + def __init__(self, parent: _Optional[str] = ..., config_id: _Optional[str] = ..., config: _Optional[_Union[TaskPushNotificationConfig, _Mapping]] = ...) -> None: ... + +class TaskSubscriptionRequest(_message.Message): + __slots__ = ("name",) + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class ListTaskPushNotificationRequest(_message.Message): + __slots__ = ("parent", "page_size", "page_token") + PARENT_FIELD_NUMBER: _ClassVar[int] + PAGE_SIZE_FIELD_NUMBER: _ClassVar[int] + PAGE_TOKEN_FIELD_NUMBER: _ClassVar[int] + parent: str + page_size: int + page_token: str + def __init__(self, parent: _Optional[str] = ..., page_size: _Optional[int] = ..., page_token: _Optional[str] = ...) -> None: ... + +class GetAgentCardRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class SendMessageResponse(_message.Message): + __slots__ = ("task", "msg") + TASK_FIELD_NUMBER: _ClassVar[int] + MSG_FIELD_NUMBER: _ClassVar[int] + task: Task + msg: Message + def __init__(self, task: _Optional[_Union[Task, _Mapping]] = ..., msg: _Optional[_Union[Message, _Mapping]] = ...) -> None: ... + +class StreamResponse(_message.Message): + __slots__ = ("task", "msg", "status_update", "artifact_update") + TASK_FIELD_NUMBER: _ClassVar[int] + MSG_FIELD_NUMBER: _ClassVar[int] + STATUS_UPDATE_FIELD_NUMBER: _ClassVar[int] + ARTIFACT_UPDATE_FIELD_NUMBER: _ClassVar[int] + task: Task + msg: Message + status_update: TaskStatusUpdateEvent + artifact_update: TaskArtifactUpdateEvent + def __init__(self, task: _Optional[_Union[Task, _Mapping]] = ..., msg: _Optional[_Union[Message, _Mapping]] = ..., status_update: _Optional[_Union[TaskStatusUpdateEvent, _Mapping]] = ..., artifact_update: _Optional[_Union[TaskArtifactUpdateEvent, _Mapping]] = ...) -> None: ... + +class ListTaskPushNotificationResponse(_message.Message): + __slots__ = ("configs", "next_page_token") + CONFIGS_FIELD_NUMBER: _ClassVar[int] + NEXT_PAGE_TOKEN_FIELD_NUMBER: _ClassVar[int] + configs: _containers.RepeatedCompositeFieldContainer[TaskPushNotificationConfig] + next_page_token: str + def __init__(self, configs: _Optional[_Iterable[_Union[TaskPushNotificationConfig, _Mapping]]] = ..., next_page_token: _Optional[str] = ...) -> None: ... diff --git a/src/a2a/grpc/a2a_pb2_grpc.py b/src/a2a/grpc/a2a_pb2_grpc.py new file mode 100644 index 000000000..01a283739 --- /dev/null +++ b/src/a2a/grpc/a2a_pb2_grpc.py @@ -0,0 +1,478 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from . import a2a_pb2 as a2a__pb2 + + +class A2AServiceStub(object): + """A2AService defines the gRPC version of the A2A protocol. This has a slightly + different shape than the JSONRPC version to better conform to AIP-127, + where appropriate. The nouns are AgentCard, Message, Task and + TaskPushNotification. + - Messages are not a standard resource so there is no get/delete/update/list + interface, only a send and stream custom methods. + - Tasks have a get interface and custom cancel and subscribe methods. + - TaskPushNotification are a resource whose parent is a task. They have get, + list and create methods. + - AgentCard is a static resource with only a get method. + Of particular note of deviation from JSONRPC approach the request metadata + fields are not present as they don't comply with AIP rules, and the + optional history_length on the get task method is not present as it also + violates AIP-127 and AIP-131. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendMessage = channel.unary_unary( + '/a2a.v1.A2AService/SendMessage', + request_serializer=a2a__pb2.SendMessageRequest.SerializeToString, + response_deserializer=a2a__pb2.SendMessageResponse.FromString, + _registered_method=True) + self.SendStreamingMessage = channel.unary_stream( + '/a2a.v1.A2AService/SendStreamingMessage', + request_serializer=a2a__pb2.SendMessageRequest.SerializeToString, + response_deserializer=a2a__pb2.StreamResponse.FromString, + _registered_method=True) + self.GetTask = channel.unary_unary( + '/a2a.v1.A2AService/GetTask', + request_serializer=a2a__pb2.GetTaskRequest.SerializeToString, + response_deserializer=a2a__pb2.Task.FromString, + _registered_method=True) + self.CancelTask = channel.unary_unary( + '/a2a.v1.A2AService/CancelTask', + request_serializer=a2a__pb2.CancelTaskRequest.SerializeToString, + response_deserializer=a2a__pb2.Task.FromString, + _registered_method=True) + self.TaskSubscription = channel.unary_stream( + '/a2a.v1.A2AService/TaskSubscription', + request_serializer=a2a__pb2.TaskSubscriptionRequest.SerializeToString, + response_deserializer=a2a__pb2.StreamResponse.FromString, + _registered_method=True) + self.CreateTaskPushNotification = channel.unary_unary( + '/a2a.v1.A2AService/CreateTaskPushNotification', + request_serializer=a2a__pb2.CreateTaskPushNotificationRequest.SerializeToString, + response_deserializer=a2a__pb2.TaskPushNotificationConfig.FromString, + _registered_method=True) + self.GetTaskPushNotification = channel.unary_unary( + '/a2a.v1.A2AService/GetTaskPushNotification', + request_serializer=a2a__pb2.GetTaskPushNotificationRequest.SerializeToString, + response_deserializer=a2a__pb2.TaskPushNotificationConfig.FromString, + _registered_method=True) + self.ListTaskPushNotification = channel.unary_unary( + '/a2a.v1.A2AService/ListTaskPushNotification', + request_serializer=a2a__pb2.ListTaskPushNotificationRequest.SerializeToString, + response_deserializer=a2a__pb2.ListTaskPushNotificationResponse.FromString, + _registered_method=True) + self.GetAgentCard = channel.unary_unary( + '/a2a.v1.A2AService/GetAgentCard', + request_serializer=a2a__pb2.GetAgentCardRequest.SerializeToString, + response_deserializer=a2a__pb2.AgentCard.FromString, + _registered_method=True) + + +class A2AServiceServicer(object): + """A2AService defines the gRPC version of the A2A protocol. This has a slightly + different shape than the JSONRPC version to better conform to AIP-127, + where appropriate. The nouns are AgentCard, Message, Task and + TaskPushNotification. + - Messages are not a standard resource so there is no get/delete/update/list + interface, only a send and stream custom methods. + - Tasks have a get interface and custom cancel and subscribe methods. + - TaskPushNotification are a resource whose parent is a task. They have get, + list and create methods. + - AgentCard is a static resource with only a get method. + Of particular note of deviation from JSONRPC approach the request metadata + fields are not present as they don't comply with AIP rules, and the + optional history_length on the get task method is not present as it also + violates AIP-127 and AIP-131. + """ + + def SendMessage(self, request, context): + """Send a message to the agent. This is a blocking call that will return the + task once it is completed, or a LRO if requested. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendStreamingMessage(self, request, context): + """SendStreamingMessage is a streaming call that will return a stream of + task update events until the Task is in an interrupted or terminal state. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetTask(self, request, context): + """Get the current state of a task from the agent. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CancelTask(self, request, context): + """Cancel a task from the agent. If supported one should expect no + more task updates for the task. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TaskSubscription(self, request, context): + """TaskSubscription is a streaming call that will return a stream of task + update events. This attaches the stream to an existing in process task. + If the task is complete the stream will return the completed task (like + GetTask) and close the stream. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreateTaskPushNotification(self, request, context): + """Set a push notification config for a task. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetTaskPushNotification(self, request, context): + """Get a push notification config for a task. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ListTaskPushNotification(self, request, context): + """Get a list of push notifications configured for a task. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetAgentCard(self, request, context): + """GetAgentCard returns the agent card for the agent. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_A2AServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SendMessage': grpc.unary_unary_rpc_method_handler( + servicer.SendMessage, + request_deserializer=a2a__pb2.SendMessageRequest.FromString, + response_serializer=a2a__pb2.SendMessageResponse.SerializeToString, + ), + 'SendStreamingMessage': grpc.unary_stream_rpc_method_handler( + servicer.SendStreamingMessage, + request_deserializer=a2a__pb2.SendMessageRequest.FromString, + response_serializer=a2a__pb2.StreamResponse.SerializeToString, + ), + 'GetTask': grpc.unary_unary_rpc_method_handler( + servicer.GetTask, + request_deserializer=a2a__pb2.GetTaskRequest.FromString, + response_serializer=a2a__pb2.Task.SerializeToString, + ), + 'CancelTask': grpc.unary_unary_rpc_method_handler( + servicer.CancelTask, + request_deserializer=a2a__pb2.CancelTaskRequest.FromString, + response_serializer=a2a__pb2.Task.SerializeToString, + ), + 'TaskSubscription': grpc.unary_stream_rpc_method_handler( + servicer.TaskSubscription, + request_deserializer=a2a__pb2.TaskSubscriptionRequest.FromString, + response_serializer=a2a__pb2.StreamResponse.SerializeToString, + ), + 'CreateTaskPushNotification': grpc.unary_unary_rpc_method_handler( + servicer.CreateTaskPushNotification, + request_deserializer=a2a__pb2.CreateTaskPushNotificationRequest.FromString, + response_serializer=a2a__pb2.TaskPushNotificationConfig.SerializeToString, + ), + 'GetTaskPushNotification': grpc.unary_unary_rpc_method_handler( + servicer.GetTaskPushNotification, + request_deserializer=a2a__pb2.GetTaskPushNotificationRequest.FromString, + response_serializer=a2a__pb2.TaskPushNotificationConfig.SerializeToString, + ), + 'ListTaskPushNotification': grpc.unary_unary_rpc_method_handler( + servicer.ListTaskPushNotification, + request_deserializer=a2a__pb2.ListTaskPushNotificationRequest.FromString, + response_serializer=a2a__pb2.ListTaskPushNotificationResponse.SerializeToString, + ), + 'GetAgentCard': grpc.unary_unary_rpc_method_handler( + servicer.GetAgentCard, + request_deserializer=a2a__pb2.GetAgentCardRequest.FromString, + response_serializer=a2a__pb2.AgentCard.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'a2a.v1.A2AService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('a2a.v1.A2AService', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class A2AService(object): + """A2AService defines the gRPC version of the A2A protocol. This has a slightly + different shape than the JSONRPC version to better conform to AIP-127, + where appropriate. The nouns are AgentCard, Message, Task and + TaskPushNotification. + - Messages are not a standard resource so there is no get/delete/update/list + interface, only a send and stream custom methods. + - Tasks have a get interface and custom cancel and subscribe methods. + - TaskPushNotification are a resource whose parent is a task. They have get, + list and create methods. + - AgentCard is a static resource with only a get method. + Of particular note of deviation from JSONRPC approach the request metadata + fields are not present as they don't comply with AIP rules, and the + optional history_length on the get task method is not present as it also + violates AIP-127 and AIP-131. + """ + + @staticmethod + def SendMessage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/SendMessage', + a2a__pb2.SendMessageRequest.SerializeToString, + a2a__pb2.SendMessageResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendStreamingMessage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/a2a.v1.A2AService/SendStreamingMessage', + a2a__pb2.SendMessageRequest.SerializeToString, + a2a__pb2.StreamResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetTask(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/GetTask', + a2a__pb2.GetTaskRequest.SerializeToString, + a2a__pb2.Task.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def CancelTask(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/CancelTask', + a2a__pb2.CancelTaskRequest.SerializeToString, + a2a__pb2.Task.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def TaskSubscription(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/a2a.v1.A2AService/TaskSubscription', + a2a__pb2.TaskSubscriptionRequest.SerializeToString, + a2a__pb2.StreamResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def CreateTaskPushNotification(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/CreateTaskPushNotification', + a2a__pb2.CreateTaskPushNotificationRequest.SerializeToString, + a2a__pb2.TaskPushNotificationConfig.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetTaskPushNotification(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/GetTaskPushNotification', + a2a__pb2.GetTaskPushNotificationRequest.SerializeToString, + a2a__pb2.TaskPushNotificationConfig.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def ListTaskPushNotification(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/ListTaskPushNotification', + a2a__pb2.ListTaskPushNotificationRequest.SerializeToString, + a2a__pb2.ListTaskPushNotificationResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetAgentCard(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/GetAgentCard', + a2a__pb2.GetAgentCardRequest.SerializeToString, + a2a__pb2.AgentCard.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index f0d2667d8..623843848 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -4,6 +4,7 @@ DefaultRequestHandler, ) from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler +from a2a.server.request_handlers.grpc_handler import GrpcHandler from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.response_helpers import ( build_error_response, @@ -14,6 +15,7 @@ __all__ = [ 'DefaultRequestHandler', 'JSONRPCHandler', + 'GrpcHandler', 'RequestHandler', 'build_error_response', 'prepare_response_object', diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 09b1d3049..660ef7ef2 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -3,6 +3,7 @@ from collections.abc import AsyncGenerator from typing import cast +import uuid from a2a.server.agent_execution import ( AgentExecutor, @@ -364,6 +365,8 @@ async def on_set_task_push_notification_config( if not task: raise ServerError(error=TaskNotFoundError()) + # Generate a unique id for the notification + params.pushNotificationConfig.id = str(uuid.uuid4()) await self._push_notifier.set_info( params.taskId, params.pushNotificationConfig, diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py new file mode 100644 index 000000000..66c24804b --- /dev/null +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -0,0 +1,358 @@ +import logging +import grpc +import contextlib + +from typing import AsyncIterable +from abc import ABC, abstractmethod + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import ( + AgentCard, + InternalError, + Message, + Task, + TaskArtifactUpdateEvent, + TaskNotFoundError, + TaskPushNotificationConfig, + TaskStatusUpdateEvent, +) +from a2a import types +from a2a.auth.user import User as A2AUser +from a2a.auth.user import UnauthenticatedUser +from a2a.server.context import ServerCallContext +from a2a.utils.errors import ServerError +from a2a.utils.helpers import validate, validate_async_generator +from a2a.utils import proto_utils +import a2a.grpc.a2a_pb2 as a2a_pb2 +import a2a.grpc.a2a_pb2_grpc as a2a_grpc + + +logger = logging.getLogger(__name__) + +# For now we use a trivial wrapper on the grpc context object + +class CallContextBuilder(ABC): + """A class for building ServerCallContexts using the Starlette Request.""" + + @abstractmethod + def build(self, context: grpc.ServicerContext) -> ServerCallContext: + """Builds a ServerCallContext from a gRPC Request.""" + + +class DefaultCallContextBuilder(CallContextBuilder): + """A default implementation of CallContextBuilder.""" + + def build(self, context: grpc.ServicerContext) -> ServerCallContext: + user = UnauthenticatedUser() + state = {} + with contextlib.suppress(Exception): + state['grpc_context'] = context + return ServerCallContext(user=user, state=state) + + +class GrpcHandler(a2a_grpc.A2AServiceServicer): + """Maps incoming gRPC requests to the appropriate request handler method + and formats responses.""" + + def __init__( + self, + agent_card: AgentCard, + request_handler: RequestHandler, + context_builder: CallContextBuilder = DefaultCallContextBuilder(), + ): + """Initializes the GrpcHandler. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The underlying `RequestHandler` instance to delegat +e requests to. + """ + self.agent_card = agent_card + self.request_handler = request_handler + self.context_builder = context_builder + + async def SendMessage( + self, + request: a2a_pb2.SendMessageRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.SendMessageResponse: + """Handles the 'SendMessage' gRPC method. + + Args: + request: The incoming `SendMessageRequest` object. + context: Context provided by the server. + + Returns: + A `SendMessageResponse` object containing the result (Task or Messag +e) + or throws an error response if a `ServerError` is raised by the han +dler. + """ + try: + # Construct the server context object + server_context = self.context_builder.build(context) + # Transform the proto object to the python internal objects + a2a_request = proto_utils.FromProto.message_send_params( + request, + ) + task_or_message = await self.request_handler.on_message_send( + a2a_request, server_context + ) + return proto_utils.ToProto.task_or_message(task_or_message) + except ServerError as e: + await self.abort_context(e, context) + + @validate_async_generator( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def SendStreamingMessage( + self, + request: a2a_pb2.SendMessageRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterable[a2a_pb2.StreamResponse]: + """Handles the 'StreamMessage' gRPC method. + + Yields response objects as they are produced by the underlying handler's + stream. + + Args: + request: The incoming `SendMessageRequest` object. + context: Context provided by the server. + + Yields: + `StreamResponse` objects containing streaming events + (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) + or gRPC error responses if a `ServerError` is raised. + """ + server_context = self.context_builder.build(context) + # Transform the proto object to the python internal objects + a2a_request = proto_utils.FromProto.message_send_params( + request, + ) + try: + async for event in self.request_handler.on_message_send_stream( + a2a_request, server_context + ): + yield proto_utils.ToProto.stream_response(event) + except ServerError as e: + await self.abort_context(e, context) + return + + async def CancelTask( + self, + request: a2a_pb2.CancelTaskRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.Task: + """Handles the 'CancelTask' gRPC method. + + Args: + request: The incoming `CancelTaskRequest` object. + context: Context provided by the server. + + Returns: + A `Task` object containing the updated Task or a gRPC error. + """ + try: + server_context = self.context_builder.build(context) + task_id_params = proto_utils.FromProto.task_id_params(request) + task = await self.request_handler.on_cancel_task( + task_id_params, server_context + ) + if task: + return proto_utils.ToProto.task(task) + self.abort_context(ServerError(error=TaskNotFoundError()), context) + except ServerError as e: + await self.abort_context(e, context) + + @validate_async_generator( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def TaskSubscription( + self, + request: a2a_pb2.TaskSubscriptionRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterable[a2a_pb2.StreamResponse]: + """Handles the 'TaskSubscription' gRPC method. + + Yields response objects as they are produced by the underlying handler's + stream. + + Args: + request: The incoming `TaskSubscriptionRequest` object. + context: Context provided by the server. + + Yields: + `StreamResponse` objects containing streaming events + """ + try: + server_context = self.context_builder.build(context) + async for event in self.request_handler.on_resubscribe_to_task( + proto_utils.FromProto.task_id_params(request), server_context, + ): + yield proto_utils.ToProto.stream_response(event) + except ServerError as e: + await self.abort_context(e, context) + + async def GetTaskPushNotification( + self, + request: a2a_pb2.GetTaskPushNotificationRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.TaskPushNotificationConfig: + """Handles the 'GetTaskPushNotification' gRPC method. + + Args: + request: The incoming `GetTaskPushNotificationConfigRequest` object. + context: Context provided by the server. + + Returns: + A `TaskPushNotificationConfig` object containing the config. + """ + try: + server_context = self.context_builder.build(context) + config = ( + await self.request_handler.on_get_task_push_notification_config( + proto_utils.FromProto.task_id_params(request), + server_context, + ) + ) + return proto_utils.ToProto.task_push_notification_config(config) + except ServerError as e: + await self.abort_context(e, context) + + @validate( + lambda self: self.agent_card.capabilities.pushNotifications, + 'Push notifications are not supported by the agent', + ) + async def CreateTaskPushNotification( + self, + request: a2a_pb2.CreateTaskPushNotificationRequest, + context: grpc.aio.ServicerContext, + ) -> TaskPushNotificationConfig: + """Handles the 'CreateTaskPushNotification' gRPC method. + + Requires the agent to support push notifications. + + Args: + request: The incoming `CreateTaskPushNotificationRequest` object. + context: Context provided by the server. + + Returns: + A `TaskPushNotificationConfig` object + + Raises: + ServerError: If push notifications are not supported by the agent + (due to the `@validate` decorator). + """ + try: + server_context = self.context_builder.build(context) + config = ( + await self.request_handler.on_set_task_push_notification_config( + proto_utils.FromProto.task_push_notification_config( + request, + ), + server_context, + ) + ) + return proto_utils.ToProto.task_push_notification_config(config) + except ServerError as e: + await self.abort_context(e, context) + + async def GetTask( + self, + request: a2a_pb2.GetTaskRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.Task: + """Handles the 'GetTask' gRPC method. + + Args: + request: The incoming `GetTaskRequest` object. + context: Context provided by the server. + + Returns: + A `Task` object. + """ + try: + server_context = self.context_builder.build(context) + task = await self.request_handler.on_get_task( + proto_utils.FromProto.task_query_params(request), server_context + ) + if task: + return proto_utils.ToProto.task(task) + self.abort_context(ServerError(error=TaskNotFoundError()), context) + except ServerError as e: + await self.abort_context(e, context) + + async def GetAgentCard( + self, + request: a2a_pb2.GetAgentCardRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.AgentCard: + return proto_utils.ToProto.agent_card(self.agent_card) + + async def abort_context( + self, error: ServerError, context: grpc.ServicerContext + ): + match error.error: + case types.JSONParseError(): + await context.abort( + grpc.StatusCode.INTERNAL, + f'JSONParseError: {error.error.message}', + ) + case types.InvalidRequestError(): + await context.abort( + grpc.StatusCode.INVALID_ARGUMENT, + f'InvalidRequestError: {error.error.message}', + ) + case types.MethodNotFoundError(): + await context.abort( + grpc.StatusCode.NOT_FOUND, + f'MethodNotFoundError: {error.error.message}', + ) + case types.InvalidParamsError(): + await context.abort( + grpc.StatusCode.INVALID_ARGUMENT, + f'InvalidParamsError: {error.error.message}', + ) + case types.InternalError(): + await context.abort( + grpc.StatusCode.INTERNAL, + f'InternalError: {error.error.message}', + ) + case types.TaskNotFoundError(): + await context.abort( + grpc.StatusCode.NOT_FOUND, + f'TaskNotFoundError: {error.error.message}', + ) + case types.TaskNotCancelableError(): + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + f'TaskNotCancelableError: {error.error.message}', + ) + case types.PushNotificationNotSupportedError(): + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + f'PushNotificationNotSupportedError: {error.error.message}', + ) + case types.UnsupportedOperationError(): + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + f'UnsupportedOperationError: {error.error.message}', + ) + case types.ContentTypeNotSupportedError(): + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + f'ContentTypeNotSupportedError: {error.error.message}', + ) + case types.InvalidAgentResponseError(): + await context.abort( + grpc.StatusCode.INTERNAL, + f'InvalidAgentResponseError: {error.error.message}', + ) + case _: + await context.abort( + grpc.StatusCode.UNKNOWN, + f'Unknown error type: {error.error}', + ) diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 243ac87b0..4260aa6e1 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -1,5 +1,5 @@ """General utility functions for the A2A Python SDK.""" - +import functools import logging from collections.abc import Callable @@ -147,6 +147,37 @@ def wrapper(self, *args, **kwargs): return decorator +def validate_async_generator( + expression: Callable[[Any], bool], error_message: str | None = None +): + """Decorator that validates if a given expression evaluates to True. + + Typically used on class methods to check capabilities or configuration + before executing the method's logic. If the expression is False, + a `ServerError` with an `UnsupportedOperationError` is raised. + + Args: + expression: A callable that takes the instance (`self`) as its argument + and returns a boolean. + error_message: An optional custom error message for the `UnsupportedOperationError`. + If None, the string representation of the expression will be used. + """ + + def decorator(function): + @functools.wraps(function) + async def wrapper(self, *args, **kwargs): + if not expression(self): + final_message = error_message or str(expression) + logger.error(f'Unsupported Operation: {final_message}') + raise ServerError( + UnsupportedOperationError(message=final_message) + ) + async for i in function(self, *args, **kwargs): + yield i + + return wrapper + + return decorator def are_modalities_compatible( server_output_modes: list[str] | None, client_output_modes: list[str] | None diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py new file mode 100644 index 000000000..bc78abbf8 --- /dev/null +++ b/src/a2a/utils/proto_utils.py @@ -0,0 +1,781 @@ +"""Utils for converting between proto and Python types.""" + +import json +from typing import Any, Dict +import re + +from a2a.grpc import a2a_pb2 +from a2a import types +from a2a.utils.errors import ServerError +from google.protobuf import struct_pb2 +from google.protobuf import json_format + + +# Regexp patterns for matching +_TASK_NAME_MATCH = r'tasks/(\w+)' +_TASK_PUSH_CONFIG_NAME_MATCH = r'tasks/(\w+)/pushNotifications/(\w+)' + + +class ToProto: + """Converts Python types to proto types.""" + + @classmethod + def message(cls, message: types.Message | None) -> a2a_pb2.Message | None: + if message is None: + return None + return a2a_pb2.Message( + message_id=message.messageId, + content=[ToProto.part(p) for p in message.parts], + context_id=message.contextId, + task_id=message.taskId, + role=cls.role(message.role.name), + metadata=ToProto.metadata(message.metadata), + ) + + @classmethod + def metadata( + cls, metadata: Dict[str, Any] | None + ) -> struct_pb2.Struct | None: + if metadata is None: + return None + return struct_pb2.Struct( + # TODO: Add support for other types. + fields={ + key: struct_pb2.Value(string_value=value) + for key, value in metadata.items() + if isinstance(value, str) + } + ) + + @classmethod + def part(cls, part: types.Part) -> a2a_pb2.Part: + if isinstance(part.root, types.TextPart): + return a2a_pb2.Part(text=part.root.text) + elif isinstance(part.root, types.FilePart): + return a2a_pb2.Part(file=ToProto.file(part.root.file)) + elif isinstance(part.root, types.DataPart): + return a2a_pb2.Part(data=ToProto.data(part.root.data)) + else: + raise ValueError(f'Unsupported part type: {part.root}') + + @classmethod + def data(cls, data: Dict[str, Any]) -> a2a_pb2.DataPart: + json_data = json.dumps(data) + return a2a_pb2.DataPart( + data=json_format.Parse( + json_data, + struct_pb2.Struct(), + ) + ) + + @classmethod + def file( + cls, file: types.FileWithUri | types.FileWithBytes + ) -> a2a_pb2.FilePart: + if isinstance(file, types.FileWithUri): + return a2a_pb2.FilePart(file_with_uri=file.uri) + return a2a_pb2.FilePart(file_with_bytes=file.bytes.encode('utf-8')) + + @classmethod + def task(cls, task: types.Task) -> a2a_pb2.Task: + return a2a_pb2.Task( + id=task.id, + context_id=task.contextId, + status=ToProto.task_status(task.status), + artifacts=([ + ToProto.artifact(a) for a in task.artifacts + ] if task.artifacts else None), + history=([ + ToProto.message(h) for h in task.history + ] if task.history else None), + ) + + @classmethod + def task_status(cls, status: types.TaskStatus) -> a2a_pb2.TaskStatus: + return a2a_pb2.TaskStatus( + state=ToProto.task_state(status.state), + update=ToProto.message(status.message), + ) + + @classmethod + def task_state(cls, state: types.TaskState) -> a2a_pb2.TaskState: + match state: + case types.TaskState.submitted: + return a2a_pb2.TaskState.TASK_STATE_SUBMITTED + case types.TaskState.working: + return a2a_pb2.TaskState.TASK_STATE_WORKING + case types.TaskState.completed: + return a2a_pb2.TaskState.TASK_STATE_COMPLETED + case types.TaskState.canceled: + return a2a_pb2.TaskState.TASK_STATE_CANCELLED + case types.TaskState.failed: + return a2a_pb2.TaskState.TASK_STATE_FAILED + case types.TaskState.input_required: + return a2a_pb2.TaskState.TASK_STATE_INPUT_REQUIRED + case _: + return a2a_pb2.TaskState.TASK_STATE_UNSPECIFIED + + @classmethod + def artifact(cls, artifact: types.Artifact) -> a2a_pb2.Artifact: + return a2a_pb2.Artifact( + artifact_id=artifact.artifactId, + description=artifact.description, + metadata=ToProto.metadata(artifact.metadata), + name=artifact.name, + parts=[ToProto.part(p) for p in artifact.parts], + ) + + @classmethod + def authentication_info( + cls, info: types.PushNotificationAuthenticationInfo + ) -> a2a_pb2.AuthenticationInfo: + return a2a_pb2.AuthenticationInfo( + schemes=info.schemes, + credentials=info.credentials, + ) + + @classmethod + def push_notification_config( + cls, config: types.PushNotificationConfig + ) -> a2a_pb2.PushNotificationConfig: + return a2a_pb2.PushNotificationConfig( + id=config.id if id else "", + url=config.url, + token=config.token, + authentication=ToProto.authentication_info(config.authentication), + ) + + @classmethod + def task_artifact_update_event( + cls, event: types.TaskArtifactUpdateEvent + ) -> a2a_pb2.TaskArtifactUpdateEvent: + return a2a_pb2.TaskArtifactUpdateEvent( + task_id=event.taskId, + context_id=event.contextId, + artifact=ToProto.artifact(event.artifact), + metadata=ToProto.metadata(event.metadata), + append=event.append, + last_chunk=event.lastChunk, + ) + + @classmethod + def task_status_update_event( + cls, event: types.TaskStatusUpdateEvent + ) -> a2a_pb2.TaskStatusUpdateEvent: + return a2a_pb2.TaskStatusUpdateEvent( + task_id=event.taskId, + context_id=event.contextId, + status=ToProto.task_status(event.status), + metadata=ToProto.metadata(event.metadata), + final=event.final, + ) + + @classmethod + def message_send_configuration( + cls, config: types.MessageSendConfiguration | None + ) -> a2a_pb2.SendMessageConfiguration: + if not config: + return a2a_pb2.SendMessageConfiguration() + return a2a_pb2.SendMessageConfiguration( + accepted_output_modes=list(config.acceptedOutputModes), + push_notification=ToProto.push_notification_config( + config.pushNotificationConfig + ), + history_length=config.historyLength, + blocking=config.blocking, + ) + + @classmethod + def update_event( + cls, event: types.Task | types.Message | types.TaskStatusUpdateEvent + | types.TaskArtifactUpdateEvent + ) -> a2a_pb2.StreamResponse: + """Converts a task, message, or task update event to a StreamResponse.""" + if isinstance(event, types.TaskStatusUpdateEvent): + return a2a_pb2.StreamResponse( + status_update=ToProto.task_status_update_event(event) + ) + elif isinstance(event, types.TaskArtifactUpdateEvent): + return a2a_pb2.StreamResponse( + artifact_update=ToProto.task_artifact_update_event(event) + ) + elif isinstance(event, types.Message): + return a2a_pb2.StreamResponse(msg=ToProto.message(event)) + elif isinstance(event, types.Task): + return a2a_pb2.StreamResponse(task=ToProto.task(event)) + else: + raise ValueError(f'Unsupported event type: {type(event)}') + + @classmethod + def task_or_message( + cls, event: types.Task | types.Message + ) -> a2a_pb2.SendMessageResponse: + if isinstance(event, types.Message): + return a2a_pb2.SendMessageResponse( + msg=cls.message(event), + ) + return a2a_pb2.SendMessageResponse( + task=cls.task(event), + ) + + @classmethod + def stream_response( + cls, + event: ( + types.Message | + types.Task | + types.TaskStatusUpdateEvent | + types.TaskArtifactUpdateEvent) + ) -> a2a_pb2.StreamResponse: + if isinstance(event, types.Message): + return a2a_pb2.StreamResponse(msg=cls.message(event)) + elif isinstance(event, types.Task): + return a2a_pb2.StreamResponse(task=cls.task(event)) + elif isinstance(event, types.TaskStatusUpdateEvent): + return a2a_pb2.StreamResponse( + status_update=cls.task_status_update_event(event), + ) + return a2a_pb2.StreamResponse( + artifact_update=cls.task_artifact_update_event(event), + ) + + @classmethod + def task_push_notification_config( + cls, + config: types.TaskPushNotificationConfig + ) -> a2a_pb2.TaskPushNotificationConfig: + return a2a_pb2.TaskPushNotificationConfig( + name=f'tasks/{config.taskId}/pushNotifications/{config.taskId}', + push_notification_config=cls.push_notification_config( + config.pushNotificationConfig, + ), + ) + + @classmethod + def agent_card( + cls, card: types.AgentCard, + ) -> a2a_pb2.AgentCard: + return a2a_pb2.AgentCard( + capabilities=cls.capabilities(card.capabilities), + default_input_modes=list(card.defaultInputModes), + default_output_modes=list(card.defaultOutputModes), + description=card.description, + documentation_url=card.documentationUrl, + name=card.name, + provider=cls.provider(card.provider), + security=cls.security(card.security), + security_schemes=cls.security_schemes(card.securitySchemes), + skills=[cls.skill(x) for x in card.skills] if card.skills else [], + url=card.url, + version=card.version, + supports_authenticated_extended_card=card.supportsAuthenticatedExtendedCard, + ) + + @classmethod + def capabilities( + cls, capabilities: types.AgentCapabilities + ) -> a2a_pb2.AgentCapabilities: + return a2a_pb2.AgentCapabilities( + streaming=capabilities.streaming, + push_notifications=capabilities.pushNotifications, + ) + + @classmethod + def provider( + cls, provider: types.AgentProvider | None + ) -> a2a_pb2.AgentProvider | None: + if not provider: + return None + return a2a_pb2.AgentProvider( + organization=provider.organization, + url=provider.url, + ) + + @classmethod + def security( + cls, security: list[dict[str, list[str]]] | None, + ) -> list[a2a_pb2.Security] | None: + if not security: + return None + rval: list[a2a_pb2.Security] = [] + for s in security: + rval.append( + a2a_pb2.Security( + schemes={ + k: a2a_pb2.StringList(list=v.list) for (k, v) in s.items() + } + ) + ) + return rval + + @classmethod + def security_schemes( + cls, schemes: dict[str, types.SecurityScheme] | None, + ) -> dict[str, a2a_pb2.SecurityScheme] | None: + if not schemes: + return None + return {k: cls.security_scheme(v) for (k, v) in schemes.items()} + + @classmethod + def security_scheme( + cls, scheme: types.SecurityScheme, + ) -> a2a_pb2.SecurityScheme: + if isinstance(scheme.root, types.ApiKeySecurityScheme): + return a2a_pb2.SecurityScheme( + api_key_security_scheme=a2a_pb2.APIKeySecurityScheme( + description=scheme.root.description, + location=scheme.root.in_, + name=scheme.root.name, + ) + ) + if isinstance(scheme.root, types.HTTPAuthSecurityScheme): + return a2a_pb2.SecurityScheme( + http_auth_security_scheme=a2a_pb2.HttpAuthSecurityScheme( + description=scheme.root.description, + scheme=scheme.root.scheme, + bearer_format=scheme.root.bearerFormat, + ) + ) + if isinstance(scheme.root, types.Oauth2SecurityScheme): + return a2a_pb2.SecurityScheme( + oauth2_security_scheme=a2a_pb2.OAuth2SecurityScheme( + description=scheme.root.description, + flows=cls.oauth2_flows(scheme.root.flows), + ) + ) + return a2a_pb2.SecurityScheme( + open_id_connect_security_scheme=a2a_pb2.OpenIdConnectSecurityScheme( + description=scheme.root.description, + open_id_connect_url=scheme.root.openIdConnectUrl, + ) + ) + + @classmethod + def oauth2_flows(cls, flows: types.OAuthFlows) -> a2a_pb2.OAuthFlows: + if flows.authorizationCode: + return a2a_pb2.OAuthFlows( + authorization_code=a2a_pb2.AuthorizationCodeAuthFlow( + authorization_url=flows.authorizationCode.authorizationUrl, + refresh_url=flows.authorizationCode.refreshUrl, + scopes={ + k: v for (k, v) in flows.authorizationCode.scopes.items() + }, + token_url=flows.authorizationCode.tokenUrl, + ), + ) + if flows.clientCredentials: + return a2a_pb2.OAuthFlows( + client_credentials=a2a_pb2.ClientCredentialsAuthFlow( + refresh_url=flows.clientCredentials.refreshUrl, + scopes={ + k:v for (k, v) in flows.clientCredentials.scopes.items() + }, + token_url=flows.client_credentials.tokenUrl, + ), + ) + if flows.implicit: + return a2a_pb2.OAuthFlows( + implicit=a2a_pb2.ImplicitOAuthFlow( + authorization_url=flows.implicit.authorization_Url, + refresh_url=flows.implicit.refreshUrl, + scopes={k: v for (k, v) in flows.implicit.scopes.items()}, + ), + ) + return a2a_pb2.OAuthFlows( + password=types.PasswordOAuthFlow( + refresh_url=flows.password.refreshUrl, + scopes={k: v for (k, v) in flows.password.scopes.items()}, + token_url=flows.password.tokenUrl, + ), + ) + + @classmethod + def skill(cls, skill: types.AgentSkill) -> a2a_pb2.AgentSkill: + return a2a_pb2.AgentSkill( + id=skill.id, + name=skill.name, + description=skill.description, + tags=skill.tags, + examples=skill.examples, + input_modes=skill.inputModes, + output_modes=skill.outputModes, + ) + + @classmethod + def role(cls, role: types.Role) -> a2a_pb2.Role: + match role: + case types.Role.user: + return a2a_pb2.Role.ROLE_USER + case types.Role.agent: + return a2a_pb2.Role.ROLE_AGENT + case _: + return a2a_pb2.Role.ROLE_UNSPECIFIED + + +class FromProto: + """Converts proto types to Python types.""" + + @classmethod + def message(cls, message: a2a_pb2.Message) -> types.Message: + return types.Message( + messageId=message.message_id, + parts=[FromProto.part(p) for p in message.content], + contextId=message.context_id, + taskId=message.task_id, + role=FromProto.role(message.role), + metadata=FromProto.metadata(message.metadata), + ) + + @classmethod + def metadata(cls, metadata: struct_pb2.Struct) -> Dict[str, Any]: + return { + key: value.string_value + for key, value in metadata.fields.items() + if value.string_value + } + + @classmethod + def part(cls, part: a2a_pb2.Part) -> types.Part: + if part.HasField('text'): + return types.Part(root=types.TextPart(text=part.text)) + elif part.HasField('file'): + return types.Part(root=types.FilePart(file=FromProto.file(part.file))) + elif part.HasField('data'): + return types.Part(root=types.DataPart(data=FromProto.data(part.data))) + else: + raise ValueError(f'Unsupported part type: {part}') + + @classmethod + def data(cls, data: a2a_pb2.DataPart) -> Dict[str, Any]: + json_data = json_format.MessageToJson(data.data) + return json.loads(json_data) + + @classmethod + def file( + cls, file: a2a_pb2.FilePart + ) -> types.FileWithUri | types.FileWithBytes: + if file.HasField('file_with_uri'): + return types.FileWithUri(uri=file.file_with_uri) + return types.FileWithBytes(bytes=file.file_with_bytes.decode('utf-8')) + + @classmethod + def task(cls, task: a2a_pb2.Task) -> types.Task: + return types.Task( + id=task.id, + contextId=task.context_id, + status=FromProto.task_status(task.status), + artifacts=[FromProto.artifact(a) for a in task.artifacts], + history=[FromProto.message(h) for h in task.history], + ) + + @classmethod + def task_status(cls, status: a2a_pb2.TaskStatus) -> types.TaskStatus: + return types.TaskStatus( + state=FromProto.task_state(status.state), + message=FromProto.message(status.update), + ) + + @classmethod + def task_state(cls, state: a2a_pb2.TaskState) -> types.TaskState: + match state: + case a2a_pb2.TaskState.TASK_STATE_SUBMITTED: + return types.TaskState.submitted + case a2a_pb2.TaskState.TASK_STATE_WORKING: + return types.TaskState.working + case a2a_pb2.TaskState.TASK_STATE_COMPLETED: + return types.TaskState.completed + case a2a_pb2.TaskState.TASK_STATE_CANCELLED: + return types.TaskState.canceled + case a2a_pb2.TaskState.TASK_STATE_FAILED: + return types.TaskState.failed + case a2a_pb2.TaskState.TASK_STATE_INPUT_REQUIRED: + return types.TaskState.input_required + case _: + return types.TaskState.unknown + + @classmethod + def artifact(cls, artifact: a2a_pb2.Artifact) -> types.Artifact: + return types.Artifact( + artifactId=artifact.artifact_id, + description=artifact.description, + metadata=FromProto.metadata(artifact.metadata), + name=artifact.name, + parts=[FromProto.part(p) for p in artifact.parts], + ) + + @classmethod + def task_artifact_update_event( + cls, event: a2a_pb2.TaskArtifactUpdateEvent + ) -> types.TaskArtifactUpdateEvent: + return types.TaskArtifactUpdateEvent( + taskId=event.task_id, + contextId=event.context_id, + artifact=FromProto.artifact(event.artifact), + metadata=FromProto.metadata(event.metadata), + append=event.append, + lastChunk=event.last_chunk, + ) + + @classmethod + def task_status_update_event( + cls, event: a2a_pb2.TaskStatusUpdateEvent + ) -> types.TaskStatusUpdateEvent: + return types.TaskStatusUpdateEvent( + taskId=event.task_id, + contextId=event.context_id, + status=FromProto.task_status(event.status), + metadata=FromProto.metadata(event.metadata), + final=event.final, + ) + + @classmethod + def push_notification_config( + cls, config: a2a_pb2.PushNotificationConfig + ) -> types.PushNotificationConfig: + return types.PushNotificationConfig( + id=config.id, + url=config.url, + token=config.token, + authentication=FromProto.authentication_info(config.authentication), + ) + + @classmethod + def authentication_info( + cls, info: a2a_pb2.AuthenticationInfo + ) -> types.PushNotificationAuthenticationInfo: + return types.PushNotificationAuthenticationInfo( + schemes=list(info.schemes), + credentials=info.credentials, + ) + + @classmethod + def message_send_configuration( + cls, config: a2a_pb2.SendMessageConfiguration + ) -> types.MessageSendConfiguration: + return types.MessageSendConfiguration( + acceptedOutputModes=list(config.accepted_output_modes), + pushNotificationConfig=FromProto.push_notification_config( + config.push_notification + ), + historyLength=config.history_length, + blocking=config.blocking, + ) + + @classmethod + def message_send_params( + cls, request: a2a_pb2.SendMessageRequest + ) -> types.MessageSendParams: + return types.MessageSendParams( + configuration=cls.message_send_configuration( + request.configuration + ), + message=cls.message(request.request), + metadata=cls.metadata(request.metadata), + ) + + @classmethod + def task_id_params( + cls, request: ( + a2a_pb2.CancelTaskRequest | + a2a_pb2.TaskSubscriptionRequest | + a2a_pb2.GetTaskPushNotificationRequest + ), + ) -> types.TaskIdParams: + # This is currently incomplete until the core sdk supports multiple + # configs for a single task. + if isinstance(request, a2a_pb2.GetTaskPushNotificationRequest): + m = re.match(_TASK_PUSH_NOTIFICATION_NAME_MATCH, task.name) + if not m: + raise ServerError( + error=types.InvalidParamsError( + message=f'No task for {task.name}' + ) + ) + return types.TaskIdParams(id = m.group(1)) + m = re.match(_TASK_NAME_MATCH, task.name) + if not m: + raise ServerError( + error=types.InvalidParamsError(message=f'No task for {task.name}') + ) + return types.TaskIdParams(id = m.group(1)) + + @classmethod + def task_push_notification_config( + cls, request: a2a_pb2.CreateTaskPushNotificationRequest, + ) -> types.TaskPushNotificationConfig: + m = re.match(_TASK_NAME_MATCH, request.parent) + if not m: + raise ServerError( + error=types.InvalidParamsError( + message=f'No task for {request.parent}' + ) + ) + return types.TaskPushNotificationConfig( + pushNotificationConfig=cls.push_notification_config( + request.config.push_notification_config, + ), + taskId=m.group(1), + ) + + @classmethod + def task_query_params( + cls, request: a2a_pb2.GetTaskRequest, + ) -> types.TaskQueryParams: + m = re.match(_TASK_NAME_MATCH, request.name) + if not m: + raise ServerError( + error=types.InvalidParamsError( + message=f'No task for {request.name}' + ) + ) + return types.TaskQueryParams( + historyLength=request.history_length if request.history_length else None, + id=m.group(1), + metadata=None, + ) + + @classmethod + def agent_card( + cls, card: a2a_pb2.AgentCard, + ) -> types.AgentCard: + return types.AgentCard( + capabilities=cls.capabilities(card.capabilities), + defaultInputModes=list(card.default_input_modes), + defaultOutputModes=list(card.default_output_modes), + description=card.description, + documentationUrl=card.documentation_url, + name=card.name, + provider=cls.provider(card.provider), + security=cls.security(card.security), + securitySchemes=cls.security_schemes(card.security_schemes), + skills=[cls.skill(x) for x in card.skills] if card.skills else [], + url=card.url, + version=card.version, + supportsAuthenticatedExtendedCard=card.supports_authenticated_extended_card, + ) + + @classmethod + def capabilities( + cls, capabilities: a2a_pb2.AgentCapabilities + ) -> types.AgentCapabilities: + return types.AgentCapabilities( + streaming=capabilities.streaming, + pushNotifications=capabilities.push_notifications, + ) + + @classmethod + def provider( + cls, provider: a2a_pb2.AgentProvider | None + ) -> types.AgentProvider | None: + if not provider: + return None + return types.AgentProvider( + organization=provider.organization, + url=provider.url, + ) + + @classmethod + def security( + cls, security: list[a2a_pb2.Security] | None, + ) -> list[dict[str, list[str]]] | None: + if not security: + return None + rval: list[dict[str, list[str]]] = [] + for s in security: + rval.append({k: list(v.list) for (k, v) in s.items()}) + return rval + + @classmethod + def security_schemes( + cls, schemes: dict[str, a2a_pb2.SecurityScheme] | None, + ) -> dict[str, types.SecurityScheme] | None: + if not schemes: + return None + return {k: cls.security_scheme(v) for (k, v) in schemes.items()} + + @classmethod + def security_scheme( + cls, scheme: a2a_pb2.SecurityScheme, + ) -> types.SecurityScheme: + if scheme.HasApiKeySecurityScheme(): + return types.SecurityScheme(root=types.APIKeySecurityScheme( + description=scheme.api_key_security_scheme.description, + in_=scheme.api_key_security_scheme.location, + name=scheme.api_key_security_scheme.name, + )) + if scheme.HasHttpAuthSecurityScheme(): + return types.SecurityScheme(root=types.HTTPAuthSecurityScheme( + description=scheme.http_auth_security_scheme.description, + scheme=scheme.http_auth_security_scheme.scheme, + bearerFormat=scheme.http_auth_security_scheme.bearer_format, + )) + if scheme.HasOauth2SecurityScheme(): + return types.SecurityScheme(root=types.OAuth2SecurityScheme( + description=scheme.oauth2_security_scheme.description, + flows=cls.oauth2_flows(scheme.oauth2_security_scheme.flows), + )) + return types.SecurityScheme(root=types.OpenIdConnectSecurityScheme( + description=scheme.open_id_connect_security_scheme.description, + openIdConnectUrl=scheme.open_id_connect_security_scheme.open_id_connect_url, + )) + + @classmethod + def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: + if flows.HasAuthorizationCode(): + return types.OAuthFlows( + authorizationCode=types.AuthorizationCodeAuthFlow( + authorizationUrl=flows.authorization_code.authorization_url, + refreshUrl=flows.authorization_code.refresh_url, + scopes={ + k: v for (k, v) in flows.authorization_code.scopes.items() + }, + tokenUrl=flows.authorization_code.token_url, + ), + ) + if flows.HasClientCredentials(): + return types.OAuthFlows( + clientCredentials=types.ClientCredentialsAuthFlow( + refreshUrl=flows.client_credentials.refresh_url, + scopes={ + k:v for (k, v) in flows.client_credentials.scopes.items() + }, + tokenUrl=flows.client_credentials.token_url, + ), + ) + if flows.HasImplicit(): + return types.OAuthFlows( + implicit=types.ImplicitOAuthFlow( + authorizationUrl=flows.implicit.authorization_url, + refreshUrl=flows.implicit.refresh_url, + scopes={k: v for (k, v) in flows.implicit.scopes.items()}, + ), + ) + return types.OAuthFlows( + password=types.PasswordOAuthFlow( + refreshUrl=flows.password.refresh_url, + scopes={k: v for (k, v) in flows.password.scopes.items()}, + tokenUrl=flows.password.token_url, + ), + ) + + @classmethod + def skill(cls, skill: a2a_pb2.AgentSkill) -> types.AgentSkill: + return types.AgentSkill( + id=skill.id, + name=skill.name, + description=skill.description, + tags=list(skill.tags), + examples=list(skill.examples), + inputModes=list(skill.input_modes), + outputModes=list(skill.output_modes), + ) + + @classmethod + def role(cls, role: a2a_pb2.Role) -> types.Role: + match role: + case a2a_pb2.Role.ROLE_USER: + return types.Role.user + case a2a_pb2.Role.ROLE_AGENT: + return types.Role.agent + case _: + return types.Role.agent From a0681401d7d3e7ab809338a26d4c3a98106e109d Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Wed, 4 Jun 2025 20:53:42 +0000 Subject: [PATCH 12/29] Update pyproject.toml --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 991fc8df4..058831464 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,9 @@ dependencies = [ "pydantic>=2.11.3", "sse-starlette>=2.3.3", "starlette>=0.46.2", + "grpcio>=1.60", + "grpcio-tools>=1.60", + "grpcio_reflection>=1.7.0", ] classifiers = [ From 27b133900a90bb0d836c3da6fbe04c4b9fea040e Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Wed, 4 Jun 2025 16:13:09 -0500 Subject: [PATCH 13/29] Add spelling fixes/excludes --- .github/actions/spelling/allow.txt | 1 + .github/actions/spelling/excludes.txt | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index d7ef1e3c6..28b21c24f 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -34,6 +34,7 @@ linting oauthoidc opensource protoc +pyi pyversions socio sse diff --git a/.github/actions/spelling/excludes.txt b/.github/actions/spelling/excludes.txt index dbbff9989..d4c4eef14 100644 --- a/.github/actions/spelling/excludes.txt +++ b/.github/actions/spelling/excludes.txt @@ -85,7 +85,6 @@ \.zip$ ^\.github/actions/spelling/ ^\.github/workflows/ -^\Qsrc/a2a/auth/__init__.py\E$ -^\Qsrc/a2a/server/request_handlers/context.py\E$ CHANGELOG.md noxfile.py +^src/a2a/grpc/ From 693bfc6b67d56993c1bbeacf8055a320008c7d22 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Wed, 4 Jun 2025 16:15:06 -0500 Subject: [PATCH 14/29] Add grpc directory to jscpd ignore --- .github/linters/.jscpd.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/linters/.jscpd.json b/.github/linters/.jscpd.json index 5e86d6d82..e621690e5 100644 --- a/.github/linters/.jscpd.json +++ b/.github/linters/.jscpd.json @@ -1,5 +1,5 @@ { - "ignore": ["**/.github/**", "**/.git/**", "**/tests/**"], + "ignore": ["**/.github/**", "**/.git/**", "**/tests/**", "src/a2a/grpc/**"], "threshold": 3, "reporters": ["html", "markdown"] } From e94a9fcc447ab33d718926941f830b9cbfbd84d4 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Wed, 4 Jun 2025 16:20:18 -0500 Subject: [PATCH 15/29] spelling/linting --- .github/actions/spelling/allow.txt | 4 ++++ .github/linters/.ruff.toml | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 28b21c24f..37a32afda 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -2,10 +2,12 @@ ACard AClient AError AFast +AGrpc ARequest ARun AServer AServers +AService AStarlette EUR GBP @@ -13,9 +15,11 @@ INR JPY JSONRPCt Llm +RUF aconnect adk agentic +aio autouse cla cls diff --git a/.github/linters/.ruff.toml b/.github/linters/.ruff.toml index 29c4ff207..61a5d15fc 100644 --- a/.github/linters/.ruff.toml +++ b/.github/linters/.ruff.toml @@ -82,6 +82,7 @@ exclude = [ "venv", "*/migrations/*", "noxfile.py", + "src/a2a/grpc/**", ] [lint.isort] @@ -139,7 +140,7 @@ inline-quotes = "single" "types.py" = ["D", "E501", "N815"] # Ignore docstring and annotation issues in types.py [format] -exclude = ["types.py"] +exclude = ["types.py", "src/a2a/grpc/**"] docstring-code-format = true docstring-code-line-length = "dynamic" # Or set to 80 quote-style = "single" From feaa218481a1383b1ae02fb0794b9adf81908452 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Wed, 4 Jun 2025 16:24:08 -0500 Subject: [PATCH 16/29] Exclude grpc/ directory --- .github/linters/.ruff.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/linters/.ruff.toml b/.github/linters/.ruff.toml index 61a5d15fc..78c377929 100644 --- a/.github/linters/.ruff.toml +++ b/.github/linters/.ruff.toml @@ -82,7 +82,7 @@ exclude = [ "venv", "*/migrations/*", "noxfile.py", - "src/a2a/grpc/**", + "src/a2a/grpc/*.*", ] [lint.isort] From 52117a777a92625a65f22b6b160c7edb1a851def Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Wed, 4 Jun 2025 16:26:42 -0500 Subject: [PATCH 17/29] Update JSCPD to ignore `/src/a2a/grpc/**` --- .github/linters/.jscpd.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/linters/.jscpd.json b/.github/linters/.jscpd.json index e621690e5..5a6fcad71 100644 --- a/.github/linters/.jscpd.json +++ b/.github/linters/.jscpd.json @@ -1,5 +1,5 @@ { - "ignore": ["**/.github/**", "**/.git/**", "**/tests/**", "src/a2a/grpc/**"], + "ignore": ["**/.github/**", "**/.git/**", "**/tests/**", "**/src/a2a/grpc/**", "**/.nox/**", "**/.venv/**"], "threshold": 3, "reporters": ["html", "markdown"] } From 50539f4202c5f6abd71b5f1b4fdd85ba453bd6f3 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 5 Jun 2025 04:09:09 +0000 Subject: [PATCH 18/29] Add google.api dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 058831464..c1e7f4db9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "fastapi>=0.115.12", "httpx>=0.28.1", "httpx-sse>=0.4.0", + "google-api-core>=1.26.0", "opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0", "pydantic>=2.11.3", From 29ec4d4b2201e53b358361c2631ccc8002b7dd6c Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Fri, 6 Jun 2025 14:57:14 +0000 Subject: [PATCH 19/29] Fix lint/typing errors --- .../server/request_handlers/grpc_handler.py | 15 +- src/a2a/utils/proto_utils.py | 1528 +++++++++-------- 2 files changed, 781 insertions(+), 762 deletions(-) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 66c24804b..03f0ca707 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -102,6 +102,7 @@ async def SendMessage( return proto_utils.ToProto.task_or_message(task_or_message) except ServerError as e: await self.abort_context(e, context) + return a2a_pb2.SendMessageResponse() @validate_async_generator( lambda self: self.agent_card.capabilities.streaming, @@ -162,9 +163,12 @@ async def CancelTask( ) if task: return proto_utils.ToProto.task(task) - self.abort_context(ServerError(error=TaskNotFoundError()), context) + await self.abort_context( + ServerError(error=TaskNotFoundError()), context + ) except ServerError as e: await self.abort_context(e, context) + return a2a_pb2.Task() @validate_async_generator( lambda self: self.agent_card.capabilities.streaming, @@ -221,6 +225,7 @@ async def GetTaskPushNotification( return proto_utils.ToProto.task_push_notification_config(config) except ServerError as e: await self.abort_context(e, context) + return a2a_pb2.TaskPushNotificationConfig() @validate( lambda self: self.agent_card.capabilities.pushNotifications, @@ -230,7 +235,7 @@ async def CreateTaskPushNotification( self, request: a2a_pb2.CreateTaskPushNotificationRequest, context: grpc.aio.ServicerContext, - ) -> TaskPushNotificationConfig: + ) -> a2a_pb2.TaskPushNotificationConfig: """Handles the 'CreateTaskPushNotification' gRPC method. Requires the agent to support push notifications. @@ -259,6 +264,7 @@ async def CreateTaskPushNotification( return proto_utils.ToProto.task_push_notification_config(config) except ServerError as e: await self.abort_context(e, context) + return a2a_pb2.TaskPushNotificationConfig() async def GetTask( self, @@ -281,9 +287,12 @@ async def GetTask( ) if task: return proto_utils.ToProto.task(task) - self.abort_context(ServerError(error=TaskNotFoundError()), context) + await self.abort_context( + ServerError(error=TaskNotFoundError()), context + ) except ServerError as e: await self.abort_context(e, context) + return a2a_pb2.Task() async def GetAgentCard( self, diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index bc78abbf8..7905eae0a 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="arg-type" """Utils for converting between proto and Python types.""" import json @@ -12,770 +13,779 @@ # Regexp patterns for matching -_TASK_NAME_MATCH = r'tasks/(\w+)' -_TASK_PUSH_CONFIG_NAME_MATCH = r'tasks/(\w+)/pushNotifications/(\w+)' +_TASK_NAME_MATCH = r"tasks/(\w+)" +_TASK_PUSH_CONFIG_NAME_MATCH = r"tasks/(\w+)/pushNotifications/(\w+)" class ToProto: - """Converts Python types to proto types.""" - - @classmethod - def message(cls, message: types.Message | None) -> a2a_pb2.Message | None: - if message is None: - return None - return a2a_pb2.Message( - message_id=message.messageId, - content=[ToProto.part(p) for p in message.parts], - context_id=message.contextId, - task_id=message.taskId, - role=cls.role(message.role.name), - metadata=ToProto.metadata(message.metadata), - ) - - @classmethod - def metadata( - cls, metadata: Dict[str, Any] | None - ) -> struct_pb2.Struct | None: - if metadata is None: - return None - return struct_pb2.Struct( - # TODO: Add support for other types. - fields={ - key: struct_pb2.Value(string_value=value) - for key, value in metadata.items() - if isinstance(value, str) - } - ) - - @classmethod - def part(cls, part: types.Part) -> a2a_pb2.Part: - if isinstance(part.root, types.TextPart): - return a2a_pb2.Part(text=part.root.text) - elif isinstance(part.root, types.FilePart): - return a2a_pb2.Part(file=ToProto.file(part.root.file)) - elif isinstance(part.root, types.DataPart): - return a2a_pb2.Part(data=ToProto.data(part.root.data)) - else: - raise ValueError(f'Unsupported part type: {part.root}') - - @classmethod - def data(cls, data: Dict[str, Any]) -> a2a_pb2.DataPart: - json_data = json.dumps(data) - return a2a_pb2.DataPart( - data=json_format.Parse( - json_data, - struct_pb2.Struct(), - ) - ) - - @classmethod - def file( - cls, file: types.FileWithUri | types.FileWithBytes - ) -> a2a_pb2.FilePart: - if isinstance(file, types.FileWithUri): - return a2a_pb2.FilePart(file_with_uri=file.uri) - return a2a_pb2.FilePart(file_with_bytes=file.bytes.encode('utf-8')) - - @classmethod - def task(cls, task: types.Task) -> a2a_pb2.Task: - return a2a_pb2.Task( - id=task.id, - context_id=task.contextId, - status=ToProto.task_status(task.status), - artifacts=([ - ToProto.artifact(a) for a in task.artifacts - ] if task.artifacts else None), - history=([ - ToProto.message(h) for h in task.history - ] if task.history else None), - ) - - @classmethod - def task_status(cls, status: types.TaskStatus) -> a2a_pb2.TaskStatus: - return a2a_pb2.TaskStatus( - state=ToProto.task_state(status.state), - update=ToProto.message(status.message), - ) - - @classmethod - def task_state(cls, state: types.TaskState) -> a2a_pb2.TaskState: - match state: - case types.TaskState.submitted: - return a2a_pb2.TaskState.TASK_STATE_SUBMITTED - case types.TaskState.working: - return a2a_pb2.TaskState.TASK_STATE_WORKING - case types.TaskState.completed: - return a2a_pb2.TaskState.TASK_STATE_COMPLETED - case types.TaskState.canceled: - return a2a_pb2.TaskState.TASK_STATE_CANCELLED - case types.TaskState.failed: - return a2a_pb2.TaskState.TASK_STATE_FAILED - case types.TaskState.input_required: - return a2a_pb2.TaskState.TASK_STATE_INPUT_REQUIRED - case _: - return a2a_pb2.TaskState.TASK_STATE_UNSPECIFIED - - @classmethod - def artifact(cls, artifact: types.Artifact) -> a2a_pb2.Artifact: - return a2a_pb2.Artifact( - artifact_id=artifact.artifactId, - description=artifact.description, - metadata=ToProto.metadata(artifact.metadata), - name=artifact.name, - parts=[ToProto.part(p) for p in artifact.parts], - ) - - @classmethod - def authentication_info( - cls, info: types.PushNotificationAuthenticationInfo - ) -> a2a_pb2.AuthenticationInfo: - return a2a_pb2.AuthenticationInfo( - schemes=info.schemes, - credentials=info.credentials, - ) - - @classmethod - def push_notification_config( - cls, config: types.PushNotificationConfig - ) -> a2a_pb2.PushNotificationConfig: - return a2a_pb2.PushNotificationConfig( - id=config.id if id else "", - url=config.url, - token=config.token, - authentication=ToProto.authentication_info(config.authentication), - ) - - @classmethod - def task_artifact_update_event( - cls, event: types.TaskArtifactUpdateEvent - ) -> a2a_pb2.TaskArtifactUpdateEvent: - return a2a_pb2.TaskArtifactUpdateEvent( - task_id=event.taskId, - context_id=event.contextId, - artifact=ToProto.artifact(event.artifact), - metadata=ToProto.metadata(event.metadata), - append=event.append, - last_chunk=event.lastChunk, - ) - - @classmethod - def task_status_update_event( - cls, event: types.TaskStatusUpdateEvent - ) -> a2a_pb2.TaskStatusUpdateEvent: - return a2a_pb2.TaskStatusUpdateEvent( - task_id=event.taskId, - context_id=event.contextId, - status=ToProto.task_status(event.status), - metadata=ToProto.metadata(event.metadata), - final=event.final, - ) - - @classmethod - def message_send_configuration( - cls, config: types.MessageSendConfiguration | None - ) -> a2a_pb2.SendMessageConfiguration: - if not config: - return a2a_pb2.SendMessageConfiguration() - return a2a_pb2.SendMessageConfiguration( - accepted_output_modes=list(config.acceptedOutputModes), - push_notification=ToProto.push_notification_config( - config.pushNotificationConfig - ), - history_length=config.historyLength, - blocking=config.blocking, - ) - - @classmethod - def update_event( - cls, event: types.Task | types.Message | types.TaskStatusUpdateEvent - | types.TaskArtifactUpdateEvent - ) -> a2a_pb2.StreamResponse: - """Converts a task, message, or task update event to a StreamResponse.""" - if isinstance(event, types.TaskStatusUpdateEvent): - return a2a_pb2.StreamResponse( - status_update=ToProto.task_status_update_event(event) - ) - elif isinstance(event, types.TaskArtifactUpdateEvent): - return a2a_pb2.StreamResponse( - artifact_update=ToProto.task_artifact_update_event(event) - ) - elif isinstance(event, types.Message): - return a2a_pb2.StreamResponse(msg=ToProto.message(event)) - elif isinstance(event, types.Task): - return a2a_pb2.StreamResponse(task=ToProto.task(event)) - else: - raise ValueError(f'Unsupported event type: {type(event)}') - - @classmethod - def task_or_message( - cls, event: types.Task | types.Message - ) -> a2a_pb2.SendMessageResponse: - if isinstance(event, types.Message): - return a2a_pb2.SendMessageResponse( - msg=cls.message(event), - ) - return a2a_pb2.SendMessageResponse( - task=cls.task(event), - ) - - @classmethod - def stream_response( - cls, - event: ( - types.Message | - types.Task | - types.TaskStatusUpdateEvent | - types.TaskArtifactUpdateEvent) - ) -> a2a_pb2.StreamResponse: - if isinstance(event, types.Message): - return a2a_pb2.StreamResponse(msg=cls.message(event)) - elif isinstance(event, types.Task): - return a2a_pb2.StreamResponse(task=cls.task(event)) - elif isinstance(event, types.TaskStatusUpdateEvent): - return a2a_pb2.StreamResponse( - status_update=cls.task_status_update_event(event), - ) - return a2a_pb2.StreamResponse( - artifact_update=cls.task_artifact_update_event(event), - ) - - @classmethod - def task_push_notification_config( - cls, - config: types.TaskPushNotificationConfig - ) -> a2a_pb2.TaskPushNotificationConfig: - return a2a_pb2.TaskPushNotificationConfig( - name=f'tasks/{config.taskId}/pushNotifications/{config.taskId}', - push_notification_config=cls.push_notification_config( - config.pushNotificationConfig, - ), - ) - - @classmethod - def agent_card( - cls, card: types.AgentCard, - ) -> a2a_pb2.AgentCard: - return a2a_pb2.AgentCard( - capabilities=cls.capabilities(card.capabilities), - default_input_modes=list(card.defaultInputModes), - default_output_modes=list(card.defaultOutputModes), - description=card.description, - documentation_url=card.documentationUrl, - name=card.name, - provider=cls.provider(card.provider), - security=cls.security(card.security), - security_schemes=cls.security_schemes(card.securitySchemes), - skills=[cls.skill(x) for x in card.skills] if card.skills else [], - url=card.url, - version=card.version, - supports_authenticated_extended_card=card.supportsAuthenticatedExtendedCard, - ) - - @classmethod - def capabilities( - cls, capabilities: types.AgentCapabilities - ) -> a2a_pb2.AgentCapabilities: - return a2a_pb2.AgentCapabilities( - streaming=capabilities.streaming, - push_notifications=capabilities.pushNotifications, - ) - - @classmethod - def provider( - cls, provider: types.AgentProvider | None - ) -> a2a_pb2.AgentProvider | None: - if not provider: - return None - return a2a_pb2.AgentProvider( - organization=provider.organization, - url=provider.url, - ) - - @classmethod - def security( - cls, security: list[dict[str, list[str]]] | None, - ) -> list[a2a_pb2.Security] | None: - if not security: - return None - rval: list[a2a_pb2.Security] = [] - for s in security: - rval.append( - a2a_pb2.Security( - schemes={ - k: a2a_pb2.StringList(list=v.list) for (k, v) in s.items() - } - ) - ) - return rval - - @classmethod - def security_schemes( - cls, schemes: dict[str, types.SecurityScheme] | None, - ) -> dict[str, a2a_pb2.SecurityScheme] | None: - if not schemes: - return None - return {k: cls.security_scheme(v) for (k, v) in schemes.items()} - - @classmethod - def security_scheme( - cls, scheme: types.SecurityScheme, - ) -> a2a_pb2.SecurityScheme: - if isinstance(scheme.root, types.ApiKeySecurityScheme): - return a2a_pb2.SecurityScheme( - api_key_security_scheme=a2a_pb2.APIKeySecurityScheme( - description=scheme.root.description, - location=scheme.root.in_, - name=scheme.root.name, - ) - ) - if isinstance(scheme.root, types.HTTPAuthSecurityScheme): - return a2a_pb2.SecurityScheme( - http_auth_security_scheme=a2a_pb2.HttpAuthSecurityScheme( - description=scheme.root.description, - scheme=scheme.root.scheme, - bearer_format=scheme.root.bearerFormat, - ) - ) - if isinstance(scheme.root, types.Oauth2SecurityScheme): - return a2a_pb2.SecurityScheme( - oauth2_security_scheme=a2a_pb2.OAuth2SecurityScheme( - description=scheme.root.description, - flows=cls.oauth2_flows(scheme.root.flows), - ) - ) - return a2a_pb2.SecurityScheme( - open_id_connect_security_scheme=a2a_pb2.OpenIdConnectSecurityScheme( - description=scheme.root.description, - open_id_connect_url=scheme.root.openIdConnectUrl, - ) - ) - - @classmethod - def oauth2_flows(cls, flows: types.OAuthFlows) -> a2a_pb2.OAuthFlows: - if flows.authorizationCode: - return a2a_pb2.OAuthFlows( - authorization_code=a2a_pb2.AuthorizationCodeAuthFlow( - authorization_url=flows.authorizationCode.authorizationUrl, - refresh_url=flows.authorizationCode.refreshUrl, - scopes={ - k: v for (k, v) in flows.authorizationCode.scopes.items() - }, - token_url=flows.authorizationCode.tokenUrl, - ), - ) - if flows.clientCredentials: - return a2a_pb2.OAuthFlows( - client_credentials=a2a_pb2.ClientCredentialsAuthFlow( - refresh_url=flows.clientCredentials.refreshUrl, - scopes={ - k:v for (k, v) in flows.clientCredentials.scopes.items() - }, - token_url=flows.client_credentials.tokenUrl, - ), - ) - if flows.implicit: - return a2a_pb2.OAuthFlows( - implicit=a2a_pb2.ImplicitOAuthFlow( - authorization_url=flows.implicit.authorization_Url, - refresh_url=flows.implicit.refreshUrl, - scopes={k: v for (k, v) in flows.implicit.scopes.items()}, - ), - ) - return a2a_pb2.OAuthFlows( - password=types.PasswordOAuthFlow( - refresh_url=flows.password.refreshUrl, - scopes={k: v for (k, v) in flows.password.scopes.items()}, - token_url=flows.password.tokenUrl, + """Converts Python types to proto types.""" + + @classmethod + def message(cls, message: types.Message | None) -> a2a_pb2.Message | None: + if message is None: + return None + return a2a_pb2.Message( + message_id=message.messageId, + content=[ToProto.part(p) for p in message.parts], + context_id=message.contextId, + task_id=message.taskId, + role=cls.role(message.role), + metadata=ToProto.metadata(message.metadata), + ) + + @classmethod + def metadata(cls, metadata: Dict[str, Any] | None) -> struct_pb2.Struct | None: + if metadata is None: + return None + return struct_pb2.Struct( + # TODO: Add support for other types. + fields={ + key: struct_pb2.Value(string_value=value) + for key, value in metadata.items() + if isinstance(value, str) + } + ) + + @classmethod + def part(cls, part: types.Part) -> a2a_pb2.Part: + if isinstance(part.root, types.TextPart): + return a2a_pb2.Part(text=part.root.text) + elif isinstance(part.root, types.FilePart): + return a2a_pb2.Part(file=ToProto.file(part.root.file)) + elif isinstance(part.root, types.DataPart): + return a2a_pb2.Part(data=ToProto.data(part.root.data)) + else: + raise ValueError(f"Unsupported part type: {part.root}") + + @classmethod + def data(cls, data: Dict[str, Any]) -> a2a_pb2.DataPart: + json_data = json.dumps(data) + return a2a_pb2.DataPart( + data=json_format.Parse( + json_data, + struct_pb2.Struct(), + ) + ) + + @classmethod + def file(cls, file: types.FileWithUri | types.FileWithBytes) -> a2a_pb2.FilePart: + if isinstance(file, types.FileWithUri): + return a2a_pb2.FilePart(file_with_uri=file.uri) + return a2a_pb2.FilePart(file_with_bytes=file.bytes.encode("utf-8")) + + @classmethod + def task(cls, task: types.Task) -> a2a_pb2.Task: + return a2a_pb2.Task( + id=task.id, + context_id=task.contextId, + status=ToProto.task_status(task.status), + artifacts=( + [ToProto.artifact(a) for a in task.artifacts] + if task.artifacts + else None + ), + history=( + [ToProto.message(h) for h in task.history] if task.history else None + ), + ) + + @classmethod + def task_status(cls, status: types.TaskStatus) -> a2a_pb2.TaskStatus: + return a2a_pb2.TaskStatus( + state=ToProto.task_state(status.state), + update=ToProto.message(status.message), + ) + + @classmethod + def task_state(cls, state: types.TaskState) -> a2a_pb2.TaskState: + match state: + case types.TaskState.submitted: + return a2a_pb2.TaskState.TASK_STATE_SUBMITTED + case types.TaskState.working: + return a2a_pb2.TaskState.TASK_STATE_WORKING + case types.TaskState.completed: + return a2a_pb2.TaskState.TASK_STATE_COMPLETED + case types.TaskState.canceled: + return a2a_pb2.TaskState.TASK_STATE_CANCELLED + case types.TaskState.failed: + return a2a_pb2.TaskState.TASK_STATE_FAILED + case types.TaskState.input_required: + return a2a_pb2.TaskState.TASK_STATE_INPUT_REQUIRED + case _: + return a2a_pb2.TaskState.TASK_STATE_UNSPECIFIED + + @classmethod + def artifact(cls, artifact: types.Artifact) -> a2a_pb2.Artifact: + return a2a_pb2.Artifact( + artifact_id=artifact.artifactId, + description=artifact.description, + metadata=ToProto.metadata(artifact.metadata), + name=artifact.name, + parts=[ToProto.part(p) for p in artifact.parts], + ) + + @classmethod + def authentication_info( + cls, info: types.PushNotificationAuthenticationInfo + ) -> a2a_pb2.AuthenticationInfo: + return a2a_pb2.AuthenticationInfo( + schemes=info.schemes, + credentials=info.credentials, + ) + + @classmethod + def push_notification_config( + cls, config: types.PushNotificationConfig + ) -> a2a_pb2.PushNotificationConfig: + return a2a_pb2.PushNotificationConfig( + id=config.id or "", + url=config.url, + token=config.token, + authentication=ToProto.authentication_info(config.authentication), + ) + + @classmethod + def task_artifact_update_event( + cls, event: types.TaskArtifactUpdateEvent + ) -> a2a_pb2.TaskArtifactUpdateEvent: + return a2a_pb2.TaskArtifactUpdateEvent( + task_id=event.taskId, + context_id=event.contextId, + artifact=ToProto.artifact(event.artifact), + metadata=ToProto.metadata(event.metadata), + append=event.append or False, + last_chunk=event.lastChunk or False, + ) + + @classmethod + def task_status_update_event( + cls, event: types.TaskStatusUpdateEvent + ) -> a2a_pb2.TaskStatusUpdateEvent: + return a2a_pb2.TaskStatusUpdateEvent( + task_id=event.taskId, + context_id=event.contextId, + status=ToProto.task_status(event.status), + metadata=ToProto.metadata(event.metadata), + final=event.final, + ) + + @classmethod + def message_send_configuration( + cls, config: types.MessageSendConfiguration | None + ) -> a2a_pb2.SendMessageConfiguration: + if not config: + return a2a_pb2.SendMessageConfiguration() + return a2a_pb2.SendMessageConfiguration( + accepted_output_modes=list(config.acceptedOutputModes), + push_notification=ToProto.push_notification_config( + config.pushNotificationConfig + ), + history_length=config.historyLength, + blocking=config.blocking or False, + ) + + @classmethod + def update_event( + cls, + event: types.Task + | types.Message + | types.TaskStatusUpdateEvent + | types.TaskArtifactUpdateEvent, + ) -> a2a_pb2.StreamResponse: + """Converts a task, message, or task update event to a StreamResponse.""" + if isinstance(event, types.TaskStatusUpdateEvent): + return a2a_pb2.StreamResponse( + status_update=ToProto.task_status_update_event(event) + ) + elif isinstance(event, types.TaskArtifactUpdateEvent): + return a2a_pb2.StreamResponse( + artifact_update=ToProto.task_artifact_update_event(event) + ) + elif isinstance(event, types.Message): + return a2a_pb2.StreamResponse(msg=ToProto.message(event)) + elif isinstance(event, types.Task): + return a2a_pb2.StreamResponse(task=ToProto.task(event)) + else: + raise ValueError(f"Unsupported event type: {type(event)}") + + @classmethod + def task_or_message( + cls, event: types.Task | types.Message + ) -> a2a_pb2.SendMessageResponse: + if isinstance(event, types.Message): + return a2a_pb2.SendMessageResponse( + msg=cls.message(event), + ) + return a2a_pb2.SendMessageResponse( + task=cls.task(event), + ) + + @classmethod + def stream_response( + cls, + event: ( + types.Message + | types.Task + | types.TaskStatusUpdateEvent + | types.TaskArtifactUpdateEvent ), - ) - - @classmethod - def skill(cls, skill: types.AgentSkill) -> a2a_pb2.AgentSkill: - return a2a_pb2.AgentSkill( - id=skill.id, - name=skill.name, - description=skill.description, - tags=skill.tags, - examples=skill.examples, - input_modes=skill.inputModes, - output_modes=skill.outputModes, - ) - - @classmethod - def role(cls, role: types.Role) -> a2a_pb2.Role: - match role: - case types.Role.user: - return a2a_pb2.Role.ROLE_USER - case types.Role.agent: - return a2a_pb2.Role.ROLE_AGENT - case _: - return a2a_pb2.Role.ROLE_UNSPECIFIED + ) -> a2a_pb2.StreamResponse: + if isinstance(event, types.Message): + return a2a_pb2.StreamResponse(msg=cls.message(event)) + elif isinstance(event, types.Task): + return a2a_pb2.StreamResponse(task=cls.task(event)) + elif isinstance(event, types.TaskStatusUpdateEvent): + return a2a_pb2.StreamResponse( + status_update=cls.task_status_update_event(event), + ) + return a2a_pb2.StreamResponse( + artifact_update=cls.task_artifact_update_event(event), + ) + + @classmethod + def task_push_notification_config( + cls, config: types.TaskPushNotificationConfig + ) -> a2a_pb2.TaskPushNotificationConfig: + return a2a_pb2.TaskPushNotificationConfig( + name=f"tasks/{config.taskId}/pushNotifications/{config.taskId}", + push_notification_config=cls.push_notification_config( + config.pushNotificationConfig, + ), + ) + + @classmethod + def agent_card( + cls, + card: types.AgentCard, + ) -> a2a_pb2.AgentCard: + return a2a_pb2.AgentCard( + capabilities=cls.capabilities(card.capabilities), + default_input_modes=list(card.defaultInputModes), + default_output_modes=list(card.defaultOutputModes), + description=card.description, + documentation_url=card.documentationUrl, + name=card.name, + provider=cls.provider(card.provider), + security=cls.security(card.security), + security_schemes=cls.security_schemes(card.securitySchemes), + skills=[cls.skill(x) for x in card.skills] if card.skills else [], + url=card.url, + version=card.version, + supports_authenticated_extended_card=card.supportsAuthenticatedExtendedCard, + ) + + @classmethod + def capabilities( + cls, capabilities: types.AgentCapabilities + ) -> a2a_pb2.AgentCapabilities: + return a2a_pb2.AgentCapabilities( + streaming=capabilities.streaming, + push_notifications=capabilities.pushNotifications, + ) + + @classmethod + def provider( + cls, provider: types.AgentProvider | None + ) -> a2a_pb2.AgentProvider | None: + if not provider: + return None + return a2a_pb2.AgentProvider( + organization=provider.organization, + url=provider.url, + ) + + @classmethod + def security( + cls, + security: list[dict[str, list[str]]] | None, + ) -> list[a2a_pb2.Security] | None: + if not security: + return None + rval: list[a2a_pb2.Security] = [] + for s in security: + rval.append( + a2a_pb2.Security( + schemes={k: a2a_pb2.StringList(list=v) for (k, v) in s.items()} + ) + ) + return rval + + @classmethod + def security_schemes( + cls, + schemes: dict[str, types.SecurityScheme] | None, + ) -> dict[str, a2a_pb2.SecurityScheme] | None: + if not schemes: + return None + return {k: cls.security_scheme(v) for (k, v) in schemes.items()} + + @classmethod + def security_scheme( + cls, + scheme: types.SecurityScheme, + ) -> a2a_pb2.SecurityScheme: + if isinstance(scheme.root, types.APIKeySecurityScheme): + return a2a_pb2.SecurityScheme( + api_key_security_scheme=a2a_pb2.APIKeySecurityScheme( + description=scheme.root.description, + location=scheme.root.in_, + name=scheme.root.name, + ) + ) + if isinstance(scheme.root, types.HTTPAuthSecurityScheme): + return a2a_pb2.SecurityScheme( + http_auth_security_scheme=a2a_pb2.HTTPAuthSecurityScheme( + description=scheme.root.description, + scheme=scheme.root.scheme, + bearer_format=scheme.root.bearerFormat, + ) + ) + if isinstance(scheme.root, types.OAuth2SecurityScheme): + return a2a_pb2.SecurityScheme( + oauth2_security_scheme=a2a_pb2.OAuth2SecurityScheme( + description=scheme.root.description, + flows=cls.oauth2_flows(scheme.root.flows), + ) + ) + return a2a_pb2.SecurityScheme( + open_id_connect_security_scheme=a2a_pb2.OpenIdConnectSecurityScheme( + description=scheme.root.description, + open_id_connect_url=scheme.root.openIdConnectUrl, + ) + ) + + @classmethod + def oauth2_flows(cls, flows: types.OAuthFlows) -> a2a_pb2.OAuthFlows: + if flows.authorizationCode: + return a2a_pb2.OAuthFlows( + authorization_code=a2a_pb2.AuthorizationCodeOAuthFlow( + authorization_url=flows.authorizationCode.authorizationUrl, + refresh_url=flows.authorizationCode.refreshUrl, + scopes={ + k: v for (k, v) in flows.authorizationCode.scopes.items() + }, + token_url=flows.authorizationCode.tokenUrl, + ), + ) + if flows.clientCredentials: + return a2a_pb2.OAuthFlows( + client_credentials=a2a_pb2.ClientCredentialsOAuthFlow( + refresh_url=flows.clientCredentials.refreshUrl, + scopes={ + k: v for (k, v) in flows.clientCredentials.scopes.items() + }, + token_url=flows.client_credentials.tokenUrl, + ), + ) + if flows.implicit: + return a2a_pb2.OAuthFlows( + implicit=a2a_pb2.ImplicitOAuthFlow( + authorization_url=flows.implicit.authorization_Url, + refresh_url=flows.implicit.refreshUrl, + scopes={k: v for (k, v) in flows.implicit.scopes.items()}, + ), + ) + if flows.password: + return a2a_pb2.OAuthFlows( + password=types.PasswordOAuthFlow( + refresh_url=flows.password.refreshUrl, + scopes={k: v for (k, v) in flows.password.scopes.items()}, + token_url=flows.password.tokenUrl, + ), + ) + raise ValueError("Unknown oauth flow definition") + + @classmethod + def skill(cls, skill: types.AgentSkill) -> a2a_pb2.AgentSkill: + return a2a_pb2.AgentSkill( + id=skill.id, + name=skill.name, + description=skill.description, + tags=skill.tags, + examples=skill.examples, + input_modes=skill.inputModes, + output_modes=skill.outputModes, + ) + + @classmethod + def role(cls, role: types.Role) -> a2a_pb2.Role: + match role: + case types.Role.user: + return a2a_pb2.Role.ROLE_USER + case types.Role.agent: + return a2a_pb2.Role.ROLE_AGENT + case _: + return a2a_pb2.Role.ROLE_UNSPECIFIED class FromProto: - """Converts proto types to Python types.""" - - @classmethod - def message(cls, message: a2a_pb2.Message) -> types.Message: - return types.Message( - messageId=message.message_id, - parts=[FromProto.part(p) for p in message.content], - contextId=message.context_id, - taskId=message.task_id, - role=FromProto.role(message.role), - metadata=FromProto.metadata(message.metadata), - ) - - @classmethod - def metadata(cls, metadata: struct_pb2.Struct) -> Dict[str, Any]: - return { - key: value.string_value - for key, value in metadata.fields.items() - if value.string_value - } - - @classmethod - def part(cls, part: a2a_pb2.Part) -> types.Part: - if part.HasField('text'): - return types.Part(root=types.TextPart(text=part.text)) - elif part.HasField('file'): - return types.Part(root=types.FilePart(file=FromProto.file(part.file))) - elif part.HasField('data'): - return types.Part(root=types.DataPart(data=FromProto.data(part.data))) - else: - raise ValueError(f'Unsupported part type: {part}') - - @classmethod - def data(cls, data: a2a_pb2.DataPart) -> Dict[str, Any]: - json_data = json_format.MessageToJson(data.data) - return json.loads(json_data) - - @classmethod - def file( - cls, file: a2a_pb2.FilePart - ) -> types.FileWithUri | types.FileWithBytes: - if file.HasField('file_with_uri'): - return types.FileWithUri(uri=file.file_with_uri) - return types.FileWithBytes(bytes=file.file_with_bytes.decode('utf-8')) - - @classmethod - def task(cls, task: a2a_pb2.Task) -> types.Task: - return types.Task( - id=task.id, - contextId=task.context_id, - status=FromProto.task_status(task.status), - artifacts=[FromProto.artifact(a) for a in task.artifacts], - history=[FromProto.message(h) for h in task.history], - ) - - @classmethod - def task_status(cls, status: a2a_pb2.TaskStatus) -> types.TaskStatus: - return types.TaskStatus( - state=FromProto.task_state(status.state), - message=FromProto.message(status.update), - ) - - @classmethod - def task_state(cls, state: a2a_pb2.TaskState) -> types.TaskState: - match state: - case a2a_pb2.TaskState.TASK_STATE_SUBMITTED: - return types.TaskState.submitted - case a2a_pb2.TaskState.TASK_STATE_WORKING: - return types.TaskState.working - case a2a_pb2.TaskState.TASK_STATE_COMPLETED: - return types.TaskState.completed - case a2a_pb2.TaskState.TASK_STATE_CANCELLED: - return types.TaskState.canceled - case a2a_pb2.TaskState.TASK_STATE_FAILED: - return types.TaskState.failed - case a2a_pb2.TaskState.TASK_STATE_INPUT_REQUIRED: - return types.TaskState.input_required - case _: - return types.TaskState.unknown - - @classmethod - def artifact(cls, artifact: a2a_pb2.Artifact) -> types.Artifact: - return types.Artifact( - artifactId=artifact.artifact_id, - description=artifact.description, - metadata=FromProto.metadata(artifact.metadata), - name=artifact.name, - parts=[FromProto.part(p) for p in artifact.parts], - ) - - @classmethod - def task_artifact_update_event( - cls, event: a2a_pb2.TaskArtifactUpdateEvent - ) -> types.TaskArtifactUpdateEvent: - return types.TaskArtifactUpdateEvent( - taskId=event.task_id, - contextId=event.context_id, - artifact=FromProto.artifact(event.artifact), - metadata=FromProto.metadata(event.metadata), - append=event.append, - lastChunk=event.last_chunk, - ) - - @classmethod - def task_status_update_event( - cls, event: a2a_pb2.TaskStatusUpdateEvent - ) -> types.TaskStatusUpdateEvent: - return types.TaskStatusUpdateEvent( - taskId=event.task_id, - contextId=event.context_id, - status=FromProto.task_status(event.status), - metadata=FromProto.metadata(event.metadata), - final=event.final, - ) - - @classmethod - def push_notification_config( - cls, config: a2a_pb2.PushNotificationConfig - ) -> types.PushNotificationConfig: - return types.PushNotificationConfig( - id=config.id, - url=config.url, - token=config.token, - authentication=FromProto.authentication_info(config.authentication), - ) - - @classmethod - def authentication_info( - cls, info: a2a_pb2.AuthenticationInfo - ) -> types.PushNotificationAuthenticationInfo: - return types.PushNotificationAuthenticationInfo( - schemes=list(info.schemes), - credentials=info.credentials, - ) - - @classmethod - def message_send_configuration( - cls, config: a2a_pb2.SendMessageConfiguration - ) -> types.MessageSendConfiguration: - return types.MessageSendConfiguration( - acceptedOutputModes=list(config.accepted_output_modes), - pushNotificationConfig=FromProto.push_notification_config( - config.push_notification - ), - historyLength=config.history_length, - blocking=config.blocking, - ) - - @classmethod - def message_send_params( - cls, request: a2a_pb2.SendMessageRequest - ) -> types.MessageSendParams: - return types.MessageSendParams( - configuration=cls.message_send_configuration( - request.configuration - ), - message=cls.message(request.request), - metadata=cls.metadata(request.metadata), - ) - - @classmethod - def task_id_params( - cls, request: ( - a2a_pb2.CancelTaskRequest | - a2a_pb2.TaskSubscriptionRequest | - a2a_pb2.GetTaskPushNotificationRequest - ), - ) -> types.TaskIdParams: - # This is currently incomplete until the core sdk supports multiple - # configs for a single task. - if isinstance(request, a2a_pb2.GetTaskPushNotificationRequest): - m = re.match(_TASK_PUSH_NOTIFICATION_NAME_MATCH, task.name) - if not m: - raise ServerError( - error=types.InvalidParamsError( - message=f'No task for {task.name}' - ) - ) - return types.TaskIdParams(id = m.group(1)) - m = re.match(_TASK_NAME_MATCH, task.name) - if not m: - raise ServerError( - error=types.InvalidParamsError(message=f'No task for {task.name}') - ) - return types.TaskIdParams(id = m.group(1)) - - @classmethod - def task_push_notification_config( - cls, request: a2a_pb2.CreateTaskPushNotificationRequest, - ) -> types.TaskPushNotificationConfig: - m = re.match(_TASK_NAME_MATCH, request.parent) - if not m: - raise ServerError( - error=types.InvalidParamsError( - message=f'No task for {request.parent}' - ) - ) - return types.TaskPushNotificationConfig( - pushNotificationConfig=cls.push_notification_config( - request.config.push_notification_config, - ), - taskId=m.group(1), - ) - - @classmethod - def task_query_params( - cls, request: a2a_pb2.GetTaskRequest, - ) -> types.TaskQueryParams: - m = re.match(_TASK_NAME_MATCH, request.name) - if not m: - raise ServerError( - error=types.InvalidParamsError( - message=f'No task for {request.name}' - ) - ) - return types.TaskQueryParams( - historyLength=request.history_length if request.history_length else None, - id=m.group(1), - metadata=None, - ) - - @classmethod - def agent_card( - cls, card: a2a_pb2.AgentCard, - ) -> types.AgentCard: - return types.AgentCard( - capabilities=cls.capabilities(card.capabilities), - defaultInputModes=list(card.default_input_modes), - defaultOutputModes=list(card.default_output_modes), - description=card.description, - documentationUrl=card.documentation_url, - name=card.name, - provider=cls.provider(card.provider), - security=cls.security(card.security), - securitySchemes=cls.security_schemes(card.security_schemes), - skills=[cls.skill(x) for x in card.skills] if card.skills else [], - url=card.url, - version=card.version, - supportsAuthenticatedExtendedCard=card.supports_authenticated_extended_card, - ) - - @classmethod - def capabilities( - cls, capabilities: a2a_pb2.AgentCapabilities - ) -> types.AgentCapabilities: - return types.AgentCapabilities( - streaming=capabilities.streaming, - pushNotifications=capabilities.push_notifications, - ) - - @classmethod - def provider( - cls, provider: a2a_pb2.AgentProvider | None - ) -> types.AgentProvider | None: - if not provider: - return None - return types.AgentProvider( - organization=provider.organization, - url=provider.url, - ) - - @classmethod - def security( - cls, security: list[a2a_pb2.Security] | None, - ) -> list[dict[str, list[str]]] | None: - if not security: - return None - rval: list[dict[str, list[str]]] = [] - for s in security: - rval.append({k: list(v.list) for (k, v) in s.items()}) - return rval - - @classmethod - def security_schemes( - cls, schemes: dict[str, a2a_pb2.SecurityScheme] | None, - ) -> dict[str, types.SecurityScheme] | None: - if not schemes: - return None - return {k: cls.security_scheme(v) for (k, v) in schemes.items()} - - @classmethod - def security_scheme( - cls, scheme: a2a_pb2.SecurityScheme, - ) -> types.SecurityScheme: - if scheme.HasApiKeySecurityScheme(): - return types.SecurityScheme(root=types.APIKeySecurityScheme( - description=scheme.api_key_security_scheme.description, - in_=scheme.api_key_security_scheme.location, - name=scheme.api_key_security_scheme.name, - )) - if scheme.HasHttpAuthSecurityScheme(): - return types.SecurityScheme(root=types.HTTPAuthSecurityScheme( - description=scheme.http_auth_security_scheme.description, - scheme=scheme.http_auth_security_scheme.scheme, - bearerFormat=scheme.http_auth_security_scheme.bearer_format, - )) - if scheme.HasOauth2SecurityScheme(): - return types.SecurityScheme(root=types.OAuth2SecurityScheme( - description=scheme.oauth2_security_scheme.description, - flows=cls.oauth2_flows(scheme.oauth2_security_scheme.flows), - )) - return types.SecurityScheme(root=types.OpenIdConnectSecurityScheme( - description=scheme.open_id_connect_security_scheme.description, - openIdConnectUrl=scheme.open_id_connect_security_scheme.open_id_connect_url, - )) - - @classmethod - def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: - if flows.HasAuthorizationCode(): - return types.OAuthFlows( - authorizationCode=types.AuthorizationCodeAuthFlow( - authorizationUrl=flows.authorization_code.authorization_url, - refreshUrl=flows.authorization_code.refresh_url, - scopes={ - k: v for (k, v) in flows.authorization_code.scopes.items() - }, - tokenUrl=flows.authorization_code.token_url, - ), - ) - if flows.HasClientCredentials(): - return types.OAuthFlows( - clientCredentials=types.ClientCredentialsAuthFlow( - refreshUrl=flows.client_credentials.refresh_url, - scopes={ - k:v for (k, v) in flows.client_credentials.scopes.items() - }, - tokenUrl=flows.client_credentials.token_url, - ), - ) - if flows.HasImplicit(): - return types.OAuthFlows( - implicit=types.ImplicitOAuthFlow( - authorizationUrl=flows.implicit.authorization_url, - refreshUrl=flows.implicit.refresh_url, - scopes={k: v for (k, v) in flows.implicit.scopes.items()}, - ), - ) - return types.OAuthFlows( - password=types.PasswordOAuthFlow( - refreshUrl=flows.password.refresh_url, - scopes={k: v for (k, v) in flows.password.scopes.items()}, - tokenUrl=flows.password.token_url, + """Converts proto types to Python types.""" + + @classmethod + def message(cls, message: a2a_pb2.Message) -> types.Message: + return types.Message( + messageId=message.message_id, + parts=[FromProto.part(p) for p in message.content], + contextId=message.context_id, + taskId=message.task_id, + role=FromProto.role(message.role), + metadata=FromProto.metadata(message.metadata), + ) + + @classmethod + def metadata(cls, metadata: struct_pb2.Struct) -> Dict[str, Any]: + return { + key: value.string_value + for key, value in metadata.fields.items() + if value.string_value + } + + @classmethod + def part(cls, part: a2a_pb2.Part) -> types.Part: + if part.HasField("text"): + return types.Part(root=types.TextPart(text=part.text)) + elif part.HasField("file"): + return types.Part(root=types.FilePart(file=FromProto.file(part.file))) + elif part.HasField("data"): + return types.Part(root=types.DataPart(data=FromProto.data(part.data))) + else: + raise ValueError(f"Unsupported part type: {part}") + + @classmethod + def data(cls, data: a2a_pb2.DataPart) -> Dict[str, Any]: + json_data = json_format.MessageToJson(data.data) + return json.loads(json_data) + + @classmethod + def file(cls, file: a2a_pb2.FilePart) -> types.FileWithUri | types.FileWithBytes: + if file.HasField("file_with_uri"): + return types.FileWithUri(uri=file.file_with_uri) + return types.FileWithBytes(bytes=file.file_with_bytes.decode("utf-8")) + + @classmethod + def task(cls, task: a2a_pb2.Task) -> types.Task: + return types.Task( + id=task.id, + contextId=task.context_id, + status=FromProto.task_status(task.status), + artifacts=[FromProto.artifact(a) for a in task.artifacts], + history=[FromProto.message(h) for h in task.history], + ) + + @classmethod + def task_status(cls, status: a2a_pb2.TaskStatus) -> types.TaskStatus: + return types.TaskStatus( + state=FromProto.task_state(status.state), + message=FromProto.message(status.update), + ) + + @classmethod + def task_state(cls, state: a2a_pb2.TaskState) -> types.TaskState: + match state: + case a2a_pb2.TaskState.TASK_STATE_SUBMITTED: + return types.TaskState.submitted + case a2a_pb2.TaskState.TASK_STATE_WORKING: + return types.TaskState.working + case a2a_pb2.TaskState.TASK_STATE_COMPLETED: + return types.TaskState.completed + case a2a_pb2.TaskState.TASK_STATE_CANCELLED: + return types.TaskState.canceled + case a2a_pb2.TaskState.TASK_STATE_FAILED: + return types.TaskState.failed + case a2a_pb2.TaskState.TASK_STATE_INPUT_REQUIRED: + return types.TaskState.input_required + case _: + return types.TaskState.unknown + + @classmethod + def artifact(cls, artifact: a2a_pb2.Artifact) -> types.Artifact: + return types.Artifact( + artifactId=artifact.artifact_id, + description=artifact.description, + metadata=FromProto.metadata(artifact.metadata), + name=artifact.name, + parts=[FromProto.part(p) for p in artifact.parts], + ) + + @classmethod + def task_artifact_update_event( + cls, event: a2a_pb2.TaskArtifactUpdateEvent + ) -> types.TaskArtifactUpdateEvent: + return types.TaskArtifactUpdateEvent( + taskId=event.task_id, + contextId=event.context_id, + artifact=FromProto.artifact(event.artifact), + metadata=FromProto.metadata(event.metadata), + append=event.append, + lastChunk=event.last_chunk, + ) + + @classmethod + def task_status_update_event( + cls, event: a2a_pb2.TaskStatusUpdateEvent + ) -> types.TaskStatusUpdateEvent: + return types.TaskStatusUpdateEvent( + taskId=event.task_id, + contextId=event.context_id, + status=FromProto.task_status(event.status), + metadata=FromProto.metadata(event.metadata), + final=event.final, + ) + + @classmethod + def push_notification_config( + cls, config: a2a_pb2.PushNotificationConfig + ) -> types.PushNotificationConfig: + return types.PushNotificationConfig( + id=config.id, + url=config.url, + token=config.token, + authentication=FromProto.authentication_info(config.authentication), + ) + + @classmethod + def authentication_info( + cls, info: a2a_pb2.AuthenticationInfo + ) -> types.PushNotificationAuthenticationInfo: + return types.PushNotificationAuthenticationInfo( + schemes=list(info.schemes), + credentials=info.credentials, + ) + + @classmethod + def message_send_configuration( + cls, config: a2a_pb2.SendMessageConfiguration + ) -> types.MessageSendConfiguration: + return types.MessageSendConfiguration( + acceptedOutputModes=list(config.accepted_output_modes), + pushNotificationConfig=FromProto.push_notification_config( + config.push_notification + ), + historyLength=config.history_length, + blocking=config.blocking, + ) + + @classmethod + def message_send_params( + cls, request: a2a_pb2.SendMessageRequest + ) -> types.MessageSendParams: + return types.MessageSendParams( + configuration=cls.message_send_configuration(request.configuration), + message=cls.message(request.request), + metadata=cls.metadata(request.metadata), + ) + + @classmethod + def task_id_params( + cls, + request: ( + a2a_pb2.CancelTaskRequest + | a2a_pb2.TaskSubscriptionRequest + | a2a_pb2.GetTaskPushNotificationRequest ), - ) - - @classmethod - def skill(cls, skill: a2a_pb2.AgentSkill) -> types.AgentSkill: - return types.AgentSkill( - id=skill.id, - name=skill.name, - description=skill.description, - tags=list(skill.tags), - examples=list(skill.examples), - inputModes=list(skill.input_modes), - outputModes=list(skill.output_modes), - ) - - @classmethod - def role(cls, role: a2a_pb2.Role) -> types.Role: - match role: - case a2a_pb2.Role.ROLE_USER: - return types.Role.user - case a2a_pb2.Role.ROLE_AGENT: - return types.Role.agent - case _: - return types.Role.agent + ) -> types.TaskIdParams: + # This is currently incomplete until the core sdk supports multiple + # configs for a single task. + if isinstance(request, a2a_pb2.GetTaskPushNotificationRequest): + m = re.match(_TASK_PUSH_CONFIG_NAME_MATCH, request.name) + if not m: + raise ServerError( + error=types.InvalidParamsError( + message=f"No task for {request.name}" + ) + ) + return types.TaskIdParams(id=m.group(1)) + m = re.match(_TASK_NAME_MATCH, request.name) + if not m: + raise ServerError( + error=types.InvalidParamsError(message=f"No task for {task.name}") + ) + return types.TaskIdParams(id=m.group(1)) + + @classmethod + def task_push_notification_config( + cls, + request: a2a_pb2.CreateTaskPushNotificationRequest, + ) -> types.TaskPushNotificationConfig: + m = re.match(_TASK_NAME_MATCH, request.parent) + if not m: + raise ServerError( + error=types.InvalidParamsError( + message=f"No task for {request.parent}" + ) + ) + return types.TaskPushNotificationConfig( + pushNotificationConfig=cls.push_notification_config( + request.config.push_notification_config, + ), + taskId=m.group(1), + ) + + @classmethod + def agent_card( + cls, + card: a2a_pb2.AgentCard, + ) -> types.AgentCard: + return types.AgentCard( + capabilities=cls.capabilities(card.capabilities), + defaultInputModes=list(card.default_input_modes), + defaultOutputModes=list(card.default_output_modes), + description=card.description, + documentationUrl=card.documentation_url, + name=card.name, + provider=cls.provider(card.provider), + security=cls.security(list(card.security)), + securitySchemes=cls.security_schemes(dict(card.security_schemes)), + skills=[cls.skill(x) for x in card.skills] if card.skills else [], + url=card.url, + version=card.version, + supportsAuthenticatedExtendedCard=card.supports_authenticated_extended_card, + ) + + @classmethod + def task_query_params( + cls, + request: a2a_pb2.GetTaskRequest, + ) -> types.TaskQueryParams: + m = re.match(_TASK_NAME_MATCH, request.name) + if not m: + raise ServerError( + error=types.InvalidParamsError( + message=f"No task for {request.name}" + ) + ) + return types.TaskQueryParams( + historyLength=request.history_length + if request.history_length + else None, + id=m.group(1), + metadata=None, + ) + + @classmethod + def capabilities( + cls, capabilities: a2a_pb2.AgentCapabilities + ) -> types.AgentCapabilities: + return types.AgentCapabilities( + streaming=capabilities.streaming, + pushNotifications=capabilities.push_notifications, + ) + + @classmethod + def security( + cls, + security: list[a2a_pb2.Security] | None, + ) -> list[dict[str, list[str]]] | None: + if not security: + return None + rval: list[dict[str, list[str]]] = [] + for s in security: + rval.append({k: list(v.list) for (k, v) in s.schemes.items()}) + return rval + + @classmethod + def provider( + cls, provider: a2a_pb2.AgentProvider | None + ) -> types.AgentProvider | None: + if not provider: + return None + return types.AgentProvider( + organization=provider.organization, + url=provider.url, + ) + + @classmethod + def security_scheme( + cls, + scheme: a2a_pb2.SecurityScheme, + ) -> types.SecurityScheme: + if scheme.HasField("api_key_security_scheme"): + return types.SecurityScheme( + root=types.APIKeySecurityScheme( + description=scheme.api_key_security_scheme.description, + in_=scheme.api_key_security_scheme.location, + name=scheme.api_key_security_scheme.name, + ) + ) + if scheme.HasField("http_auth_security_scheme"): + return types.SecurityScheme( + root=types.HTTPAuthSecurityScheme( + description=scheme.http_auth_security_scheme.description, + scheme=scheme.http_auth_security_scheme.scheme, + bearerFormat=scheme.http_auth_security_scheme.bearer_format, + ) + ) + if scheme.HasField("oauth2_security_scheme"): + return types.SecurityScheme( + root=types.OAuth2SecurityScheme( + description=scheme.oauth2_security_scheme.description, + flows=cls.oauth2_flows(scheme.oauth2_security_scheme.flows), + ) + ) + return types.SecurityScheme( + root=types.OpenIdConnectSecurityScheme( + description=scheme.open_id_connect_security_scheme.description, + openIdConnectUrl=scheme.open_id_connect_security_scheme.open_id_connect_url, + ) + ) + + @classmethod + def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: + if flows.HasField("authorization_code"): + return types.OAuthFlows( + authorizationCode=types.AuthorizationCodeAuthFlow( + authorizationUrl=flows.authorization_code.authorization_url, + refreshUrl=flows.authorization_code.refresh_url, + scopes={ + k: v for (k, v) in flows.authorization_code.scopes.items() + }, + tokenUrl=flows.authorization_code.token_url, + ), + ) + if flows.HasField("client_credentials"): + return types.OAuthFlows( + clientCredentials=types.ClientCredentialsAuthFlow( + refreshUrl=flows.client_credentials.refresh_url, + scopes={ + k: v for (k, v) in flows.client_credentials.scopes.items() + }, + tokenUrl=flows.client_credentials.token_url, + ), + ) + if flows.HasField("implicit"): + return types.OAuthFlows( + implicit=types.ImplicitOAuthFlow( + authorizationUrl=flows.implicit.authorization_url, + refreshUrl=flows.implicit.refresh_url, + scopes={k: v for (k, v) in flows.implicit.scopes.items()}, + ), + ) + return types.OAuthFlows( + password=types.PasswordOAuthFlow( + refreshUrl=flows.password.refresh_url, + scopes={k: v for (k, v) in flows.password.scopes.items()}, + tokenUrl=flows.password.token_url, + ), + ) + + @classmethod + def skill(cls, skill: a2a_pb2.AgentSkill) -> types.AgentSkill: + return types.AgentSkill( + id=skill.id, + name=skill.name, + description=skill.description, + tags=list(skill.tags), + examples=list(skill.examples), + inputModes=list(skill.input_modes), + outputModes=list(skill.output_modes), + ) + + @classmethod + def role(cls, role: a2a_pb2.Role) -> types.Role: + match role: + case a2a_pb2.Role.ROLE_USER: + return types.Role.user + case a2a_pb2.Role.ROLE_AGENT: + return types.Role.agent + case _: + return types.Role.agent From 4d7270085638f957ed943e0ba4fd92e7bb80be4c Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Fri, 6 Jun 2025 15:08:52 +0000 Subject: [PATCH 20/29] Fix more lint errors --- src/a2a/server/request_handlers/grpc_handler.py | 11 +++++------ src/a2a/utils/proto_utils.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 03f0ca707..270afbcf3 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -65,8 +65,8 @@ def __init__( Args: agent_card: The AgentCard describing the agent's capabilities. - request_handler: The underlying `RequestHandler` instance to delegat -e requests to. + request_handler: The underlying `RequestHandler` instance to + delegate requests to. """ self.agent_card = agent_card self.request_handler = request_handler @@ -84,10 +84,9 @@ async def SendMessage( context: Context provided by the server. Returns: - A `SendMessageResponse` object containing the result (Task or Messag -e) - or throws an error response if a `ServerError` is raised by the han -dler. + A `SendMessageResponse` object containing the result (Task or + Message) or throws an error response if a `ServerError` is raised + by the handler. """ try: # Construct the server context object diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 7905eae0a..36c6ccd87 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -138,7 +138,7 @@ def push_notification_config( cls, config: types.PushNotificationConfig ) -> a2a_pb2.PushNotificationConfig: return a2a_pb2.PushNotificationConfig( - id=config.id or "", + id=config.id or '', url=config.url, token=config.token, authentication=ToProto.authentication_info(config.authentication), From a427f151c13eec3eb8fc4d55d3c73f642538f1e2 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Fri, 6 Jun 2025 15:50:19 +0000 Subject: [PATCH 21/29] Yet more lint/mypy fixes --- src/a2a/utils/proto_utils.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 36c6ccd87..e118248dc 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -85,7 +85,7 @@ def task(cls, task: types.Task) -> a2a_pb2.Task: else None ), history=( - [ToProto.message(h) for h in task.history] if task.history else None + [ToProto.message(h) for h in task.history] if task.history else None # type: ignore[misc] ), ) @@ -374,20 +374,20 @@ def oauth2_flows(cls, flows: types.OAuthFlows) -> a2a_pb2.OAuthFlows: scopes={ k: v for (k, v) in flows.clientCredentials.scopes.items() }, - token_url=flows.client_credentials.tokenUrl, + token_url=flows.clientCredentials.tokenUrl, ), ) if flows.implicit: return a2a_pb2.OAuthFlows( implicit=a2a_pb2.ImplicitOAuthFlow( - authorization_url=flows.implicit.authorization_Url, + authorization_url=flows.implicit.authorizationUrl, refresh_url=flows.implicit.refreshUrl, scopes={k: v for (k, v) in flows.implicit.scopes.items()}, ), ) if flows.password: return a2a_pb2.OAuthFlows( - password=types.PasswordOAuthFlow( + password=a2a_pb2.PasswordOAuthFlow( refresh_url=flows.password.refreshUrl, scopes={k: v for (k, v) in flows.password.scopes.items()}, token_url=flows.password.tokenUrl, @@ -598,7 +598,9 @@ def task_id_params( m = re.match(_TASK_NAME_MATCH, request.name) if not m: raise ServerError( - error=types.InvalidParamsError(message=f"No task for {task.name}") + error=types.InvalidParamsError( + message=f"No task for {request.name}" + ) ) return types.TaskIdParams(id=m.group(1)) @@ -694,19 +696,27 @@ def provider( url=provider.url, ) + @classmethod + def security_schemes( + cls, + schemes: dict[str, a2a_pb2.SecurityScheme] + ) -> dict[str, types.SecurityScheme]: + return {k: cls.security_scheme(v) for (k, v) in schemes.items()} + @classmethod def security_scheme( cls, scheme: a2a_pb2.SecurityScheme, ) -> types.SecurityScheme: if scheme.HasField("api_key_security_scheme"): - return types.SecurityScheme( + ss = types.SecurityScheme( root=types.APIKeySecurityScheme( description=scheme.api_key_security_scheme.description, - in_=scheme.api_key_security_scheme.location, name=scheme.api_key_security_scheme.name, + in_= scheme.api_key_security_scheme.location, # type: ignore[call-arg] ) ) + return ss if scheme.HasField("http_auth_security_scheme"): return types.SecurityScheme( root=types.HTTPAuthSecurityScheme( @@ -733,7 +743,7 @@ def security_scheme( def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: if flows.HasField("authorization_code"): return types.OAuthFlows( - authorizationCode=types.AuthorizationCodeAuthFlow( + authorizationCode=types.AuthorizationCodeOAuthFlow( authorizationUrl=flows.authorization_code.authorization_url, refreshUrl=flows.authorization_code.refresh_url, scopes={ @@ -744,7 +754,7 @@ def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: ) if flows.HasField("client_credentials"): return types.OAuthFlows( - clientCredentials=types.ClientCredentialsAuthFlow( + clientCredentials=types.ClientCredentialsOAuthFlow( refreshUrl=flows.client_credentials.refresh_url, scopes={ k: v for (k, v) in flows.client_credentials.scopes.items() From 83fdeafeb151e42346819ee61b512e3be03fefa2 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Fri, 6 Jun 2025 11:17:45 -0500 Subject: [PATCH 22/29] Update Linter and nox formatter to exclude grpc/ directory --- .github/linters/.ruff.toml | 7 +++++-- .github/workflows/linter.yaml | 1 + noxfile.py | 5 ++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.github/linters/.ruff.toml b/.github/linters/.ruff.toml index 78c377929..5e84c97a5 100644 --- a/.github/linters/.ruff.toml +++ b/.github/linters/.ruff.toml @@ -82,7 +82,7 @@ exclude = [ "venv", "*/migrations/*", "noxfile.py", - "src/a2a/grpc/*.*", + "src/a2a/grpc/**", ] [lint.isort] @@ -140,7 +140,10 @@ inline-quotes = "single" "types.py" = ["D", "E501", "N815"] # Ignore docstring and annotation issues in types.py [format] -exclude = ["types.py", "src/a2a/grpc/**"] +exclude = [ + "types.py", + "src/a2a/grpc/**", +] docstring-code-format = true docstring-code-line-length = "dynamic" # Or set to 80 quote-style = "single" diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml index a657c5a30..890e81ae9 100644 --- a/.github/workflows/linter.yaml +++ b/.github/workflows/linter.yaml @@ -64,4 +64,5 @@ jobs: VALIDATE_GIT_COMMITLINT: false PYTHON_MYPY_CONFIG_FILE: .mypy.ini FILTER_REGEX_INCLUDE: ".*src/**/*" + FILTER_REGEX_EXCLUDE: ".*src/a2a/grpc/**/*" PYTHON_RUFF_CONFIG_FILE: .ruff.toml diff --git a/noxfile.py b/noxfile.py index 380d2a28d..60dd15efa 100644 --- a/noxfile.py +++ b/noxfile.py @@ -103,7 +103,9 @@ def format(session) -> None: } ) - lint_paths_py = [f for f in changed_files if f.endswith('.py')] + lint_paths_py = [ + f for f in changed_files if f.endswith('.py') and 'grpc/' not in f + ] if not lint_paths_py: session.log('No changed Python files to lint.') @@ -111,6 +113,7 @@ def format(session) -> None: session.install( 'types-requests', + 'types-protobuf', 'pyupgrade', 'autoflake', 'ruff', From a19fddb9b706c53c570dc6bb60e58668c140788d Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Fri, 6 Jun 2025 11:17:50 -0500 Subject: [PATCH 23/29] Fix formatting --- src/a2a/client/__init__.py | 2 +- src/a2a/client/grpc_client.py | 22 ++- src/a2a/server/request_handlers/__init__.py | 4 +- .../default_request_handler.py | 2 +- .../server/request_handlers/grpc_handler.py | 37 +++-- src/a2a/utils/helpers.py | 3 + src/a2a/utils/proto_utils.py | 129 ++++++++++-------- 7 files changed, 106 insertions(+), 93 deletions(-) diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 1a2bb5449..e91c9eb7a 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -1,12 +1,12 @@ """Client-side components for interacting with an A2A agent.""" from a2a.client.client import A2ACardResolver, A2AClient -from a2a.client.grpc_client import A2AGrpcClient from a2a.client.errors import ( A2AClientError, A2AClientHTTPError, A2AClientJSONError, ) +from a2a.client.grpc_client import A2AGrpcClient from a2a.client.helpers import create_text_message_object diff --git a/src/a2a/client/grpc_client.py b/src/a2a/client/grpc_client.py index 40fea5e42..d9f14b7f7 100644 --- a/src/a2a/client/grpc_client.py +++ b/src/a2a/client/grpc_client.py @@ -1,26 +1,24 @@ -import json import logging + from collections.abc import AsyncGenerator -from typing import Any -from uuid import uuid4 + import grpc -from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCard, + Message, MessageSendParams, Task, - TaskStatusUpdateEvent, TaskArtifactUpdateEvent, - TaskPushNotificationConfig, TaskIdParams, + TaskPushNotificationConfig, TaskQueryParams, - Message, + TaskStatusUpdateEvent, ) -from a2a.utils.telemetry import SpanKind, trace_class from a2a.utils import proto_utils -from a2a.grpc import a2a_pb2_grpc -from a2a.grpc import a2a_pb2 +from a2a.utils.telemetry import SpanKind, trace_class + logger = logging.getLogger(__name__) @@ -48,7 +46,7 @@ def __init__( async def send_message( self, request: MessageSendParams, - ) -> Task | Message : + ) -> Task | Message: """Sends a non-streaming message request to the agent. Args: @@ -174,7 +172,7 @@ async def set_task_callback( async def get_task_callback( self, - request: TaskIdParams, # TODO: Update to a push id params + request: TaskIdParams, # TODO: Update to a push id params ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task. diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 623843848..8cf2fe8ce 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -3,8 +3,8 @@ from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) -from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.grpc_handler import GrpcHandler +from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.response_helpers import ( build_error_response, @@ -14,8 +14,8 @@ __all__ = [ 'DefaultRequestHandler', - 'JSONRPCHandler', 'GrpcHandler', + 'JSONRPCHandler', 'RequestHandler', 'build_error_response', 'prepare_response_object', diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 67160c942..2c81da967 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -1,9 +1,9 @@ import asyncio import logging +import uuid from collections.abc import AsyncGenerator from typing import cast -import uuid from a2a.server.agent_execution import ( AgentExecutor, diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 270afbcf3..92125abfd 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -1,37 +1,32 @@ -import logging -import grpc import contextlib +import logging -from typing import AsyncIterable from abc import ABC, abstractmethod +from collections.abc import AsyncIterable + +import grpc + +import a2a.grpc.a2a_pb2_grpc as a2a_grpc +from a2a import types +from a2a.auth.user import UnauthenticatedUser +from a2a.grpc import a2a_pb2 from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import ( AgentCard, - InternalError, - Message, - Task, - TaskArtifactUpdateEvent, TaskNotFoundError, - TaskPushNotificationConfig, - TaskStatusUpdateEvent, ) -from a2a import types -from a2a.auth.user import User as A2AUser -from a2a.auth.user import UnauthenticatedUser -from a2a.server.context import ServerCallContext +from a2a.utils import proto_utils from a2a.utils.errors import ServerError from a2a.utils.helpers import validate, validate_async_generator -from a2a.utils import proto_utils -import a2a.grpc.a2a_pb2 as a2a_pb2 -import a2a.grpc.a2a_pb2_grpc as a2a_grpc logger = logging.getLogger(__name__) # For now we use a trivial wrapper on the grpc context object + class CallContextBuilder(ABC): """A class for building ServerCallContexts using the Starlette Request.""" @@ -53,7 +48,8 @@ def build(self, context: grpc.ServicerContext) -> ServerCallContext: class GrpcHandler(a2a_grpc.A2AServiceServicer): """Maps incoming gRPC requests to the appropriate request handler method - and formats responses.""" + and formats responses. + """ def __init__( self, @@ -115,7 +111,7 @@ async def SendStreamingMessage( """Handles the 'StreamMessage' gRPC method. Yields response objects as they are produced by the underlying handler's - stream. + stream. Args: request: The incoming `SendMessageRequest` object. @@ -181,7 +177,7 @@ async def TaskSubscription( """Handles the 'TaskSubscription' gRPC method. Yields response objects as they are produced by the underlying handler's - stream. + stream. Args: request: The incoming `TaskSubscriptionRequest` object. @@ -193,7 +189,8 @@ async def TaskSubscription( try: server_context = self.context_builder.build(context) async for event in self.request_handler.on_resubscribe_to_task( - proto_utils.FromProto.task_id_params(request), server_context, + proto_utils.FromProto.task_id_params(request), + server_context, ): yield proto_utils.ToProto.stream_response(event) except ServerError as e: diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 41591772b..a1cc43eca 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -1,4 +1,5 @@ """General utility functions for the A2A Python SDK.""" + import functools import logging @@ -147,6 +148,7 @@ def wrapper(self: Any, *args, **kwargs) -> Any: return decorator + def validate_async_generator( expression: Callable[[Any], bool], error_message: str | None = None ): @@ -179,6 +181,7 @@ async def wrapper(self, *args, **kwargs): return decorator + def are_modalities_compatible( server_output_modes: list[str] | None, client_output_modes: list[str] | None ) -> bool: diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index e118248dc..09c9209ab 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -2,19 +2,20 @@ """Utils for converting between proto and Python types.""" import json -from typing import Any, Dict import re -from a2a.grpc import a2a_pb2 +from typing import Any + +from google.protobuf import json_format, struct_pb2 + from a2a import types +from a2a.grpc import a2a_pb2 from a2a.utils.errors import ServerError -from google.protobuf import struct_pb2 -from google.protobuf import json_format # Regexp patterns for matching -_TASK_NAME_MATCH = r"tasks/(\w+)" -_TASK_PUSH_CONFIG_NAME_MATCH = r"tasks/(\w+)/pushNotifications/(\w+)" +_TASK_NAME_MATCH = r'tasks/(\w+)' +_TASK_PUSH_CONFIG_NAME_MATCH = r'tasks/(\w+)/pushNotifications/(\w+)' class ToProto: @@ -34,7 +35,9 @@ def message(cls, message: types.Message | None) -> a2a_pb2.Message | None: ) @classmethod - def metadata(cls, metadata: Dict[str, Any] | None) -> struct_pb2.Struct | None: + def metadata( + cls, metadata: dict[str, Any] | None + ) -> struct_pb2.Struct | None: if metadata is None: return None return struct_pb2.Struct( @@ -50,15 +53,14 @@ def metadata(cls, metadata: Dict[str, Any] | None) -> struct_pb2.Struct | None: def part(cls, part: types.Part) -> a2a_pb2.Part: if isinstance(part.root, types.TextPart): return a2a_pb2.Part(text=part.root.text) - elif isinstance(part.root, types.FilePart): + if isinstance(part.root, types.FilePart): return a2a_pb2.Part(file=ToProto.file(part.root.file)) - elif isinstance(part.root, types.DataPart): + if isinstance(part.root, types.DataPart): return a2a_pb2.Part(data=ToProto.data(part.root.data)) - else: - raise ValueError(f"Unsupported part type: {part.root}") + raise ValueError(f'Unsupported part type: {part.root}') @classmethod - def data(cls, data: Dict[str, Any]) -> a2a_pb2.DataPart: + def data(cls, data: dict[str, Any]) -> a2a_pb2.DataPart: json_data = json.dumps(data) return a2a_pb2.DataPart( data=json_format.Parse( @@ -68,10 +70,12 @@ def data(cls, data: Dict[str, Any]) -> a2a_pb2.DataPart: ) @classmethod - def file(cls, file: types.FileWithUri | types.FileWithBytes) -> a2a_pb2.FilePart: + def file( + cls, file: types.FileWithUri | types.FileWithBytes + ) -> a2a_pb2.FilePart: if isinstance(file, types.FileWithUri): return a2a_pb2.FilePart(file_with_uri=file.uri) - return a2a_pb2.FilePart(file_with_bytes=file.bytes.encode("utf-8")) + return a2a_pb2.FilePart(file_with_bytes=file.bytes.encode('utf-8')) @classmethod def task(cls, task: types.Task) -> a2a_pb2.Task: @@ -85,7 +89,9 @@ def task(cls, task: types.Task) -> a2a_pb2.Task: else None ), history=( - [ToProto.message(h) for h in task.history] if task.history else None # type: ignore[misc] + [ToProto.message(h) for h in task.history] + if task.history + else None # type: ignore[misc] ), ) @@ -197,16 +203,15 @@ def update_event( return a2a_pb2.StreamResponse( status_update=ToProto.task_status_update_event(event) ) - elif isinstance(event, types.TaskArtifactUpdateEvent): + if isinstance(event, types.TaskArtifactUpdateEvent): return a2a_pb2.StreamResponse( artifact_update=ToProto.task_artifact_update_event(event) ) - elif isinstance(event, types.Message): + if isinstance(event, types.Message): return a2a_pb2.StreamResponse(msg=ToProto.message(event)) - elif isinstance(event, types.Task): + if isinstance(event, types.Task): return a2a_pb2.StreamResponse(task=ToProto.task(event)) - else: - raise ValueError(f"Unsupported event type: {type(event)}") + raise ValueError(f'Unsupported event type: {type(event)}') @classmethod def task_or_message( @@ -232,9 +237,9 @@ def stream_response( ) -> a2a_pb2.StreamResponse: if isinstance(event, types.Message): return a2a_pb2.StreamResponse(msg=cls.message(event)) - elif isinstance(event, types.Task): + if isinstance(event, types.Task): return a2a_pb2.StreamResponse(task=cls.task(event)) - elif isinstance(event, types.TaskStatusUpdateEvent): + if isinstance(event, types.TaskStatusUpdateEvent): return a2a_pb2.StreamResponse( status_update=cls.task_status_update_event(event), ) @@ -247,7 +252,7 @@ def task_push_notification_config( cls, config: types.TaskPushNotificationConfig ) -> a2a_pb2.TaskPushNotificationConfig: return a2a_pb2.TaskPushNotificationConfig( - name=f"tasks/{config.taskId}/pushNotifications/{config.taskId}", + name=f'tasks/{config.taskId}/pushNotifications/{config.taskId}', push_notification_config=cls.push_notification_config( config.pushNotificationConfig, ), @@ -305,7 +310,9 @@ def security( for s in security: rval.append( a2a_pb2.Security( - schemes={k: a2a_pb2.StringList(list=v) for (k, v) in s.items()} + schemes={ + k: a2a_pb2.StringList(list=v) for (k, v) in s.items() + } ) ) return rval @@ -362,18 +369,20 @@ def oauth2_flows(cls, flows: types.OAuthFlows) -> a2a_pb2.OAuthFlows: authorization_url=flows.authorizationCode.authorizationUrl, refresh_url=flows.authorizationCode.refreshUrl, scopes={ - k: v for (k, v) in flows.authorizationCode.scopes.items() + k: v + for (k, v) in flows.authorizationCode.scopes.items() }, - token_url=flows.authorizationCode.tokenUrl, - ), + token_url=flows.authorizationCode.tokenUrl, + ), ) if flows.clientCredentials: return a2a_pb2.OAuthFlows( client_credentials=a2a_pb2.ClientCredentialsOAuthFlow( refresh_url=flows.clientCredentials.refreshUrl, scopes={ - k: v for (k, v) in flows.clientCredentials.scopes.items() - }, + k: v + for (k, v) in flows.clientCredentials.scopes.items() + }, token_url=flows.clientCredentials.tokenUrl, ), ) @@ -393,7 +402,7 @@ def oauth2_flows(cls, flows: types.OAuthFlows) -> a2a_pb2.OAuthFlows: token_url=flows.password.tokenUrl, ), ) - raise ValueError("Unknown oauth flow definition") + raise ValueError('Unknown oauth flow definition') @classmethod def skill(cls, skill: types.AgentSkill) -> a2a_pb2.AgentSkill: @@ -433,7 +442,7 @@ def message(cls, message: a2a_pb2.Message) -> types.Message: ) @classmethod - def metadata(cls, metadata: struct_pb2.Struct) -> Dict[str, Any]: + def metadata(cls, metadata: struct_pb2.Struct) -> dict[str, Any]: return { key: value.string_value for key, value in metadata.fields.items() @@ -442,25 +451,30 @@ def metadata(cls, metadata: struct_pb2.Struct) -> Dict[str, Any]: @classmethod def part(cls, part: a2a_pb2.Part) -> types.Part: - if part.HasField("text"): + if part.HasField('text'): return types.Part(root=types.TextPart(text=part.text)) - elif part.HasField("file"): - return types.Part(root=types.FilePart(file=FromProto.file(part.file))) - elif part.HasField("data"): - return types.Part(root=types.DataPart(data=FromProto.data(part.data))) - else: - raise ValueError(f"Unsupported part type: {part}") + if part.HasField('file'): + return types.Part( + root=types.FilePart(file=FromProto.file(part.file)) + ) + if part.HasField('data'): + return types.Part( + root=types.DataPart(data=FromProto.data(part.data)) + ) + raise ValueError(f'Unsupported part type: {part}') @classmethod - def data(cls, data: a2a_pb2.DataPart) -> Dict[str, Any]: + def data(cls, data: a2a_pb2.DataPart) -> dict[str, Any]: json_data = json_format.MessageToJson(data.data) return json.loads(json_data) @classmethod - def file(cls, file: a2a_pb2.FilePart) -> types.FileWithUri | types.FileWithBytes: - if file.HasField("file_with_uri"): + def file( + cls, file: a2a_pb2.FilePart + ) -> types.FileWithUri | types.FileWithBytes: + if file.HasField('file_with_uri'): return types.FileWithUri(uri=file.file_with_uri) - return types.FileWithBytes(bytes=file.file_with_bytes.decode("utf-8")) + return types.FileWithBytes(bytes=file.file_with_bytes.decode('utf-8')) @classmethod def task(cls, task: a2a_pb2.Task) -> types.Task: @@ -591,7 +605,7 @@ def task_id_params( if not m: raise ServerError( error=types.InvalidParamsError( - message=f"No task for {request.name}" + message=f'No task for {request.name}' ) ) return types.TaskIdParams(id=m.group(1)) @@ -599,7 +613,7 @@ def task_id_params( if not m: raise ServerError( error=types.InvalidParamsError( - message=f"No task for {request.name}" + message=f'No task for {request.name}' ) ) return types.TaskIdParams(id=m.group(1)) @@ -613,7 +627,7 @@ def task_push_notification_config( if not m: raise ServerError( error=types.InvalidParamsError( - message=f"No task for {request.parent}" + message=f'No task for {request.parent}' ) ) return types.TaskPushNotificationConfig( @@ -653,7 +667,7 @@ def task_query_params( if not m: raise ServerError( error=types.InvalidParamsError( - message=f"No task for {request.name}" + message=f'No task for {request.name}' ) ) return types.TaskQueryParams( @@ -698,8 +712,7 @@ def provider( @classmethod def security_schemes( - cls, - schemes: dict[str, a2a_pb2.SecurityScheme] + cls, schemes: dict[str, a2a_pb2.SecurityScheme] ) -> dict[str, types.SecurityScheme]: return {k: cls.security_scheme(v) for (k, v) in schemes.items()} @@ -708,16 +721,16 @@ def security_scheme( cls, scheme: a2a_pb2.SecurityScheme, ) -> types.SecurityScheme: - if scheme.HasField("api_key_security_scheme"): + if scheme.HasField('api_key_security_scheme'): ss = types.SecurityScheme( root=types.APIKeySecurityScheme( description=scheme.api_key_security_scheme.description, name=scheme.api_key_security_scheme.name, - in_= scheme.api_key_security_scheme.location, # type: ignore[call-arg] + in_=scheme.api_key_security_scheme.location, # type: ignore[call-arg] ) ) return ss - if scheme.HasField("http_auth_security_scheme"): + if scheme.HasField('http_auth_security_scheme'): return types.SecurityScheme( root=types.HTTPAuthSecurityScheme( description=scheme.http_auth_security_scheme.description, @@ -725,7 +738,7 @@ def security_scheme( bearerFormat=scheme.http_auth_security_scheme.bearer_format, ) ) - if scheme.HasField("oauth2_security_scheme"): + if scheme.HasField('oauth2_security_scheme'): return types.SecurityScheme( root=types.OAuth2SecurityScheme( description=scheme.oauth2_security_scheme.description, @@ -741,28 +754,30 @@ def security_scheme( @classmethod def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: - if flows.HasField("authorization_code"): + if flows.HasField('authorization_code'): return types.OAuthFlows( authorizationCode=types.AuthorizationCodeOAuthFlow( authorizationUrl=flows.authorization_code.authorization_url, refreshUrl=flows.authorization_code.refresh_url, scopes={ - k: v for (k, v) in flows.authorization_code.scopes.items() + k: v + for (k, v) in flows.authorization_code.scopes.items() }, tokenUrl=flows.authorization_code.token_url, ), ) - if flows.HasField("client_credentials"): + if flows.HasField('client_credentials'): return types.OAuthFlows( clientCredentials=types.ClientCredentialsOAuthFlow( refreshUrl=flows.client_credentials.refresh_url, scopes={ - k: v for (k, v) in flows.client_credentials.scopes.items() + k: v + for (k, v) in flows.client_credentials.scopes.items() }, tokenUrl=flows.client_credentials.token_url, ), ) - if flows.HasField("implicit"): + if flows.HasField('implicit'): return types.OAuthFlows( implicit=types.ImplicitOAuthFlow( authorizationUrl=flows.implicit.authorization_url, From 7bd25ab8bd3d39cab3626e03af186facbc23215e Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Fri, 6 Jun 2025 11:22:43 -0500 Subject: [PATCH 24/29] Ignore Docstring error in `proto_utils.py` --- .github/linters/.ruff.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/linters/.ruff.toml b/.github/linters/.ruff.toml index 5e84c97a5..28d0f614c 100644 --- a/.github/linters/.ruff.toml +++ b/.github/linters/.ruff.toml @@ -138,6 +138,7 @@ inline-quotes = "single" "SLF001", ] "types.py" = ["D", "E501", "N815"] # Ignore docstring and annotation issues in types.py +"proto_utils.py" = ["D102"] [format] exclude = [ From c83f32985a2b988b6533495206ed42fd50828e29 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Fri, 6 Jun 2025 11:43:11 -0500 Subject: [PATCH 25/29] Lint fixes (ruff `unsafe-fixes`) --- .../server/request_handlers/grpc_handler.py | 2 +- src/a2a/utils/proto_utils.py | 31 ++++++------------- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 92125abfd..125d201c0 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -299,7 +299,7 @@ async def GetAgentCard( async def abort_context( self, error: ServerError, context: grpc.ServicerContext - ): + ) -> None: match error.error: case types.JSONParseError(): await context.abort( diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 09c9209ab..d519f1a65 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -368,10 +368,7 @@ def oauth2_flows(cls, flows: types.OAuthFlows) -> a2a_pb2.OAuthFlows: authorization_code=a2a_pb2.AuthorizationCodeOAuthFlow( authorization_url=flows.authorizationCode.authorizationUrl, refresh_url=flows.authorizationCode.refreshUrl, - scopes={ - k: v - for (k, v) in flows.authorizationCode.scopes.items() - }, + scopes=dict(flows.authorizationCode.scopes.items()), token_url=flows.authorizationCode.tokenUrl, ), ) @@ -379,10 +376,7 @@ def oauth2_flows(cls, flows: types.OAuthFlows) -> a2a_pb2.OAuthFlows: return a2a_pb2.OAuthFlows( client_credentials=a2a_pb2.ClientCredentialsOAuthFlow( refresh_url=flows.clientCredentials.refreshUrl, - scopes={ - k: v - for (k, v) in flows.clientCredentials.scopes.items() - }, + scopes=dict(flows.clientCredentials.scopes.items()), token_url=flows.clientCredentials.tokenUrl, ), ) @@ -391,14 +385,14 @@ def oauth2_flows(cls, flows: types.OAuthFlows) -> a2a_pb2.OAuthFlows: implicit=a2a_pb2.ImplicitOAuthFlow( authorization_url=flows.implicit.authorizationUrl, refresh_url=flows.implicit.refreshUrl, - scopes={k: v for (k, v) in flows.implicit.scopes.items()}, + scopes=dict(flows.implicit.scopes.items()), ), ) if flows.password: return a2a_pb2.OAuthFlows( password=a2a_pb2.PasswordOAuthFlow( refresh_url=flows.password.refreshUrl, - scopes={k: v for (k, v) in flows.password.scopes.items()}, + scopes=dict(flows.password.scopes.items()), token_url=flows.password.tokenUrl, ), ) @@ -722,14 +716,13 @@ def security_scheme( scheme: a2a_pb2.SecurityScheme, ) -> types.SecurityScheme: if scheme.HasField('api_key_security_scheme'): - ss = types.SecurityScheme( + return types.SecurityScheme( root=types.APIKeySecurityScheme( description=scheme.api_key_security_scheme.description, name=scheme.api_key_security_scheme.name, in_=scheme.api_key_security_scheme.location, # type: ignore[call-arg] ) ) - return ss if scheme.HasField('http_auth_security_scheme'): return types.SecurityScheme( root=types.HTTPAuthSecurityScheme( @@ -759,10 +752,7 @@ def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: authorizationCode=types.AuthorizationCodeOAuthFlow( authorizationUrl=flows.authorization_code.authorization_url, refreshUrl=flows.authorization_code.refresh_url, - scopes={ - k: v - for (k, v) in flows.authorization_code.scopes.items() - }, + scopes=dict(flows.authorization_code.scopes.items()), tokenUrl=flows.authorization_code.token_url, ), ) @@ -770,10 +760,7 @@ def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: return types.OAuthFlows( clientCredentials=types.ClientCredentialsOAuthFlow( refreshUrl=flows.client_credentials.refresh_url, - scopes={ - k: v - for (k, v) in flows.client_credentials.scopes.items() - }, + scopes=dict(flows.client_credentials.scopes.items()), tokenUrl=flows.client_credentials.token_url, ), ) @@ -782,13 +769,13 @@ def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: implicit=types.ImplicitOAuthFlow( authorizationUrl=flows.implicit.authorization_url, refreshUrl=flows.implicit.refresh_url, - scopes={k: v for (k, v) in flows.implicit.scopes.items()}, + scopes=dict(flows.implicit.scopes.items()), ), ) return types.OAuthFlows( password=types.PasswordOAuthFlow( refreshUrl=flows.password.refresh_url, - scopes={k: v for (k, v) in flows.password.scopes.items()}, + scopes=dict(flows.password.scopes.items()), tokenUrl=flows.password.token_url, ), ) From feff2d744f3dcd449006b75c99c7516aec9c57d1 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Fri, 6 Jun 2025 17:42:50 +0000 Subject: [PATCH 26/29] Fix ruff errors --- src/a2a/server/request_handlers/grpc_handler.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 125d201c0..2c1804c05 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -1,3 +1,4 @@ +# ruff: noqa: N802 import contextlib import logging @@ -55,7 +56,7 @@ def __init__( self, agent_card: AgentCard, request_handler: RequestHandler, - context_builder: CallContextBuilder = DefaultCallContextBuilder(), + context_builder: CallContextBuilder | None = None ): """Initializes the GrpcHandler. @@ -66,7 +67,7 @@ def __init__( """ self.agent_card = agent_card self.request_handler = request_handler - self.context_builder = context_builder + self.context_builder = context_builder or DefaultCallContextBuilder() async def SendMessage( self, @@ -295,11 +296,13 @@ async def GetAgentCard( request: a2a_pb2.GetAgentCardRequest, context: grpc.aio.ServicerContext, ) -> a2a_pb2.AgentCard: + """Get the agent card for the agent served.""" return proto_utils.ToProto.agent_card(self.agent_card) async def abort_context( self, error: ServerError, context: grpc.ServicerContext ) -> None: + """Sets the grpc errors appropriately in the context.""" match error.error: case types.JSONParseError(): await context.abort( From a44ab73a51e6d19ec579a09dd49cf6f64738526f Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Fri, 6 Jun 2025 17:50:43 +0000 Subject: [PATCH 27/29] Additional ruff/mypy fixes --- src/a2a/server/request_handlers/grpc_handler.py | 7 ++++--- src/a2a/utils/proto_utils.py | 8 ++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 2c1804c05..b8c21070a 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -40,6 +40,7 @@ class DefaultCallContextBuilder(CallContextBuilder): """A default implementation of CallContextBuilder.""" def build(self, context: grpc.ServicerContext) -> ServerCallContext: + """Builds the ServerCallContext.""" user = UnauthenticatedUser() state = {} with contextlib.suppress(Exception): @@ -48,9 +49,7 @@ def build(self, context: grpc.ServicerContext) -> ServerCallContext: class GrpcHandler(a2a_grpc.A2AServiceServicer): - """Maps incoming gRPC requests to the appropriate request handler method - and formats responses. - """ + """Maps incoming gRPC requests to the appropriate request handler method.""" def __init__( self, @@ -64,6 +63,8 @@ def __init__( agent_card: The AgentCard describing the agent's capabilities. request_handler: The underlying `RequestHandler` instance to delegate requests to. + context_builder: The CallContextBuilder object. If none the + DefaultCallContextBuilder is used. """ self.agent_card = agent_card self.request_handler = request_handler diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index d519f1a65..14bb0fb19 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -89,9 +89,9 @@ def task(cls, task: types.Task) -> a2a_pb2.Task: else None ), history=( - [ToProto.message(h) for h in task.history] + [ToProto.message(h) for h in task.history] # type: ignore[misc] if task.history - else None # type: ignore[misc] + else None ), ) @@ -103,7 +103,7 @@ def task_status(cls, status: types.TaskStatus) -> a2a_pb2.TaskStatus: ) @classmethod - def task_state(cls, state: types.TaskState) -> a2a_pb2.TaskState: + def task_state(cls, state: types.TaskState) -> a2a_pb2.TaskState: # ruff: noqa: PLR0911 match state: case types.TaskState.submitted: return a2a_pb2.TaskState.TASK_STATE_SUBMITTED @@ -488,7 +488,7 @@ def task_status(cls, status: a2a_pb2.TaskStatus) -> types.TaskStatus: ) @classmethod - def task_state(cls, state: a2a_pb2.TaskState) -> types.TaskState: + def task_state(cls, state: a2a_pb2.TaskState) -> types.TaskState: # ruff: noqa: PLR0911 match state: case a2a_pb2.TaskState.TASK_STATE_SUBMITTED: return types.TaskState.submitted From 434b7fa27ce2a29763b6020008967d7c2ae071c3 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Fri, 6 Jun 2025 17:55:58 +0000 Subject: [PATCH 28/29] Update .ruff.toml to exclude a few rules for a few files --- .github/linters/.ruff.toml | 3 ++- src/a2a/utils/proto_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/linters/.ruff.toml b/.github/linters/.ruff.toml index 28d0f614c..98ec7d43a 100644 --- a/.github/linters/.ruff.toml +++ b/.github/linters/.ruff.toml @@ -138,7 +138,8 @@ inline-quotes = "single" "SLF001", ] "types.py" = ["D", "E501", "N815"] # Ignore docstring and annotation issues in types.py -"proto_utils.py" = ["D102"] +"proto_utils.py" = ["D102", "PLR0911"] +"helpers.py" = ["ANN001"] [format] exclude = [ diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 14bb0fb19..e1dddc393 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -103,7 +103,7 @@ def task_status(cls, status: types.TaskStatus) -> a2a_pb2.TaskStatus: ) @classmethod - def task_state(cls, state: types.TaskState) -> a2a_pb2.TaskState: # ruff: noqa: PLR0911 + def task_state(cls, state: types.TaskState) -> a2a_pb2.TaskState: match state: case types.TaskState.submitted: return a2a_pb2.TaskState.TASK_STATE_SUBMITTED @@ -488,7 +488,7 @@ def task_status(cls, status: a2a_pb2.TaskStatus) -> types.TaskStatus: ) @classmethod - def task_state(cls, state: a2a_pb2.TaskState) -> types.TaskState: # ruff: noqa: PLR0911 + def task_state(cls, state: a2a_pb2.TaskState) -> types.TaskState: match state: case a2a_pb2.TaskState.TASK_STATE_SUBMITTED: return types.TaskState.submitted From 2e0c17602281647d5986eec78f886e5f0de90ec9 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Fri, 6 Jun 2025 18:02:10 +0000 Subject: [PATCH 29/29] More ruff rules --- .github/linters/.ruff.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/linters/.ruff.toml b/.github/linters/.ruff.toml index 98ec7d43a..34dbfa2b9 100644 --- a/.github/linters/.ruff.toml +++ b/.github/linters/.ruff.toml @@ -139,7 +139,7 @@ inline-quotes = "single" ] "types.py" = ["D", "E501", "N815"] # Ignore docstring and annotation issues in types.py "proto_utils.py" = ["D102", "PLR0911"] -"helpers.py" = ["ANN001"] +"helpers.py" = ["ANN001", "ANN201", "ANN202"] [format] exclude = [