diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 25fe25258..afe777120 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -286,11 +286,19 @@ async def on_message_send( interrupted_or_non_blocking = False try: + # Create async callback for push notifications + async def push_notification_callback() -> None: + await self._send_push_notification_if_needed( + task_id, result_aggregator + ) + ( result, interrupted_or_non_blocking, ) = await result_aggregator.consume_and_break_on_interrupt( - consumer, blocking=blocking + consumer, + blocking=blocking, + event_callback=push_notification_callback, ) if not result: raise ServerError(error=InternalError()) # noqa: TRY301 diff --git a/src/a2a/server/tasks/result_aggregator.py b/src/a2a/server/tasks/result_aggregator.py index 147c32022..fb1ab62ef 100644 --- a/src/a2a/server/tasks/result_aggregator.py +++ b/src/a2a/server/tasks/result_aggregator.py @@ -1,7 +1,7 @@ import asyncio import logging -from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable from a2a.server.events import Event, EventConsumer from a2a.server.tasks.task_manager import TaskManager @@ -24,7 +24,10 @@ class ResultAggregator: Task object and emit that Task object. """ - def __init__(self, task_manager: TaskManager): + def __init__( + self, + task_manager: TaskManager, + ) -> None: """Initializes the ResultAggregator. Args: @@ -92,7 +95,10 @@ async def consume_all( return await self.task_manager.get_task() async def consume_and_break_on_interrupt( - self, consumer: EventConsumer, blocking: bool = True + self, + consumer: EventConsumer, + blocking: bool = True, + event_callback: Callable[[], Awaitable[None]] | None = None, ) -> tuple[Task | Message | None, bool]: """Processes the event stream until completion or an interruptable state is encountered. @@ -105,6 +111,9 @@ async def consume_and_break_on_interrupt( consumer: The `EventConsumer` to read events from. blocking: If `False`, the method returns as soon as a task/message is available. If `True`, it waits for a terminal state. + event_callback: Optional async callback function to be called after each event + is processed in the background continuation. + Mainly used for push notifications currently. Returns: A tuple containing: @@ -150,13 +159,17 @@ async def consume_and_break_on_interrupt( if should_interrupt: # Continue consuming the rest of the events in the background. # TODO: We should track all outstanding tasks to ensure they eventually complete. - asyncio.create_task(self._continue_consuming(event_stream)) # noqa: RUF006 + asyncio.create_task( # noqa: RUF006 + self._continue_consuming(event_stream, event_callback) + ) interrupted = True break return await self.task_manager.get_task(), interrupted async def _continue_consuming( - self, event_stream: AsyncIterator[Event] + self, + event_stream: AsyncIterator[Event], + event_callback: Callable[[], Awaitable[None]] | None = None, ) -> None: """Continues processing an event stream in a background task. @@ -165,6 +178,9 @@ async def _continue_consuming( Args: event_stream: The remaining `AsyncIterator` of events from the consumer. + event_callback: Optional async callback function to be called after each event is processed. """ async for event in event_stream: await self.task_manager.process(event) + if event_callback: + await event_callback() diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 6cb21662c..e8906554a 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -405,6 +405,134 @@ async def get_current_result(): mock_agent_executor.execute.assert_awaited_once() +@pytest.mark.asyncio +async def test_on_message_send_with_push_notification_in_non_blocking_request(): + """Test that push notification callback is called during background event processing for non-blocking requests.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore) + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) + mock_push_sender = AsyncMock() + + task_id = 'non_blocking_task_1' + context_id = 'non_blocking_ctx_1' + + # Create a task that will be returned after the first event + initial_task = create_sample_task( + task_id=task_id, context_id=context_id, status_state=TaskState.working + ) + + # Create a final task that will be available during background processing + final_task = create_sample_task( + task_id=task_id, context_id=context_id, status_state=TaskState.completed + ) + + mock_task_store.get.return_value = None + + # Mock request context + mock_request_context = MagicMock(spec=RequestContext) + mock_request_context.task_id = task_id + mock_request_context.context_id = context_id + mock_request_context_builder.build.return_value = mock_request_context + + request_handler = DefaultRequestHandler( + agent_executor=mock_agent_executor, + task_store=mock_task_store, + push_config_store=mock_push_notification_store, + request_context_builder=mock_request_context_builder, + push_sender=mock_push_sender, + ) + + # Configure push notification + push_config = PushNotificationConfig(url='http://callback.com/push') + message_config = MessageSendConfiguration( + push_notification_config=push_config, + accepted_output_modes=['text/plain'], + blocking=False, # Non-blocking request + ) + params = MessageSendParams( + message=Message( + role=Role.user, + message_id='msg_non_blocking', + parts=[], + task_id=task_id, + context_id=context_id, + ), + configuration=message_config, + ) + + # Mock ResultAggregator with custom behavior + mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator) + + # First call returns the initial task and indicates interruption (non-blocking) + mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( + initial_task, + True, # interrupted = True for non-blocking + ) + + # Mock the current_result property to return the final task + async def get_current_result(): + return final_task + + type(mock_result_aggregator_instance).current_result = PropertyMock( + return_value=get_current_result() + ) + + # Track if the event_callback was passed to consume_and_break_on_interrupt + event_callback_passed = False + event_callback_received = None + + async def mock_consume_and_break_on_interrupt( + consumer, blocking=True, event_callback=None + ): + nonlocal event_callback_passed, event_callback_received + event_callback_passed = event_callback is not None + event_callback_received = event_callback + return initial_task, True # interrupted = True for non-blocking + + mock_result_aggregator_instance.consume_and_break_on_interrupt = ( + mock_consume_and_break_on_interrupt + ) + + with ( + patch( + 'a2a.server.request_handlers.default_request_handler.ResultAggregator', + return_value=mock_result_aggregator_instance, + ), + patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=initial_task, + ), + patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.update_with_message', + return_value=initial_task, + ), + ): + # Execute the non-blocking request + result = await request_handler.on_message_send( + params, create_server_call_context() + ) + + # Verify the result is the initial task (non-blocking behavior) + assert result == initial_task + + # Verify that the event_callback was passed to consume_and_break_on_interrupt + assert event_callback_passed, ( + 'event_callback should have been passed to consume_and_break_on_interrupt' + ) + assert event_callback_received is not None, ( + 'event_callback should not be None' + ) + + # Verify that the push notification was sent with the final task + mock_push_sender.send_notification.assert_called_with(final_task) + + # Verify that the push notification config was stored + mock_push_notification_store.set_info.assert_awaited_once_with( + task_id, push_config + ) + + @pytest.mark.asyncio async def test_on_message_send_with_push_notification_no_existing_Task(): """Test on_message_send for new task sets push notification info if provided."""