Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from typing import Any, List, Sequence

from autogen_core import DefaultTopicId, MessageContext, event, rpc
from autogen_core import CancellationToken, DefaultTopicId, MessageContext, event, rpc

from ...base import TerminationCondition
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, SelectSpeakerEvent, StopMessage
Expand Down Expand Up @@ -79,6 +79,7 @@ def __init__(
self._current_turn = 0
self._message_factory = message_factory
self._emit_team_events = emit_team_events
self._active_speakers: List[str] = []

@rpc
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
Expand Down Expand Up @@ -122,64 +123,64 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
# Stop the group chat.
return

# Select a speaker to start/continue the conversation
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_name_future)
speaker_name = await speaker_name_future
if speaker_name not in self._participant_name_to_topic_type:
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
await self._log_speaker_selection(speaker_name)

# Send the message to the next speaker
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=speaker_topic_type),
cancellation_token=ctx.cancellation_token,
)

async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
self._message_thread.extend(messages)
# Select speakers to start/continue the conversation
await self._transition_to_next_speakers(ctx.cancellation_token)

@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
try:
# Append the message to the message thread and construct the delta.
# Construct the detla from the agent response.
delta: List[BaseAgentEvent | BaseChatMessage] = []
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
delta.append(inner_message)
delta.append(message.agent_response.chat_message)

# Append the messages to the message thread.
await self.update_message_thread(delta)

# Remove the agent from the active speakers list.
self._active_speakers.remove(message.agent_name)
if len(self._active_speakers) > 0:
# If there are still active speakers, return without doing anything.
return

# Check if the conversation should be terminated.
if await self._apply_termination_condition(delta, increment_turn_count=True):
# Stop the group chat.
return

# Select a speaker to continue the conversation.
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_name_future)
speaker_name = await speaker_name_future
# Select speakers to continue the conversation.
await self._transition_to_next_speakers(ctx.cancellation_token)
except Exception as e:
# Handle the exception and signal termination with an error.
error = SerializableException.from_exception(e)
await self._signal_termination_with_error(error)
# Raise the exception to the runtime.
raise

async def _transition_to_next_speakers(self, cancellation_token: CancellationToken) -> None:
speaker_names_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
cancellation_token.link_future(speaker_names_future)
speaker_names = await speaker_names_future
if isinstance(speaker_names, str):
# If only one speaker is selected, convert it to a list.
speaker_names = [speaker_names]
for speaker_name in speaker_names:
if speaker_name not in self._participant_name_to_topic_type:
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
await self._log_speaker_selection(speaker_name)
await self._log_speaker_selection(speaker_names)

# Send the message to the next speakers
# Send request to publish message to the next speakers
for speaker_name in speaker_names:
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=speaker_topic_type),
cancellation_token=ctx.cancellation_token,
cancellation_token=cancellation_token,
)
except Exception as e:
# Handle the exception and signal termination with an error.
error = SerializableException.from_exception(e)
await self._signal_termination_with_error(error)
# Raise the exception to the runtime.
raise
self._active_speakers.append(speaker_name)

async def _apply_termination_condition(
self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False
Expand Down Expand Up @@ -216,9 +217,9 @@ async def _apply_termination_condition(
return True
return False

async def _log_speaker_selection(self, speaker_name: str) -> None:
async def _log_speaker_selection(self, speaker_names: List[str]) -> None:
"""Log the selected speaker to the output message queue."""
select_msg = SelectSpeakerEvent(content=[speaker_name], source=self._name)
select_msg = SelectSpeakerEvent(content=speaker_names, source=self._name)
if self._emit_team_events:
await self.publish_message(
GroupChatMessage(message=select_msg),
Expand Down Expand Up @@ -284,10 +285,26 @@ async def validate_group_state(self, messages: List[BaseChatMessage] | None) ->
"""
...

async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
"""Update the message thread with the new messages.
This is called when the group chat receives a GroupChatStart or GroupChatAgentResponse event,
before calling the select_speakers method.
"""
self._message_thread.extend(messages)

@abstractmethod
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
"""Select a speaker from the participants and return the
topic type of the selected speaker."""
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
"""Select speakers from the participants and return the topic types of the selected speaker.
This is called when the group chat manager have received all responses from the participants
for a turn and is ready to select the next speakers for the next turn.

Args:
thread: The message thread of the group chat.

Returns:
A list of topic types of the selected speakers.
If only one speaker is selected, a single string is returned instead of a list.
"""
...

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def handle_request(self, message: GroupChatRequestPublish, ctx: MessageCon
# Publish the response to the group chat.
self._message_buffer.clear()
await self.publish_message(
GroupChatAgentResponse(agent_response=response),
GroupChatAgentResponse(agent_response=response, agent_name=self._agent.name),
topic_id=DefaultTopicId(type=self._parent_topic_type),
cancellation_token=ctx.cancellation_token,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class GroupChatAgentResponse(BaseModel):
agent_response: Response
"""The response from an agent."""

agent_name: str
"""The name of the agent that produced the response."""


class GroupChatRequestPublish(BaseModel):
"""A request to publish a message to a group chat."""
Expand Down
Loading
Loading