first commit
This commit is contained in:
30
agent_dhal/agentdhal_agentchat/teams/__init__.py
Normal file
30
agent_dhal/agentdhal_agentchat/teams/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
This module provides implementation of various pre-defined multi-agent teams.
|
||||
Each team inherits from the BaseGroupChat class.
|
||||
"""
|
||||
|
||||
from ._group_chat._base_group_chat import BaseGroupChat
|
||||
from ._group_chat._graph import (
|
||||
DiGraph,
|
||||
DiGraphBuilder,
|
||||
DiGraphEdge,
|
||||
DiGraphNode,
|
||||
GraphFlow,
|
||||
)
|
||||
from ._group_chat._magentic_one import MagenticOneGroupChat
|
||||
from ._group_chat._round_robin_group_chat import RoundRobinGroupChat
|
||||
from ._group_chat._selector_group_chat import SelectorGroupChat
|
||||
from ._group_chat._swarm_group_chat import Swarm
|
||||
|
||||
__all__ = [
|
||||
"BaseGroupChat",
|
||||
"RoundRobinGroupChat",
|
||||
"SelectorGroupChat",
|
||||
"Swarm",
|
||||
"MagenticOneGroupChat",
|
||||
"DiGraphBuilder",
|
||||
"DiGraph",
|
||||
"DiGraphNode",
|
||||
"DiGraphEdge",
|
||||
"GraphFlow",
|
||||
]
|
||||
@@ -0,0 +1,834 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Sequence
|
||||
|
||||
from agentdhal_core import (
|
||||
AgentId,
|
||||
AgentRuntime,
|
||||
AgentType,
|
||||
CancellationToken,
|
||||
ComponentBase,
|
||||
SingleThreadedAgentRuntime,
|
||||
TypeSubscription,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
MessageFactory,
|
||||
ModelClientStreamingChunkEvent,
|
||||
StopMessage,
|
||||
StructuredMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from ...state import TeamState
|
||||
from ._chat_agent_container import ChatAgentContainer
|
||||
from ._events import (
|
||||
GroupChatPause,
|
||||
GroupChatReset,
|
||||
GroupChatResume,
|
||||
GroupChatStart,
|
||||
GroupChatTermination,
|
||||
SerializableException,
|
||||
)
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
|
||||
class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
"""The base class for group chat teams.
|
||||
|
||||
In a group chat team, participants share context by publishing their messages
|
||||
to all other participants.
|
||||
|
||||
If an :class:`~agentdhal_agentchat.base.ChatAgent` is a participant,
|
||||
the :class:`~agentdhal_agentchat.messages.BaseChatMessage` from the agent response's
|
||||
:attr:`~agentdhal_agentchat.base.Response.chat_message` will be published
|
||||
to other participants in the group chat.
|
||||
|
||||
If a :class:`~agentdhal_agentchat.base.Team` is a participant,
|
||||
the :class:`~agentdhal_agentchat.messages.BaseChatMessage`
|
||||
from the team result' :attr:`~agentdhal_agentchat.base.TaskResult.messages` will be published
|
||||
to other participants in the group chat.
|
||||
|
||||
To implement a group chat team, first create a subclass of :class:`BaseGroupChatManager` and then
|
||||
create a subclass of :class:`BaseGroupChat` that uses the group chat manager.
|
||||
|
||||
This base class provides the mapping between the agents of the AgentChat API
|
||||
and the agent runtime of the Core API, and handles high-level features like
|
||||
running, pausing, resuming, and resetting the team.
|
||||
"""
|
||||
|
||||
component_type = "team"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
participants: List[ChatAgent | Team],
|
||||
group_chat_manager_name: str,
|
||||
group_chat_manager_class: type[SequentialRoutedAgent],
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
):
|
||||
self._name = name
|
||||
self._description = description
|
||||
if len(participants) == 0:
|
||||
raise ValueError("At least one participant is required.")
|
||||
if len(participants) != len(set(participant.name for participant in participants)):
|
||||
raise ValueError("The participant names must be unique.")
|
||||
self._participants = participants
|
||||
self._base_group_chat_manager_class = group_chat_manager_class
|
||||
self._termination_condition = termination_condition
|
||||
self._max_turns = max_turns
|
||||
self._message_factory = MessageFactory()
|
||||
if custom_message_types is not None:
|
||||
for message_type in custom_message_types:
|
||||
self._message_factory.register(message_type)
|
||||
|
||||
for agent in participants:
|
||||
if isinstance(agent, ChatAgent):
|
||||
for message_type in agent.produced_message_types:
|
||||
try:
|
||||
is_registered = self._message_factory.is_registered(message_type) # type: ignore[reportUnknownArgumentType]
|
||||
if issubclass(message_type, StructuredMessage) and not is_registered:
|
||||
self._message_factory.register(message_type) # type: ignore[reportUnknownArgumentType]
|
||||
except TypeError:
|
||||
# Not a class or not a valid subclassable type (skip)
|
||||
pass
|
||||
|
||||
# The team ID is a UUID that is used to identify the team and its participants
|
||||
# in the agent runtime. It is used to create unique topic types for each participant.
|
||||
# Currently, team ID is binded to an object instance of the group chat class.
|
||||
# So if you create two instances of group chat, there will be two teams with different IDs.
|
||||
self._team_id = str(uuid.uuid4())
|
||||
|
||||
# Constants for the group chat team.
|
||||
# The names are used to identify the agents within the team.
|
||||
# The names may not be unique across different teams.
|
||||
self._group_chat_manager_name = group_chat_manager_name
|
||||
self._participant_names: List[str] = [participant.name for participant in participants]
|
||||
self._participant_descriptions: List[str] = [participant.description for participant in participants]
|
||||
# The group chat topic type is used for broadcast communication among all participants and the group chat manager.
|
||||
self._group_topic_type = f"group_topic_{self._team_id}"
|
||||
# The group chat manager topic type is used for direct communication with the group chat manager.
|
||||
self._group_chat_manager_topic_type = f"{self._group_chat_manager_name}_{self._team_id}"
|
||||
# The participant topic types are used for direct communication with each participant.
|
||||
self._participant_topic_types: List[str] = [
|
||||
f"{participant.name}_{self._team_id}" for participant in participants
|
||||
]
|
||||
# The output topic type is used for emitting streaming messages from the group chat.
|
||||
# The group chat manager will relay the messages to the output message queue.
|
||||
self._output_topic_type = f"output_topic_{self._team_id}"
|
||||
|
||||
# The queue for collecting the output messages.
|
||||
self._output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination] = (
|
||||
asyncio.Queue()
|
||||
)
|
||||
|
||||
# Create a runtime for the team.
|
||||
if runtime is not None:
|
||||
self._runtime = runtime
|
||||
self._embedded_runtime = False
|
||||
else:
|
||||
# Use a embedded single-threaded runtime for the group chat.
|
||||
# Background exceptions must not be ignored as it results in non-surfaced exceptions and early team termination.
|
||||
self._runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
|
||||
self._embedded_runtime = True
|
||||
|
||||
# Flag to track if the group chat has been initialized.
|
||||
self._initialized = False
|
||||
|
||||
# Flag to track if the group chat is running.
|
||||
self._is_running = False
|
||||
|
||||
# Flag to track if the team events should be emitted.
|
||||
self._emit_team_events = emit_team_events
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the group chat team."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""A description of the group chat team."""
|
||||
return self._description
|
||||
|
||||
@abstractmethod
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], SequentialRoutedAgent]: ...
|
||||
|
||||
def _create_participant_factory(
|
||||
self,
|
||||
parent_topic_type: str,
|
||||
output_topic_type: str,
|
||||
agent: ChatAgent | Team,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], ChatAgentContainer]:
|
||||
def _factory() -> ChatAgentContainer:
|
||||
container = ChatAgentContainer(parent_topic_type, output_topic_type, agent, message_factory)
|
||||
return container
|
||||
|
||||
return _factory
|
||||
|
||||
async def _init(self, runtime: AgentRuntime) -> None:
|
||||
# Constants for the group chat manager.
|
||||
group_chat_manager_agent_type = AgentType(self._group_chat_manager_topic_type)
|
||||
|
||||
# Register participants.
|
||||
# Use the participant topic type as the agent type.
|
||||
for participant, agent_type in zip(self._participants, self._participant_topic_types, strict=True):
|
||||
# Register the participant factory.
|
||||
await ChatAgentContainer.register(
|
||||
runtime,
|
||||
type=agent_type,
|
||||
factory=self._create_participant_factory(
|
||||
self._group_topic_type, self._output_topic_type, participant, self._message_factory
|
||||
),
|
||||
)
|
||||
# Add subscriptions for the participant.
|
||||
# The participant should be able to receive messages from its own topic.
|
||||
await runtime.add_subscription(TypeSubscription(topic_type=agent_type, agent_type=agent_type))
|
||||
# The participant should be able to receive messages from the group topic.
|
||||
await runtime.add_subscription(TypeSubscription(topic_type=self._group_topic_type, agent_type=agent_type))
|
||||
|
||||
# Register the group chat manager.
|
||||
await self._base_group_chat_manager_class.register(
|
||||
runtime,
|
||||
type=group_chat_manager_agent_type.type,
|
||||
factory=self._create_group_chat_manager_factory(
|
||||
name=self._group_chat_manager_name,
|
||||
group_topic_type=self._group_topic_type,
|
||||
output_topic_type=self._output_topic_type,
|
||||
participant_names=self._participant_names,
|
||||
participant_topic_types=self._participant_topic_types,
|
||||
participant_descriptions=self._participant_descriptions,
|
||||
output_message_queue=self._output_message_queue,
|
||||
termination_condition=self._termination_condition,
|
||||
max_turns=self._max_turns,
|
||||
message_factory=self._message_factory,
|
||||
),
|
||||
)
|
||||
# Add subscriptions for the group chat manager.
|
||||
# The group chat manager should be able to receive messages from the its own topic.
|
||||
await runtime.add_subscription(
|
||||
TypeSubscription(
|
||||
topic_type=self._group_chat_manager_topic_type, agent_type=group_chat_manager_agent_type.type
|
||||
)
|
||||
)
|
||||
# The group chat manager should be able to receive messages from the group topic.
|
||||
await runtime.add_subscription(
|
||||
TypeSubscription(topic_type=self._group_topic_type, agent_type=group_chat_manager_agent_type.type)
|
||||
)
|
||||
# The group chat manager will relay the messages from output topic to the output message queue.
|
||||
await runtime.add_subscription(
|
||||
TypeSubscription(topic_type=self._output_topic_type, agent_type=group_chat_manager_agent_type.type)
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
output_task_messages: bool = True,
|
||||
) -> TaskResult:
|
||||
"""Run the team and return the result. The base implementation uses
|
||||
:meth:`run_stream` to run the team and then returns the final result.
|
||||
Once the team is stopped, the termination condition is reset.
|
||||
|
||||
Args:
|
||||
task (str | BaseChatMessage | Sequence[BaseChatMessage] | None): The task to run the team with. Can be a string, a single :class:`BaseChatMessage` , or a list of :class:`BaseChatMessage`.
|
||||
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
|
||||
Setting the cancellation token potentially put the team in an inconsistent state,
|
||||
and it may not reset the termination condition.
|
||||
To gracefully stop the team, use :class:`~agentdhal_agentchat.conditions.ExternalTermination` instead.
|
||||
|
||||
Returns:
|
||||
result: The result of the task as :class:`~agentdhal_agentchat.base.TaskResult`. The result contains the messages produced by the team and the stop reason.
|
||||
|
||||
Example using the :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` team:
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
termination = MaxMessageTermination(3)
|
||||
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
|
||||
|
||||
result = await team.run(task="Count from 1 to 10, respond one at a time.")
|
||||
print(result)
|
||||
|
||||
# Run the team again without a task to continue the previous task.
|
||||
result = await team.run()
|
||||
print(result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Example using the :class:`~agentdhal_core.CancellationToken` to cancel the task:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
termination = MaxMessageTermination(3)
|
||||
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
|
||||
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
# Create a task to run the team in the background.
|
||||
run_task = asyncio.create_task(
|
||||
team.run(
|
||||
task="Count from 1 to 10, respond one at a time.",
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for 1 second and then cancel the task.
|
||||
await asyncio.sleep(1)
|
||||
cancellation_token.cancel()
|
||||
|
||||
# This will raise a cancellation error.
|
||||
await run_task
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
result: TaskResult | None = None
|
||||
async for message in self.run_stream(
|
||||
task=task,
|
||||
cancellation_token=cancellation_token,
|
||||
output_task_messages=output_task_messages,
|
||||
):
|
||||
if isinstance(message, TaskResult):
|
||||
result = message
|
||||
if result is not None:
|
||||
return result
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
output_task_messages: bool = True,
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
|
||||
"""Run the team and produces a stream of messages and the final result
|
||||
of the type :class:`~agentdhal_agentchat.base.TaskResult` as the last item in the stream. Once the
|
||||
team is stopped, the termination condition is reset.
|
||||
|
||||
.. note::
|
||||
|
||||
If an agent produces :class:`~agentdhal_agentchat.messages.ModelClientStreamingChunkEvent`,
|
||||
the message will be yielded in the stream but it will not be included in the
|
||||
:attr:`~agentdhal_agentchat.base.TaskResult.messages`.
|
||||
|
||||
Args:
|
||||
task (str | BaseChatMessage | Sequence[BaseChatMessage] | None): The task to run the team with. Can be a string, a single :class:`BaseChatMessage` , or a list of :class:`BaseChatMessage`.
|
||||
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
|
||||
Setting the cancellation token potentially put the team in an inconsistent state,
|
||||
and it may not reset the termination condition.
|
||||
To gracefully stop the team, use :class:`~agentdhal_agentchat.conditions.ExternalTermination` instead.
|
||||
output_task_messages (bool): Whether to include task messages in the output stream. Defaults to True for backward compatibility.
|
||||
|
||||
Returns:
|
||||
stream: an :class:`~collections.abc.AsyncGenerator` that yields :class:`~agentdhal_agentchat.messages.BaseAgentEvent`, :class:`~agentdhal_agentchat.messages.BaseChatMessage`, and the final result :class:`~agentdhal_agentchat.base.TaskResult` as the last item in the stream.
|
||||
|
||||
Example using the :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` team:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
termination = MaxMessageTermination(3)
|
||||
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
|
||||
|
||||
stream = team.run_stream(task="Count from 1 to 10, respond one at a time.")
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
# Run the team again without a task to continue the previous task.
|
||||
stream = team.run_stream()
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Example using the :class:`~agentdhal_core.CancellationToken` to cancel the task:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
termination = MaxMessageTermination(3)
|
||||
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
|
||||
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
# Create a task to run the team in the background.
|
||||
run_task = asyncio.create_task(
|
||||
Console(
|
||||
team.run_stream(
|
||||
task="Count from 1 to 10, respond one at a time.",
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for 1 second and then cancel the task.
|
||||
await asyncio.sleep(1)
|
||||
cancellation_token.cancel()
|
||||
|
||||
# This will raise a cancellation error.
|
||||
await run_task
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
# Create the messages list if the task is a string or a chat message.
|
||||
messages: List[BaseChatMessage] | None = None
|
||||
if task is None:
|
||||
pass
|
||||
elif isinstance(task, str):
|
||||
messages = [TextMessage(content=task, source="user")]
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
messages = [task]
|
||||
elif isinstance(task, list):
|
||||
if not task:
|
||||
raise ValueError("Task list cannot be empty.")
|
||||
messages = []
|
||||
for msg in task:
|
||||
if not isinstance(msg, BaseChatMessage):
|
||||
raise ValueError("All messages in task list must be valid BaseChatMessage types")
|
||||
messages.append(msg)
|
||||
else:
|
||||
raise ValueError("Task must be a string, a BaseChatMessage, or a list of BaseChatMessage.")
|
||||
# Check if the messages types are registered with the message factory.
|
||||
if messages is not None:
|
||||
for msg in messages:
|
||||
if not self._message_factory.is_registered(msg.__class__):
|
||||
raise ValueError(
|
||||
f"Message type {msg.__class__} is not registered with the message factory. "
|
||||
"Please register it with the message factory by adding it to the "
|
||||
"custom_message_types list when creating the team."
|
||||
)
|
||||
|
||||
if self._is_running:
|
||||
raise ValueError("The team is already running, it cannot run again until it is stopped.")
|
||||
self._is_running = True
|
||||
|
||||
if self._embedded_runtime:
|
||||
# Start the embedded runtime.
|
||||
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
|
||||
self._runtime.start()
|
||||
|
||||
if not self._initialized:
|
||||
await self._init(self._runtime)
|
||||
|
||||
shutdown_task: asyncio.Task[None] | None = None
|
||||
if self._embedded_runtime:
|
||||
|
||||
async def stop_runtime() -> None:
|
||||
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
|
||||
try:
|
||||
# This will propagate any exceptions raised.
|
||||
await self._runtime.stop_when_idle()
|
||||
# Put a termination message in the queue to indicate that the group chat is stopped for whatever reason
|
||||
# but not due to an exception.
|
||||
await self._output_message_queue.put(
|
||||
GroupChatTermination(
|
||||
message=StopMessage(
|
||||
content="The group chat is stopped.", source=self._group_chat_manager_name
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Stop the consumption of messages and end the stream.
|
||||
# NOTE: we also need to put a GroupChatTermination event here because when the runtime
|
||||
# has an exception, the group chat manager may not be able to put a GroupChatTermination event in the queue.
|
||||
# This may not be necessary if the group chat manager is able to handle the exception and put the event in the queue.
|
||||
await self._output_message_queue.put(
|
||||
GroupChatTermination(
|
||||
message=StopMessage(
|
||||
content="An exception occurred in the runtime.", source=self._group_chat_manager_name
|
||||
),
|
||||
error=SerializableException.from_exception(e),
|
||||
)
|
||||
)
|
||||
|
||||
# Create a background task to stop the runtime when the group chat
|
||||
# is stopped or has an exception.
|
||||
shutdown_task = asyncio.create_task(stop_runtime())
|
||||
|
||||
try:
|
||||
# Run the team by sending the start message to the group chat manager.
|
||||
# The group chat manager will start the group chat by relaying the message to the participants
|
||||
# and the group chat manager.
|
||||
await self._runtime.send_message(
|
||||
GroupChatStart(messages=messages, output_task_messages=output_task_messages),
|
||||
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
# Collect the output messages in order.
|
||||
output_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
stop_reason: str | None = None
|
||||
|
||||
# Yield the messages until the queue is empty.
|
||||
while True:
|
||||
message_future = asyncio.ensure_future(self._output_message_queue.get())
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(message_future)
|
||||
# Wait for the next message, this will raise an exception if the task is cancelled.
|
||||
message = await message_future
|
||||
if isinstance(message, GroupChatTermination):
|
||||
# If the message contains an error, we need to raise it here.
|
||||
# This will stop the team and propagate the error.
|
||||
if message.error is not None:
|
||||
raise RuntimeError(str(message.error))
|
||||
stop_reason = message.message.content
|
||||
break
|
||||
yield message
|
||||
if isinstance(message, ModelClientStreamingChunkEvent):
|
||||
# Skip the model client streaming chunk events.
|
||||
continue
|
||||
output_messages.append(message)
|
||||
|
||||
# Yield the final result.
|
||||
yield TaskResult(messages=output_messages, stop_reason=stop_reason)
|
||||
|
||||
finally:
|
||||
try:
|
||||
if shutdown_task is not None:
|
||||
# Wait for the shutdown task to finish.
|
||||
# This will propagate any exceptions raised.
|
||||
await shutdown_task
|
||||
finally:
|
||||
# Clear the output message queue.
|
||||
while not self._output_message_queue.empty():
|
||||
self._output_message_queue.get_nowait()
|
||||
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset the team and its participants to their initial state.
|
||||
|
||||
The team must be stopped before it can be reset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the team has not been initialized or is currently running.
|
||||
|
||||
Example using the :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` team:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
termination = MaxMessageTermination(3)
|
||||
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
|
||||
stream = team.run_stream(task="Count from 1 to 10, respond one at a time.")
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
# Reset the team.
|
||||
await team.reset()
|
||||
stream = team.run_stream(task="Count from 1 to 10, respond one at a time.")
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
if not self._initialized:
|
||||
await self._init(self._runtime)
|
||||
|
||||
if self._is_running:
|
||||
raise RuntimeError("The group chat is currently running. It must be stopped before it can be reset.")
|
||||
self._is_running = True
|
||||
|
||||
if self._embedded_runtime:
|
||||
# Start the runtime.
|
||||
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
|
||||
self._runtime.start()
|
||||
|
||||
try:
|
||||
# Send a reset messages to all participants.
|
||||
for participant_topic_type in self._participant_topic_types:
|
||||
await self._runtime.send_message(
|
||||
GroupChatReset(),
|
||||
recipient=AgentId(type=participant_topic_type, key=self._team_id),
|
||||
)
|
||||
# Send a reset message to the group chat manager.
|
||||
await self._runtime.send_message(
|
||||
GroupChatReset(),
|
||||
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
|
||||
)
|
||||
finally:
|
||||
if self._embedded_runtime:
|
||||
# Stop the runtime.
|
||||
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
|
||||
await self._runtime.stop_when_idle()
|
||||
|
||||
# Reset the output message queue.
|
||||
while not self._output_message_queue.empty():
|
||||
self._output_message_queue.get_nowait()
|
||||
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
|
||||
async def pause(self) -> None:
|
||||
"""Pause its participants when the team is running by calling their
|
||||
:meth:`~agentdhal_agentchat.base.ChatAgent.on_pause` method via direct RPC calls.
|
||||
|
||||
.. attention::
|
||||
|
||||
This is an experimental feature introduced in v0.4.9 and may subject
|
||||
to change or removal in the future.
|
||||
|
||||
The team must be initialized before it can be paused.
|
||||
|
||||
Different from termination, pausing the team does not cause the
|
||||
:meth:`run` or :meth:`run_stream` method to return. It calls the
|
||||
:meth:`~agentdhal_agentchat.base.ChatAgent.on_pause` method on each
|
||||
participant, and if the participant does not implement the method, it
|
||||
will be a no-op.
|
||||
|
||||
.. note::
|
||||
|
||||
It is the responsibility of the agent class to handle the pause
|
||||
and ensure that the agent can be resumed later.
|
||||
Make sure to implement the :meth:`~agentdhal_agentchat.agents.BaseChatAgent.on_pause`
|
||||
method in your agent class for custom pause behavior.
|
||||
By default, the agent will not do anything when called.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the team has not been initialized. Exceptions from
|
||||
the participants when calling their implementations of
|
||||
:class:`~agentdhal_agentchat.base.ChatAgent.on_pause` are
|
||||
propagated to this method and raised.
|
||||
"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("The group chat has not been initialized. It must be run before it can be paused.")
|
||||
|
||||
# Send a pause message to all participants.
|
||||
for participant_topic_type in self._participant_topic_types:
|
||||
await self._runtime.send_message(
|
||||
GroupChatPause(),
|
||||
recipient=AgentId(type=participant_topic_type, key=self._team_id),
|
||||
)
|
||||
# Send a pause message to the group chat manager.
|
||||
await self._runtime.send_message(
|
||||
GroupChatPause(),
|
||||
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
|
||||
)
|
||||
|
||||
async def resume(self) -> None:
|
||||
"""Resume its participants when the team is running and paused by calling their
|
||||
:meth:`~agentdhal_agentchat.base.ChatAgent.on_resume` method via direct RPC calls.
|
||||
|
||||
.. attention::
|
||||
|
||||
This is an experimental feature introduced in v0.4.9 and may subject
|
||||
to change or removal in the future.
|
||||
|
||||
The team must be initialized before it can be resumed.
|
||||
|
||||
Different from termination and restart with a new task, resuming the team
|
||||
does not cause the :meth:`run` or :meth:`run_stream` method to return.
|
||||
It calls the :meth:`~agentdhal_agentchat.base.ChatAgent.on_resume` method on each
|
||||
participant, and if the participant does not implement the method, it
|
||||
will be a no-op.
|
||||
|
||||
.. note::
|
||||
|
||||
It is the responsibility of the agent class to handle the resume
|
||||
and ensure that the agent continues from where it was paused.
|
||||
Make sure to implement the :meth:`~agentdhal_agentchat.agents.BaseChatAgent.on_resume`
|
||||
method in your agent class for custom resume behavior.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the team has not been initialized. Exceptions from
|
||||
the participants when calling their implementations of :class:`~agentdhal_agentchat.base.ChatAgent.on_resume`
|
||||
method are propagated to this method and raised.
|
||||
|
||||
"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("The group chat has not been initialized. It must be run before it can be resumed.")
|
||||
|
||||
# Send a resume message to all participants.
|
||||
for participant_topic_type in self._participant_topic_types:
|
||||
await self._runtime.send_message(
|
||||
GroupChatResume(),
|
||||
recipient=AgentId(type=participant_topic_type, key=self._team_id),
|
||||
)
|
||||
# Send a resume message to the group chat manager.
|
||||
await self._runtime.send_message(
|
||||
GroupChatResume(),
|
||||
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
|
||||
)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the group chat team.
|
||||
|
||||
The state is saved by calling the :meth:`~agentdhal_core.AgentRuntime.agent_save_state` method
|
||||
on each participant and the group chat manager with their internal agent ID.
|
||||
The state is returned as a nested dictionary: a dictionary with key `agent_states`,
|
||||
which is a dictionary the agent names as keys and the state as values.
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
{
|
||||
"agent_states": {
|
||||
"agent1": ...,
|
||||
"agent2": ...,
|
||||
"RoundRobinGroupChatManager": ...
|
||||
}
|
||||
}
|
||||
|
||||
.. note::
|
||||
|
||||
Starting v0.4.9, the state is using the agent name as the key instead of the agent ID,
|
||||
and the `team_id` field is removed from the state. This is to allow the state to be
|
||||
portable across different teams and runtimes. States saved with the old format
|
||||
may not be compatible with the new format in the future.
|
||||
|
||||
.. caution::
|
||||
|
||||
When calling :func:`~agentdhal_agentchat.teams.BaseGroupChat.save_state` on a team
|
||||
while it is running, the state may not be consistent and may result in an unexpected state.
|
||||
It is recommended to call this method when the team is not running or after it is stopped.
|
||||
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self._init(self._runtime)
|
||||
|
||||
# Store state of each agent by their name.
|
||||
# NOTE: we don't use the agent ID as the key here because we need to be able to decouple
|
||||
# the state of the agents from their identities in the agent runtime.
|
||||
agent_states: Dict[str, Mapping[str, Any]] = {}
|
||||
# Save the state of all participants.
|
||||
for name, agent_type in zip(self._participant_names, self._participant_topic_types, strict=True):
|
||||
agent_id = AgentId(type=agent_type, key=self._team_id)
|
||||
# NOTE: We are using the runtime's save state method rather than the agent instance's
|
||||
# save_state method because we want to support saving state of remote agents.
|
||||
agent_states[name] = await self._runtime.agent_save_state(agent_id)
|
||||
# Save the state of the group chat manager.
|
||||
agent_id = AgentId(type=self._group_chat_manager_topic_type, key=self._team_id)
|
||||
agent_states[self._group_chat_manager_name] = await self._runtime.agent_save_state(agent_id)
|
||||
return TeamState(agent_states=agent_states).model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load an external state and overwrite the current state of the group chat team.
|
||||
|
||||
The state is loaded by calling the :meth:`~agentdhal_core.AgentRuntime.agent_load_state` method
|
||||
on each participant and the group chat manager with their internal agent ID.
|
||||
See :meth:`~agentdhal_agentchat.teams.BaseGroupChat.save_state` for the expected format of the state.
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self._init(self._runtime)
|
||||
|
||||
if self._is_running:
|
||||
raise RuntimeError("The team cannot be loaded while it is running.")
|
||||
self._is_running = True
|
||||
|
||||
try:
|
||||
team_state = TeamState.model_validate(state)
|
||||
# Load the state of all participants.
|
||||
for name, agent_type in zip(self._participant_names, self._participant_topic_types, strict=True):
|
||||
agent_id = AgentId(type=agent_type, key=self._team_id)
|
||||
if name not in team_state.agent_states:
|
||||
raise ValueError(f"Agent state for {name} not found in the saved state.")
|
||||
await self._runtime.agent_load_state(agent_id, team_state.agent_states[name])
|
||||
# Load the state of the group chat manager.
|
||||
agent_id = AgentId(type=self._group_chat_manager_topic_type, key=self._team_id)
|
||||
if self._group_chat_manager_name not in team_state.agent_states:
|
||||
raise ValueError(f"Agent state for {self._group_chat_manager_name} not found in the saved state.")
|
||||
await self._runtime.agent_load_state(agent_id, team_state.agent_states[self._group_chat_manager_name])
|
||||
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
"Invalid state format. The expected state format has changed since v0.4.9. "
|
||||
"Please read the release note on GitHub."
|
||||
) from e
|
||||
|
||||
finally:
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
@@ -0,0 +1,326 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Sequence
|
||||
|
||||
from agentdhal_core import CancellationToken, DefaultTopicId, MessageContext, event, rpc
|
||||
|
||||
from ...base import TerminationCondition
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, SelectSpeakerEvent, StopMessage
|
||||
from ._events import (
|
||||
GroupChatAgentResponse,
|
||||
GroupChatError,
|
||||
GroupChatMessage,
|
||||
GroupChatPause,
|
||||
GroupChatRequestPublish,
|
||||
GroupChatReset,
|
||||
GroupChatResume,
|
||||
GroupChatStart,
|
||||
GroupChatTeamResponse,
|
||||
GroupChatTermination,
|
||||
SerializableException,
|
||||
)
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
|
||||
class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
"""Base class for a group chat manager that manages a group chat with multiple participants.
|
||||
|
||||
It is the responsibility of the caller to ensure:
|
||||
- All participants must subscribe to the group chat topic and each of their own topics.
|
||||
- The group chat manager must subscribe to the group chat topic.
|
||||
- The agent types of the participants must be unique.
|
||||
- For each participant, the agent type must be the same as the topic type.
|
||||
|
||||
Without the above conditions, the group chat will not function correctly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
emit_team_events: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
description="Group chat manager",
|
||||
sequential_message_types=[
|
||||
GroupChatStart,
|
||||
GroupChatAgentResponse,
|
||||
GroupChatTeamResponse,
|
||||
GroupChatMessage,
|
||||
GroupChatReset,
|
||||
],
|
||||
)
|
||||
if max_turns is not None and max_turns <= 0:
|
||||
raise ValueError("The maximum number of turns must be greater than 0.")
|
||||
if len(participant_topic_types) != len(participant_descriptions):
|
||||
raise ValueError("The number of participant topic types, agent types, and descriptions must be the same.")
|
||||
if len(set(participant_topic_types)) != len(participant_topic_types):
|
||||
raise ValueError("The participant topic types must be unique.")
|
||||
if group_topic_type in participant_topic_types:
|
||||
raise ValueError("The group topic type must not be in the participant topic types.")
|
||||
self._name = name
|
||||
self._group_topic_type = group_topic_type
|
||||
self._output_topic_type = output_topic_type
|
||||
self._participant_names = participant_names
|
||||
self._participant_name_to_topic_type = {
|
||||
name: topic_type for name, topic_type in zip(participant_names, participant_topic_types, strict=True)
|
||||
}
|
||||
self._participant_descriptions = participant_descriptions
|
||||
self._message_thread: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
self._output_message_queue = output_message_queue
|
||||
self._termination_condition = termination_condition
|
||||
self._max_turns = max_turns
|
||||
self._current_turn = 0
|
||||
self._message_factory = message_factory
|
||||
self._emit_team_events = emit_team_events
|
||||
self._active_speakers: List[str] = []
|
||||
|
||||
@rpc
|
||||
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
|
||||
"""Handle the start of a group chat by selecting a speaker to start the conversation."""
|
||||
|
||||
# Check if the conversation has already terminated.
|
||||
if self._termination_condition is not None and self._termination_condition.terminated:
|
||||
early_stop_message = StopMessage(
|
||||
content="The group chat has already terminated.",
|
||||
source=self._name,
|
||||
)
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(early_stop_message)
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Validate the group state given the start messages
|
||||
await self.validate_group_state(message.messages)
|
||||
|
||||
if message.messages is not None:
|
||||
# Log all messages at once
|
||||
await self.publish_message(
|
||||
GroupChatStart(messages=message.messages),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
|
||||
# Only put messages in output queue if output_task_messages is True
|
||||
if message.output_task_messages:
|
||||
for msg in message.messages:
|
||||
await self._output_message_queue.put(msg)
|
||||
|
||||
# Relay all messages at once to participants
|
||||
await self.publish_message(
|
||||
GroupChatStart(messages=message.messages),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
|
||||
# Append all messages to thread
|
||||
await self.update_message_thread(message.messages)
|
||||
|
||||
# Check termination condition after processing all messages
|
||||
if await self._apply_termination_condition(message.messages):
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Select speakers to start/continue the conversation
|
||||
await self._transition_to_next_speakers(ctx.cancellation_token)
|
||||
|
||||
@event
|
||||
async def handle_agent_response(
|
||||
self, message: GroupChatAgentResponse | GroupChatTeamResponse, ctx: MessageContext
|
||||
) -> None:
|
||||
try:
|
||||
# Construct the detla from the agent response.
|
||||
delta: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if isinstance(message, GroupChatAgentResponse):
|
||||
if message.response.inner_messages is not None:
|
||||
for inner_message in message.response.inner_messages:
|
||||
delta.append(inner_message)
|
||||
delta.append(message.response.chat_message)
|
||||
else:
|
||||
delta.extend(message.result.messages)
|
||||
|
||||
# Append the messages to the message thread.
|
||||
await self.update_message_thread(delta)
|
||||
|
||||
# Remove the agent from the active speakers list.
|
||||
self._active_speakers.remove(message.name)
|
||||
if len(self._active_speakers) > 0:
|
||||
# If there are still active speakers, return without doing anything.
|
||||
return
|
||||
|
||||
# Check if the conversation should be terminated.
|
||||
if await self._apply_termination_condition(delta, increment_turn_count=True):
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Select speakers to continue the conversation.
|
||||
await self._transition_to_next_speakers(ctx.cancellation_token)
|
||||
except Exception as e:
|
||||
# Handle the exception and signal termination with an error.
|
||||
error = SerializableException.from_exception(e)
|
||||
await self._signal_termination_with_error(error)
|
||||
# Raise the exception to the runtime.
|
||||
raise
|
||||
|
||||
async def _transition_to_next_speakers(self, cancellation_token: CancellationToken) -> None:
|
||||
speaker_names_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
|
||||
# Link the select speaker future to the cancellation token.
|
||||
cancellation_token.link_future(speaker_names_future)
|
||||
speaker_names = await speaker_names_future
|
||||
if isinstance(speaker_names, str):
|
||||
# If only one speaker is selected, convert it to a list.
|
||||
speaker_names = [speaker_names]
|
||||
for speaker_name in speaker_names:
|
||||
if speaker_name not in self._participant_name_to_topic_type:
|
||||
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
|
||||
await self._log_speaker_selection(speaker_names)
|
||||
|
||||
# Send request to publish message to the next speakers
|
||||
for speaker_name in speaker_names:
|
||||
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
|
||||
await self.publish_message(
|
||||
GroupChatRequestPublish(),
|
||||
topic_id=DefaultTopicId(type=speaker_topic_type),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
self._active_speakers.append(speaker_name)
|
||||
|
||||
async def _apply_termination_condition(
|
||||
self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False
|
||||
) -> bool:
|
||||
"""Apply the termination condition to the delta and return True if the conversation should be terminated.
|
||||
It also resets the termination condition and turn count, and signals termination to the caller of the team."""
|
||||
if self._termination_condition is not None:
|
||||
stop_message = await self._termination_condition(delta)
|
||||
if stop_message is not None:
|
||||
# Reset the termination conditions and turn count.
|
||||
await self._termination_condition.reset()
|
||||
self._current_turn = 0
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(stop_message)
|
||||
# Stop the group chat.
|
||||
return True
|
||||
if increment_turn_count:
|
||||
# Increment the turn count.
|
||||
self._current_turn += 1
|
||||
# Check if the maximum number of turns has been reached.
|
||||
if self._max_turns is not None:
|
||||
if self._current_turn >= self._max_turns:
|
||||
stop_message = StopMessage(
|
||||
content=f"Maximum number of turns {self._max_turns} reached.",
|
||||
source=self._name,
|
||||
)
|
||||
# Reset the termination conditions and turn count.
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._current_turn = 0
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(stop_message)
|
||||
# Stop the group chat.
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _log_speaker_selection(self, speaker_names: List[str]) -> None:
|
||||
"""Log the selected speaker to the output message queue."""
|
||||
select_msg = SelectSpeakerEvent(content=speaker_names, source=self._name)
|
||||
if self._emit_team_events:
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=select_msg),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
await self._output_message_queue.put(select_msg)
|
||||
|
||||
async def _signal_termination(self, message: StopMessage) -> None:
|
||||
termination_event = GroupChatTermination(message=message)
|
||||
# Log the early stop message.
|
||||
await self.publish_message(
|
||||
termination_event,
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Put the termination event in the output message queue.
|
||||
await self._output_message_queue.put(termination_event)
|
||||
|
||||
async def _signal_termination_with_error(self, error: SerializableException) -> None:
|
||||
termination_event = GroupChatTermination(
|
||||
message=StopMessage(content="An error occurred in the group chat.", source=self._name), error=error
|
||||
)
|
||||
# Log the termination event.
|
||||
await self.publish_message(
|
||||
termination_event,
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Put the termination event in the output message queue.
|
||||
await self._output_message_queue.put(termination_event)
|
||||
|
||||
@event
|
||||
async def handle_group_chat_message(self, message: GroupChatMessage, ctx: MessageContext) -> None:
|
||||
"""Handle a group chat message by appending the content to its output message queue."""
|
||||
await self._output_message_queue.put(message.message)
|
||||
|
||||
@event
|
||||
async def handle_group_chat_error(self, message: GroupChatError, ctx: MessageContext) -> None:
|
||||
"""Handle a group chat error by logging the error and signaling termination."""
|
||||
await self._signal_termination_with_error(message.error)
|
||||
|
||||
@rpc
|
||||
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
|
||||
"""Reset the group chat manager. Calling :meth:`reset` to reset the group chat manager
|
||||
and clear the message thread."""
|
||||
await self.reset()
|
||||
|
||||
@rpc
|
||||
async def handle_pause(self, message: GroupChatPause, ctx: MessageContext) -> None:
|
||||
"""Pause the group chat manager. This is a no-op in the base class."""
|
||||
pass
|
||||
|
||||
@rpc
|
||||
async def handle_resume(self, message: GroupChatResume, ctx: MessageContext) -> None:
|
||||
"""Resume the group chat manager. This is a no-op in the base class."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
"""Validate the state of the group chat given the start messages.
|
||||
This is executed when the group chat manager receives a GroupChatStart event.
|
||||
|
||||
Args:
|
||||
messages: A list of chat messages to validate, or None if no messages are provided.
|
||||
"""
|
||||
...
|
||||
|
||||
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
|
||||
"""Update the message thread with the new messages.
|
||||
This is called when the group chat receives a GroupChatStart or GroupChatAgentResponse event,
|
||||
before calling the select_speakers method.
|
||||
"""
|
||||
self._message_thread.extend(messages)
|
||||
|
||||
@abstractmethod
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Select speakers from the participants and return the topic types of the selected speaker.
|
||||
This is called when the group chat manager have received all responses from the participants
|
||||
for a turn and is ready to select the next speakers for the next turn.
|
||||
|
||||
Args:
|
||||
thread: The message thread of the group chat.
|
||||
|
||||
Returns:
|
||||
A list of topic types of the selected speakers.
|
||||
If only one speaker is selected, a single string is returned instead of a list.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""Reset the group chat manager."""
|
||||
...
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
raise ValueError(f"Unhandled message in group chat manager: {type(message)}")
|
||||
@@ -0,0 +1,213 @@
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from agentdhal_core import DefaultTopicId, MessageContext, event, rpc, trace_invoke_agent_span
|
||||
|
||||
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, MessageFactory
|
||||
|
||||
from ...base import ChatAgent, Response, TaskResult, Team
|
||||
from ...state import ChatAgentContainerState
|
||||
from ._events import (
|
||||
GroupChatAgentResponse,
|
||||
GroupChatError,
|
||||
GroupChatMessage,
|
||||
GroupChatPause,
|
||||
GroupChatRequestPublish,
|
||||
GroupChatReset,
|
||||
GroupChatResume,
|
||||
GroupChatStart,
|
||||
GroupChatTeamResponse,
|
||||
SerializableException,
|
||||
)
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
|
||||
class ChatAgentContainer(SequentialRoutedAgent):
|
||||
"""A core agent class that delegates message handling to an
|
||||
:class:`agentdhal_agentchat.base.ChatAgent` or :class:`agentdhal_agentchat.base.Team`
|
||||
so that it can be used in a group chat team.
|
||||
|
||||
Args:
|
||||
parent_topic_type (str): The topic type of the parent orchestrator.
|
||||
output_topic_type (str): The topic type for the output.
|
||||
agent (ChatAgent | Team): The agent or team to delegate message handling to.
|
||||
message_factory (MessageFactory): The message factory to use for
|
||||
creating messages from JSON data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent | Team, message_factory: MessageFactory
|
||||
) -> None:
|
||||
super().__init__(
|
||||
description=agent.description,
|
||||
sequential_message_types=[
|
||||
GroupChatStart,
|
||||
GroupChatRequestPublish,
|
||||
GroupChatReset,
|
||||
GroupChatAgentResponse,
|
||||
GroupChatTeamResponse,
|
||||
],
|
||||
)
|
||||
self._parent_topic_type = parent_topic_type
|
||||
self._output_topic_type = output_topic_type
|
||||
self._agent = agent
|
||||
self._message_buffer: List[BaseChatMessage] = []
|
||||
self._message_factory = message_factory
|
||||
|
||||
@event
|
||||
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
|
||||
"""Handle a start event by appending the content to the buffer."""
|
||||
if message.messages is not None:
|
||||
for msg in message.messages:
|
||||
self._buffer_message(msg)
|
||||
|
||||
@event
|
||||
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
|
||||
"""Handle an agent response event by appending the content to the buffer."""
|
||||
self._buffer_message(message.response.chat_message)
|
||||
|
||||
@event
|
||||
async def handle_team_response(self, message: GroupChatTeamResponse, ctx: MessageContext) -> None:
|
||||
"""Handle a team response event by appending the content to the buffer."""
|
||||
for msg in message.result.messages:
|
||||
if isinstance(msg, BaseChatMessage):
|
||||
self._buffer_message(msg)
|
||||
|
||||
@rpc
|
||||
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
|
||||
"""Handle a reset event by resetting the agent."""
|
||||
self._message_buffer.clear()
|
||||
if isinstance(self._agent, Team):
|
||||
# If the agent is a team, reset the team.
|
||||
await self._agent.reset()
|
||||
else:
|
||||
await self._agent.on_reset(ctx.cancellation_token)
|
||||
|
||||
@event
|
||||
async def handle_request(self, message: GroupChatRequestPublish, ctx: MessageContext) -> None:
|
||||
"""Handle a content request event by passing the messages in the buffer
|
||||
to the delegate agent and publish the response."""
|
||||
if isinstance(self._agent, Team):
|
||||
try:
|
||||
stream = self._agent.run_stream(
|
||||
task=self._message_buffer,
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
output_task_messages=False,
|
||||
)
|
||||
result: TaskResult | None = None
|
||||
async for team_event in stream:
|
||||
if isinstance(team_event, TaskResult):
|
||||
result = team_event
|
||||
else:
|
||||
await self._log_message(team_event)
|
||||
if result is None:
|
||||
raise RuntimeError(
|
||||
"The team did not produce a final TaskResult. Check the team's run_stream method."
|
||||
)
|
||||
self._message_buffer.clear()
|
||||
# Publish the team response to the group chat.
|
||||
await self.publish_message(
|
||||
GroupChatTeamResponse(result=result, name=self._agent.name),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
except Exception as e:
|
||||
# Publish the error to the group chat.
|
||||
error_message = SerializableException.from_exception(e)
|
||||
await self.publish_message(
|
||||
GroupChatError(error=error_message),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
# Raise the error to the runtime.
|
||||
raise
|
||||
else:
|
||||
# If the agent is not a team, handle it as a single agent.
|
||||
with trace_invoke_agent_span(
|
||||
agent_name=self._agent.name,
|
||||
agent_description=self._agent.description,
|
||||
agent_id=str(self.id),
|
||||
):
|
||||
try:
|
||||
# Pass the messages in the buffer to the delegate agent.
|
||||
response: Response | None = None
|
||||
async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
|
||||
if isinstance(msg, Response):
|
||||
await self._log_message(msg.chat_message)
|
||||
response = msg
|
||||
else:
|
||||
await self._log_message(msg)
|
||||
if response is None:
|
||||
raise RuntimeError(
|
||||
"The agent did not produce a final response. Check the agent's on_messages_stream method."
|
||||
)
|
||||
# Publish the response to the group chat.
|
||||
self._message_buffer.clear()
|
||||
await self.publish_message(
|
||||
GroupChatAgentResponse(response=response, name=self._agent.name),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
except Exception as e:
|
||||
# Publish the error to the group chat.
|
||||
error_message = SerializableException.from_exception(e)
|
||||
await self.publish_message(
|
||||
GroupChatError(error=error_message),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
# Raise the error to the runtime.
|
||||
raise
|
||||
|
||||
def _buffer_message(self, message: BaseChatMessage) -> None:
|
||||
if not self._message_factory.is_registered(message.__class__):
|
||||
raise ValueError(f"Message type {message.__class__} is not registered.")
|
||||
# Buffer the message.
|
||||
self._message_buffer.append(message)
|
||||
|
||||
async def _log_message(self, message: BaseAgentEvent | BaseChatMessage) -> None:
|
||||
if not self._message_factory.is_registered(message.__class__):
|
||||
raise ValueError(f"Message type {message.__class__} is not registered.")
|
||||
# Log the message.
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
|
||||
@rpc
|
||||
async def handle_pause(self, message: GroupChatPause, ctx: MessageContext) -> None:
|
||||
"""Handle a pause event by pausing the agent."""
|
||||
if isinstance(self._agent, Team):
|
||||
# If the agent is a team, pause the team.
|
||||
await self._agent.pause()
|
||||
else:
|
||||
await self._agent.on_pause(ctx.cancellation_token)
|
||||
|
||||
@rpc
|
||||
async def handle_resume(self, message: GroupChatResume, ctx: MessageContext) -> None:
|
||||
"""Handle a resume event by resuming the agent."""
|
||||
if isinstance(self._agent, Team):
|
||||
# If the agent is a team, resume the team.
|
||||
await self._agent.resume()
|
||||
else:
|
||||
await self._agent.on_resume(ctx.cancellation_token)
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
raise ValueError(f"Unhandled message in agent container: {type(message)}")
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
agent_state = await self._agent.save_state()
|
||||
state = ChatAgentContainerState(
|
||||
agent_state=agent_state, message_buffer=[message.dump() for message in self._message_buffer]
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
container_state = ChatAgentContainerState.model_validate(state)
|
||||
self._message_buffer = []
|
||||
for message_data in container_state.message_buffer:
|
||||
message = self._message_factory.create(message_data)
|
||||
if isinstance(message, BaseChatMessage):
|
||||
self._message_buffer.append(message)
|
||||
else:
|
||||
raise ValueError(f"Invalid message type in message buffer: {type(message)}")
|
||||
await self._agent.load_state(container_state.agent_state)
|
||||
113
agent_dhal/agentdhal_agentchat/teams/_group_chat/_events.py
Normal file
113
agent_dhal/agentdhal_agentchat/teams/_group_chat/_events.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, SerializeAsAny
|
||||
|
||||
from ...base import Response, TaskResult
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, StopMessage
|
||||
|
||||
|
||||
class SerializableException(BaseModel):
|
||||
"""A serializable exception."""
|
||||
|
||||
error_type: str
|
||||
"""The type of error that occurred."""
|
||||
|
||||
error_message: str
|
||||
"""The error message that describes the error."""
|
||||
|
||||
traceback: str | None = None
|
||||
"""The traceback of the error, if available."""
|
||||
|
||||
@classmethod
|
||||
def from_exception(cls, exc: Exception) -> "SerializableException":
|
||||
"""Create a GroupChatError from an exception."""
|
||||
return cls(
|
||||
error_type=type(exc).__name__,
|
||||
error_message=str(exc),
|
||||
traceback="\n".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the error, including the traceback if available."""
|
||||
if self.traceback:
|
||||
return f"{self.error_type}: {self.error_message}\nTraceback:\n{self.traceback}"
|
||||
return f"{self.error_type}: {self.error_message}"
|
||||
|
||||
|
||||
class GroupChatStart(BaseModel):
|
||||
"""A request to start a group chat."""
|
||||
|
||||
messages: List[SerializeAsAny[BaseChatMessage]] | None = None
|
||||
"""An optional list of messages to start the group chat."""
|
||||
|
||||
output_task_messages: bool = True
|
||||
"""Whether to include task messages in the output. Defaults to True for backward compatibility."""
|
||||
|
||||
|
||||
class GroupChatAgentResponse(BaseModel):
|
||||
"""A response published to a group chat."""
|
||||
|
||||
response: SerializeAsAny[Response]
|
||||
"""The response from an agent."""
|
||||
|
||||
name: str
|
||||
"""The name of the agent that produced the response."""
|
||||
|
||||
|
||||
class GroupChatTeamResponse(BaseModel):
|
||||
"""A response published to a group chat from a team."""
|
||||
|
||||
result: SerializeAsAny[TaskResult]
|
||||
"""The result from a team."""
|
||||
|
||||
name: str
|
||||
"""The name of the team that produced the response."""
|
||||
|
||||
|
||||
class GroupChatRequestPublish(BaseModel):
|
||||
"""A request to publish a message to a group chat."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class GroupChatMessage(BaseModel):
|
||||
"""A message from a group chat."""
|
||||
|
||||
message: SerializeAsAny[BaseAgentEvent | BaseChatMessage]
|
||||
"""The message that was published."""
|
||||
|
||||
|
||||
class GroupChatTermination(BaseModel):
|
||||
"""A message indicating that a group chat has terminated."""
|
||||
|
||||
message: StopMessage
|
||||
"""The stop message that indicates the reason of termination."""
|
||||
|
||||
error: SerializableException | None = None
|
||||
"""The error that occurred, if any."""
|
||||
|
||||
|
||||
class GroupChatReset(BaseModel):
|
||||
"""A request to reset the agents in the group chat."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class GroupChatPause(BaseModel):
|
||||
"""A request to pause the group chat."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class GroupChatResume(BaseModel):
|
||||
"""A request to resume the group chat."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class GroupChatError(BaseModel):
|
||||
"""A message indicating that an error occurred in the group chat."""
|
||||
|
||||
error: SerializableException
|
||||
"""The error that occurred."""
|
||||
@@ -0,0 +1,17 @@
|
||||
from ._digraph_group_chat import (
|
||||
DiGraph,
|
||||
DiGraphEdge,
|
||||
DiGraphNode,
|
||||
GraphFlow,
|
||||
GraphFlowManager,
|
||||
)
|
||||
from ._graph_builder import DiGraphBuilder
|
||||
|
||||
__all__ = [
|
||||
"GraphFlow",
|
||||
"DiGraph",
|
||||
"GraphFlowManager",
|
||||
"DiGraphNode",
|
||||
"DiGraphEdge",
|
||||
"DiGraphBuilder",
|
||||
]
|
||||
@@ -0,0 +1,877 @@
|
||||
import asyncio
|
||||
from collections import Counter, deque
|
||||
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set, Union
|
||||
|
||||
from agentdhal_core import AgentRuntime, Component, ComponentModel
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from agentdhal_agentchat.base import ChatAgent, TerminationCondition
|
||||
from agentdhal_agentchat.messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
MessageFactory,
|
||||
StopMessage,
|
||||
)
|
||||
from agentdhal_agentchat.state import BaseGroupChatManagerState
|
||||
from agentdhal_agentchat.teams import BaseGroupChat
|
||||
|
||||
from ..._group_chat._base_group_chat_manager import BaseGroupChatManager
|
||||
from ..._group_chat._events import GroupChatTermination
|
||||
|
||||
_DIGRAPH_STOP_MESSAGE = "Digraph execution is complete"
|
||||
|
||||
|
||||
class DiGraphEdge(BaseModel):
|
||||
"""Represents a directed edge in a :class:`DiGraph`, with an optional execution condition.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental feature, and the API will change in the future releases.
|
||||
|
||||
.. warning::
|
||||
|
||||
If the condition is a callable, it will not be serialized in the model.
|
||||
|
||||
"""
|
||||
|
||||
target: str # Target node name
|
||||
condition: Union[str, Callable[[BaseChatMessage], bool], None] = Field(default=None)
|
||||
"""(Experimental) Condition to execute this edge.
|
||||
If None, the edge is unconditional.
|
||||
If a string, the edge is conditional on the presence of that string in the last agent chat message.
|
||||
If a callable, the edge is conditional on the callable returning True when given the last message.
|
||||
"""
|
||||
|
||||
# Using Field to exclude the condition in serialization if it's a callable
|
||||
condition_function: Callable[[BaseChatMessage], bool] | None = Field(default=None, exclude=True)
|
||||
activation_group: str = Field(default="")
|
||||
"""Group identifier for forward dependencies.
|
||||
|
||||
When multiple edges point to the same target node, they are grouped by this field.
|
||||
This allows distinguishing between different cycles or dependency patterns.
|
||||
|
||||
Example: In a graph containing a cycle like A->B->C->B, the two edges pointing to B (A->B and C->B)
|
||||
can be in different activation groups to control how B is activated.
|
||||
Defaults to the target node name if not specified.
|
||||
"""
|
||||
activation_condition: Literal["all", "any"] = "all"
|
||||
"""Determines how forward dependencies within the same activation_group are evaluated.
|
||||
|
||||
- "all": All edges in this activation group must be satisfied before the target node can execute
|
||||
- "any": Any single edge in this activation group being satisfied allows the target node to execute
|
||||
|
||||
This is used to handle complex dependency patterns in cyclic graphs where multiple
|
||||
paths can lead to the same target node.
|
||||
"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_condition(self) -> "DiGraphEdge":
|
||||
# Store callable in a separate field and set condition to None for serialization
|
||||
if callable(self.condition):
|
||||
self.condition_function = self.condition
|
||||
# For serialization purposes, we'll set the condition to None
|
||||
# when storing as a pydantic model/dict
|
||||
object.__setattr__(self, "condition", None)
|
||||
|
||||
# Set activation_group to target if not already set
|
||||
if not self.activation_group:
|
||||
self.activation_group = self.target
|
||||
|
||||
return self
|
||||
|
||||
def check_condition(self, message: BaseChatMessage) -> bool:
|
||||
"""Check if the edge condition is satisfied for the given message.
|
||||
|
||||
Args:
|
||||
message: The message to check the condition against.
|
||||
|
||||
Returns:
|
||||
True if condition is satisfied (None condition always returns True),
|
||||
False otherwise.
|
||||
"""
|
||||
if self.condition_function is not None:
|
||||
return self.condition_function(message)
|
||||
elif isinstance(self.condition, str):
|
||||
# If it's a string, check if the string is in the message content
|
||||
return self.condition in message.to_model_text()
|
||||
return True # None condition is always satisfied
|
||||
|
||||
|
||||
class DiGraphNode(BaseModel):
|
||||
"""Represents a node (agent) in a :class:`DiGraph`, with its outgoing edges and activation type.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental feature, and the API will change in the future releases.
|
||||
|
||||
"""
|
||||
|
||||
name: str # Agent's name
|
||||
edges: List[DiGraphEdge] = [] # Outgoing edges
|
||||
activation: Literal["all", "any"] = "all"
|
||||
|
||||
|
||||
class DiGraph(BaseModel):
|
||||
"""Defines a directed graph structure with nodes and edges.
|
||||
:class:`GraphFlow` uses this to determine execution order and conditions.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental feature, and the API will change in the future releases.
|
||||
|
||||
"""
|
||||
|
||||
nodes: Dict[str, DiGraphNode] # Node name → DiGraphNode mapping
|
||||
default_start_node: str | None = None # Default start node name
|
||||
_has_cycles: bool | None = None # Cyclic graph flag
|
||||
|
||||
def get_parents(self) -> Dict[str, List[str]]:
|
||||
"""Compute a mapping of each node to its parent nodes."""
|
||||
parents: Dict[str, List[str]] = {node: [] for node in self.nodes}
|
||||
for node in self.nodes.values():
|
||||
for edge in node.edges:
|
||||
parents[edge.target].append(node.name)
|
||||
return parents
|
||||
|
||||
def get_start_nodes(self) -> Set[str]:
|
||||
"""Return the nodes that have no incoming edges (entry points)."""
|
||||
if self.default_start_node:
|
||||
return {self.default_start_node}
|
||||
|
||||
parents = self.get_parents()
|
||||
return set([node_name for node_name, parent_list in parents.items() if not parent_list])
|
||||
|
||||
def get_leaf_nodes(self) -> Set[str]:
|
||||
"""Return nodes that have no outgoing edges (final output nodes)."""
|
||||
return set([name for name, node in self.nodes.items() if not node.edges])
|
||||
|
||||
def has_cycles_with_exit(self) -> bool:
|
||||
"""
|
||||
Check if the graph has any cycles and validate that each cycle has at least one conditional edge.
|
||||
|
||||
Returns:
|
||||
bool: True if there is at least one cycle and all cycles have an exit condition.
|
||||
False if there are no cycles.
|
||||
|
||||
Raises:
|
||||
ValueError: If there is a cycle without any conditional edge.
|
||||
"""
|
||||
visited: Set[str] = set()
|
||||
rec_stack: Set[str] = set()
|
||||
path: List[str] = []
|
||||
|
||||
def dfs(node_name: str) -> bool:
|
||||
visited.add(node_name)
|
||||
rec_stack.add(node_name)
|
||||
path.append(node_name)
|
||||
|
||||
for edge in self.nodes[node_name].edges:
|
||||
target = edge.target
|
||||
if target not in visited:
|
||||
if dfs(target):
|
||||
return True
|
||||
elif target in rec_stack:
|
||||
# Found a cycle → extract the cycle
|
||||
cycle_start_index = path.index(target)
|
||||
cycle_nodes = path[cycle_start_index:]
|
||||
cycle_edges: List[DiGraphEdge] = []
|
||||
for n in cycle_nodes:
|
||||
cycle_edges.extend(self.nodes[n].edges)
|
||||
if all(edge.condition is None and edge.condition_function is None for edge in cycle_edges):
|
||||
raise ValueError(
|
||||
f"Cycle detected without exit condition: {' -> '.join(cycle_nodes + cycle_nodes[:1])}"
|
||||
)
|
||||
return True # Found cycle, but it has an exit condition
|
||||
|
||||
rec_stack.remove(node_name)
|
||||
path.pop()
|
||||
return False
|
||||
|
||||
has_cycle = False
|
||||
for node in self.nodes:
|
||||
if node not in visited:
|
||||
if dfs(node):
|
||||
has_cycle = True
|
||||
|
||||
return has_cycle
|
||||
|
||||
def get_has_cycles(self) -> bool:
|
||||
"""Indicates if the graph has at least one cycle (with valid exit conditions)."""
|
||||
if self._has_cycles is None:
|
||||
self._has_cycles = self.has_cycles_with_exit()
|
||||
|
||||
return self._has_cycles
|
||||
|
||||
def graph_validate(self) -> None:
|
||||
"""Validate graph structure and execution rules."""
|
||||
if not self.nodes:
|
||||
raise ValueError("Graph has no nodes.")
|
||||
|
||||
if not self.get_start_nodes():
|
||||
raise ValueError("Graph must have at least one start node")
|
||||
|
||||
if not self.get_leaf_nodes():
|
||||
raise ValueError("Graph must have at least one leaf node")
|
||||
|
||||
# Outgoing edge condition validation (per node)
|
||||
for node in self.nodes.values():
|
||||
# Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional
|
||||
has_condition = any(
|
||||
edge.condition is not None or edge.condition_function is not None for edge in node.edges
|
||||
)
|
||||
has_unconditioned = any(edge.condition is None and edge.condition_function is None for edge in node.edges)
|
||||
if has_condition and has_unconditioned:
|
||||
raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.")
|
||||
|
||||
# Validate activation conditions across all edges in the graph
|
||||
self._validate_activation_conditions()
|
||||
|
||||
self._has_cycles = self.has_cycles_with_exit()
|
||||
|
||||
def _validate_activation_conditions(self) -> None:
|
||||
"""Validate that all edges pointing to the same target node have consistent activation_condition values.
|
||||
|
||||
Raises:
|
||||
ValueError: If edges pointing to the same target have different activation_condition values
|
||||
"""
|
||||
target_activation_conditions: Dict[str, Dict[str, str]] = {} # target_node -> {activation_group -> condition}
|
||||
|
||||
for node in self.nodes.values():
|
||||
for edge in node.edges:
|
||||
target = edge.target # The target node this edge points to
|
||||
activation_group = edge.activation_group
|
||||
|
||||
if target not in target_activation_conditions:
|
||||
target_activation_conditions[target] = {}
|
||||
|
||||
if activation_group in target_activation_conditions[target]:
|
||||
if target_activation_conditions[target][activation_group] != edge.activation_condition:
|
||||
# Find the source node that has the conflicting condition
|
||||
conflicting_source = self._find_edge_source_by_target_and_group(
|
||||
target, activation_group, target_activation_conditions[target][activation_group]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Conflicting activation conditions for target '{target}' group '{activation_group}': "
|
||||
f"'{target_activation_conditions[target][activation_group]}' (from node '{conflicting_source}') "
|
||||
f"and '{edge.activation_condition}' (from node '{node.name}')"
|
||||
)
|
||||
else:
|
||||
target_activation_conditions[target][activation_group] = edge.activation_condition
|
||||
|
||||
def _find_edge_source_by_target_and_group(
|
||||
self, target: str, activation_group: str, activation_condition: str
|
||||
) -> str:
|
||||
"""Find the source node that has an edge pointing to the given target with the given activation_group and activation_condition."""
|
||||
for node_name, node in self.nodes.items():
|
||||
for edge in node.edges:
|
||||
if (
|
||||
edge.target == target
|
||||
and edge.activation_group == activation_group
|
||||
and edge.activation_condition == activation_condition
|
||||
):
|
||||
return node_name
|
||||
return "unknown"
|
||||
|
||||
def get_remaining_map(self) -> Dict[str, Dict[str, int]]:
|
||||
"""Get the remaining map that tracks how many edges point to each target node with each activation group.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping target nodes to their activation groups and remaining counts
|
||||
"""
|
||||
|
||||
remaining_map: Dict[str, Dict[str, int]] = {}
|
||||
|
||||
for node in self.nodes.values():
|
||||
for edge in node.edges:
|
||||
target = edge.target
|
||||
activation_group = edge.activation_group
|
||||
|
||||
if target not in remaining_map:
|
||||
remaining_map[target] = {}
|
||||
|
||||
if activation_group not in remaining_map[target]:
|
||||
remaining_map[target][activation_group] = 0
|
||||
|
||||
remaining_map[target][activation_group] += 1
|
||||
|
||||
return remaining_map
|
||||
|
||||
|
||||
class GraphFlowManagerState(BaseGroupChatManagerState):
|
||||
"""Tracks active execution state for DAG-based execution."""
|
||||
|
||||
active_nodes: List[str] = [] # Currently executing nodes
|
||||
type: str = "GraphManagerState"
|
||||
|
||||
|
||||
class GraphFlowManager(BaseGroupChatManager):
|
||||
"""Manages execution of agents using a Directed Graph execution model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
graph: DiGraph,
|
||||
) -> None:
|
||||
"""Initialize the graph-based execution manager."""
|
||||
super().__init__(
|
||||
name=name,
|
||||
group_topic_type=group_topic_type,
|
||||
output_topic_type=output_topic_type,
|
||||
participant_topic_types=participant_topic_types,
|
||||
participant_names=participant_names,
|
||||
participant_descriptions=participant_descriptions,
|
||||
output_message_queue=output_message_queue,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
message_factory=message_factory,
|
||||
)
|
||||
graph.graph_validate()
|
||||
if graph.get_has_cycles() and self._termination_condition is None and self._max_turns is None:
|
||||
raise ValueError("A termination condition is required for cyclic graphs without a maximum turn limit.")
|
||||
self._graph = graph
|
||||
# Lookup table for incoming edges for each node.
|
||||
self._parents = graph.get_parents()
|
||||
# Lookup table for outgoing edges for each node.
|
||||
self._edges: Dict[str, List[DiGraphEdge]] = {n: node.edges for n, node in graph.nodes.items()}
|
||||
|
||||
# Build activation and enqueued_any lookup tables by collecting all edges and grouping by target node
|
||||
self._build_lookup_tables(graph)
|
||||
|
||||
# Track which activation groups were triggered for each node
|
||||
self._triggered_activation_groups: Dict[str, Set[str]] = {}
|
||||
# === Mutable states for the graph execution ===
|
||||
# Count the number of remaining parents to activate each node.
|
||||
self._remaining: Dict[str, Counter[str]] = {
|
||||
target: Counter(groups) for target, groups in graph.get_remaining_map().items()
|
||||
}
|
||||
# cache for remaining
|
||||
self._origin_remaining: Dict[str, Dict[str, int]] = {
|
||||
target: Counter(groups) for target, groups in self._remaining.items()
|
||||
}
|
||||
|
||||
# Ready queue for nodes that are ready to execute, starting with the start nodes.
|
||||
self._ready: Deque[str] = deque([n for n in graph.get_start_nodes()])
|
||||
|
||||
def _build_lookup_tables(self, graph: DiGraph) -> None:
|
||||
"""Build activation and enqueued_any lookup tables by collecting all edges and grouping by target node.
|
||||
|
||||
Args:
|
||||
graph: The directed graph
|
||||
"""
|
||||
self._activation: Dict[str, Dict[str, Literal["any", "all"]]] = {}
|
||||
self._enqueued_any: Dict[str, Dict[str, bool]] = {}
|
||||
|
||||
for node in graph.nodes.values():
|
||||
for edge in node.edges:
|
||||
target = edge.target
|
||||
activation_group = edge.activation_group
|
||||
|
||||
# Build activation lookup
|
||||
if target not in self._activation:
|
||||
self._activation[target] = {}
|
||||
if activation_group not in self._activation[target]:
|
||||
self._activation[target][activation_group] = edge.activation_condition
|
||||
|
||||
# Build enqueued_any lookup
|
||||
if target not in self._enqueued_any:
|
||||
self._enqueued_any[target] = {}
|
||||
if activation_group not in self._enqueued_any[target]:
|
||||
self._enqueued_any[target][activation_group] = False
|
||||
|
||||
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
|
||||
await super().update_message_thread(messages)
|
||||
|
||||
# Find the node that ran in the current turn.
|
||||
message = messages[-1]
|
||||
if message.source not in self._graph.nodes:
|
||||
# Ignore messages from sources outside of the graph.
|
||||
return
|
||||
assert isinstance(message, BaseChatMessage)
|
||||
source = message.source
|
||||
|
||||
# Propagate the update to the children of the node.
|
||||
for edge in self._edges[source]:
|
||||
# Use the new check_condition method that handles both string and callable conditions
|
||||
if not edge.check_condition(message):
|
||||
continue
|
||||
|
||||
target = edge.target
|
||||
activation_group = edge.activation_group
|
||||
|
||||
if self._activation[target][activation_group] == "all":
|
||||
self._remaining[target][activation_group] -= 1
|
||||
if self._remaining[target][activation_group] == 0:
|
||||
# If all parents are done, add to the ready queue.
|
||||
self._ready.append(target)
|
||||
# Track which activation group was triggered
|
||||
self._save_triggered_activation_group(target, activation_group)
|
||||
else:
|
||||
# If activation is any, add to the ready queue if not already enqueued.
|
||||
if not self._enqueued_any[target][activation_group]:
|
||||
self._ready.append(target)
|
||||
self._enqueued_any[target][activation_group] = True
|
||||
# Track which activation group was triggered
|
||||
self._save_triggered_activation_group(target, activation_group)
|
||||
|
||||
def _save_triggered_activation_group(self, target: str, activation_group: str) -> None:
|
||||
"""Save which activation group was triggered for a target node.
|
||||
|
||||
Args:
|
||||
target: The target node that was triggered
|
||||
activation_group: The activation group that caused the trigger
|
||||
"""
|
||||
if target not in self._triggered_activation_groups:
|
||||
self._triggered_activation_groups[target] = set()
|
||||
self._triggered_activation_groups[target].add(activation_group)
|
||||
|
||||
def _reset_triggered_activation_groups(self, speaker: str) -> None:
|
||||
"""Reset the bookkeeping for the specific activation groups that were triggered for a speaker.
|
||||
|
||||
Args:
|
||||
speaker: The speaker node to reset activation groups for
|
||||
"""
|
||||
if speaker not in self._triggered_activation_groups:
|
||||
return
|
||||
|
||||
for activation_group in self._triggered_activation_groups[speaker]:
|
||||
if self._activation[speaker][activation_group] == "any":
|
||||
self._enqueued_any[speaker][activation_group] = False
|
||||
else:
|
||||
# Reset the remaining count for this activation group using the graph's original count
|
||||
if speaker in self._remaining and activation_group in self._remaining[speaker]:
|
||||
self._remaining[speaker][activation_group] = self._origin_remaining[speaker][activation_group]
|
||||
|
||||
# Clear the triggered activation groups for this speaker
|
||||
self._triggered_activation_groups[speaker].clear()
|
||||
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]:
|
||||
# Drain the ready queue for the next set of speakers.
|
||||
speakers: List[str] = []
|
||||
while self._ready:
|
||||
speaker = self._ready.popleft()
|
||||
speakers.append(speaker)
|
||||
|
||||
# Reset the bookkeeping for the specific activation groups that were triggered
|
||||
self._reset_triggered_activation_groups(speaker)
|
||||
|
||||
return speakers
|
||||
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
pass
|
||||
|
||||
async def _apply_termination_condition(
|
||||
self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False
|
||||
) -> bool:
|
||||
"""Apply termination condition including graph-specific completion logic.
|
||||
|
||||
First checks if graph execution is complete, then checks standard termination conditions.
|
||||
|
||||
Args:
|
||||
delta: The message delta to check termination conditions against
|
||||
increment_turn_count: Whether to increment the turn count
|
||||
|
||||
Returns:
|
||||
True if the conversation should be terminated, False otherwise
|
||||
"""
|
||||
# Check if the graph execution is complete (no ready speakers) - prioritize this check
|
||||
if not self._ready:
|
||||
stop_message = StopMessage(
|
||||
content=_DIGRAPH_STOP_MESSAGE,
|
||||
source=self._name,
|
||||
)
|
||||
# Reset the execution state when the graph has naturally completed
|
||||
self._reset_execution_state()
|
||||
# Reset the termination conditions and turn count.
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._current_turn = 0
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(stop_message)
|
||||
return True
|
||||
|
||||
# Apply the standard termination conditions from the base class
|
||||
return await super()._apply_termination_condition(delta, increment_turn_count)
|
||||
|
||||
def _reset_execution_state(self) -> None:
|
||||
"""Reset the graph execution state to the initial state."""
|
||||
self._remaining = {target: Counter(groups) for target, groups in self._graph.get_remaining_map().items()}
|
||||
self._enqueued_any = {n: {g: False for g in self._enqueued_any[n]} for n in self._enqueued_any}
|
||||
self._ready = deque([n for n in self._graph.get_start_nodes()])
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the execution state."""
|
||||
state = {
|
||||
"message_thread": [message.dump() for message in self._message_thread],
|
||||
"current_turn": self._current_turn,
|
||||
"remaining": {target: dict(counter) for target, counter in self._remaining.items()},
|
||||
"enqueued_any": dict(self._enqueued_any),
|
||||
"ready": list(self._ready),
|
||||
}
|
||||
return state
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Restore execution state from saved data."""
|
||||
self._message_thread = [self._message_factory.create(msg) for msg in state["message_thread"]]
|
||||
self._current_turn = state["current_turn"]
|
||||
self._remaining = {target: Counter(groups) for target, groups in state["remaining"].items()}
|
||||
self._enqueued_any = state["enqueued_any"]
|
||||
self._ready = deque(state["ready"])
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset execution state to the start of the graph."""
|
||||
self._current_turn = 0
|
||||
self._message_thread.clear()
|
||||
if self._termination_condition:
|
||||
await self._termination_condition.reset()
|
||||
self._reset_execution_state()
|
||||
|
||||
|
||||
class GraphFlowConfig(BaseModel):
|
||||
"""The declarative configuration for GraphFlow."""
|
||||
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
participants: List[ComponentModel]
|
||||
termination_condition: ComponentModel | None = None
|
||||
max_turns: int | None = None
|
||||
graph: DiGraph # The execution graph for agents
|
||||
|
||||
|
||||
class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
|
||||
"""A team that runs a group chat following a Directed Graph execution pattern.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental feature, and the API will change in the future releases.
|
||||
|
||||
This group chat executes agents based on a directed graph (:class:`DiGraph`) structure,
|
||||
allowing complex workflows such as sequential execution, parallel fan-out,
|
||||
conditional branching, join patterns, and loops with explicit exit conditions.
|
||||
|
||||
The execution order is determined by the edges defined in the `DiGraph`. Each node
|
||||
in the graph corresponds to an agent, and edges define the flow of messages between agents.
|
||||
Nodes can be configured to activate when:
|
||||
|
||||
- **All** parent nodes have completed (activation="all") → default
|
||||
- **Any** parent node completes (activation="any")
|
||||
|
||||
Conditional branching is supported using edge conditions, where the next agent(s) are selected
|
||||
based on content in the chat history. Loops are permitted as long as there is a condition
|
||||
that eventually exits the loop.
|
||||
|
||||
.. note::
|
||||
|
||||
Use the :class:`DiGraphBuilder` class to create a :class:`DiGraph` easily. It provides a fluent API
|
||||
for adding nodes and edges, setting entry points, and validating the graph structure.
|
||||
See the :class:`DiGraphBuilder` documentation for more details.
|
||||
The :class:`GraphFlow` class is designed to be used with the :class:`DiGraphBuilder` for creating complex workflows.
|
||||
|
||||
.. warning::
|
||||
|
||||
When using callable conditions in edges, they will not be serialized
|
||||
when calling :meth:`dump_component`. This will be addressed in future releases.
|
||||
|
||||
|
||||
Args:
|
||||
participants (List[ChatAgent]): The participants in the group chat.
|
||||
termination_condition (TerminationCondition, optional): Termination condition for the chat.
|
||||
max_turns (int, optional): Maximum number of turns before forcing termination.
|
||||
graph (DiGraph): Directed execution graph defining node flow and conditions.
|
||||
|
||||
Raises:
|
||||
ValueError: If participant names are not unique, or if graph validation fails (e.g., cycles without exit).
|
||||
|
||||
Examples:
|
||||
|
||||
**Sequential Flow: A → B → C**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import DiGraphBuilder, GraphFlow
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize agents with OpenAI model clients.
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
|
||||
agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.")
|
||||
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to English.")
|
||||
|
||||
# Create a directed graph with sequential flow A -> B -> C.
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
builder.add_edge(agent_a, agent_b).add_edge(agent_b, agent_c)
|
||||
graph = builder.build()
|
||||
|
||||
# Create a GraphFlow team with the directed graph.
|
||||
team = GraphFlow(
|
||||
participants=[agent_a, agent_b, agent_c],
|
||||
graph=graph,
|
||||
termination_condition=MaxMessageTermination(5),
|
||||
)
|
||||
|
||||
# Run the team and print the events.
|
||||
async for event in team.run_stream(task="Write a short story about a cat."):
|
||||
print(event)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
**Parallel Fan-out: A → (B, C)**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import DiGraphBuilder, GraphFlow
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize agents with OpenAI model clients.
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
|
||||
agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.")
|
||||
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Japanese.")
|
||||
|
||||
# Create a directed graph with fan-out flow A -> (B, C).
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
|
||||
graph = builder.build()
|
||||
|
||||
# Create a GraphFlow team with the directed graph.
|
||||
team = GraphFlow(
|
||||
participants=[agent_a, agent_b, agent_c],
|
||||
graph=graph,
|
||||
termination_condition=MaxMessageTermination(5),
|
||||
)
|
||||
|
||||
# Run the team and print the events.
|
||||
async for event in team.run_stream(task="Write a short story about a cat."):
|
||||
print(event)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
**Conditional Branching: A → B (if 'yes') or C (otherwise)**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import DiGraphBuilder, GraphFlow
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize agents with OpenAI model clients.
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
|
||||
agent_a = AssistantAgent(
|
||||
"A",
|
||||
model_client=model_client,
|
||||
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
|
||||
)
|
||||
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
|
||||
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
|
||||
|
||||
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
# Create conditions as callables that check the message content.
|
||||
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
|
||||
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
|
||||
graph = builder.build()
|
||||
|
||||
# Create a GraphFlow team with the directed graph.
|
||||
team = GraphFlow(
|
||||
participants=[agent_a, agent_b, agent_c],
|
||||
graph=graph,
|
||||
termination_condition=MaxMessageTermination(5),
|
||||
)
|
||||
|
||||
# Run the team and print the events.
|
||||
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
|
||||
print(event)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
**Loop with exit condition: A → B → C (if 'APPROVE') or A (otherwise)**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import DiGraphBuilder, GraphFlow
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize agents with OpenAI model clients.
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1")
|
||||
agent_a = AssistantAgent(
|
||||
"A",
|
||||
model_client=model_client,
|
||||
system_message="You are a helpful assistant.",
|
||||
)
|
||||
agent_b = AssistantAgent(
|
||||
"B",
|
||||
model_client=model_client,
|
||||
system_message="Provide feedback on the input, if your feedback has been addressed, "
|
||||
"say 'APPROVE', otherwise provide a reason for rejection.",
|
||||
)
|
||||
agent_c = AssistantAgent(
|
||||
"C", model_client=model_client, system_message="Translate the final product to Korean."
|
||||
)
|
||||
|
||||
# Create a loop graph with conditional exit: A -> B -> C ("APPROVE"), B -> A (otherwise).
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
builder.add_edge(agent_a, agent_b)
|
||||
|
||||
# Create conditional edges using strings
|
||||
builder.add_edge(agent_b, agent_c, condition=lambda msg: "APPROVE" in msg.to_model_text())
|
||||
builder.add_edge(agent_b, agent_a, condition=lambda msg: "APPROVE" not in msg.to_model_text())
|
||||
|
||||
builder.set_entry_point(agent_a)
|
||||
graph = builder.build()
|
||||
|
||||
# Create a GraphFlow team with the directed graph.
|
||||
team = GraphFlow(
|
||||
participants=[agent_a, agent_b, agent_c],
|
||||
graph=graph,
|
||||
termination_condition=MaxMessageTermination(20), # Max 20 messages to avoid infinite loop.
|
||||
)
|
||||
|
||||
# Run the team and print the events.
|
||||
async for event in team.run_stream(task="Write a short poem about AI Agents."):
|
||||
print(event)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_config_schema = GraphFlowConfig
|
||||
component_provider_override = "agentdhal_agentchat.teams.GraphFlow"
|
||||
|
||||
DEFAULT_NAME = "GraphFlow"
|
||||
DEFAULT_DESCRIPTION = "A team of agents"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
participants: List[ChatAgent],
|
||||
graph: DiGraph,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
) -> None:
|
||||
self._input_participants = participants
|
||||
self._input_termination_condition = termination_condition
|
||||
|
||||
for participant in participants:
|
||||
if not isinstance(participant, ChatAgent):
|
||||
raise TypeError(f"Participant {participant} must be a ChatAgent.")
|
||||
|
||||
# No longer add _StopAgent or StopMessageTermination
|
||||
# Termination is now handled directly in GraphFlowManager._apply_termination_condition
|
||||
super().__init__(
|
||||
name=name or self.DEFAULT_NAME,
|
||||
description=description or self.DEFAULT_DESCRIPTION,
|
||||
participants=list(participants),
|
||||
group_chat_manager_name="GraphManager",
|
||||
group_chat_manager_class=GraphFlowManager,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
)
|
||||
self._graph = graph
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], GraphFlowManager]:
|
||||
"""Creates the factory method for initializing the DiGraph-based chat manager."""
|
||||
|
||||
def _factory() -> GraphFlowManager:
|
||||
return GraphFlowManager(
|
||||
name=name,
|
||||
group_topic_type=group_topic_type,
|
||||
output_topic_type=output_topic_type,
|
||||
participant_topic_types=participant_topic_types,
|
||||
participant_names=participant_names,
|
||||
participant_descriptions=participant_descriptions,
|
||||
output_message_queue=output_message_queue,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
message_factory=message_factory,
|
||||
graph=self._graph,
|
||||
)
|
||||
|
||||
return _factory
|
||||
|
||||
def _to_config(self) -> GraphFlowConfig:
|
||||
"""Converts the instance into a configuration object."""
|
||||
participants = [participant.dump_component() for participant in self._input_participants]
|
||||
termination_condition = (
|
||||
self._input_termination_condition.dump_component() if self._input_termination_condition else None
|
||||
)
|
||||
return GraphFlowConfig(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
participants=participants,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=self._max_turns,
|
||||
graph=self._graph,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: GraphFlowConfig) -> Self:
|
||||
"""Reconstructs an instance from a configuration object."""
|
||||
participants = [ChatAgent.load_component(participant) for participant in config.participants]
|
||||
termination_condition = (
|
||||
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
|
||||
)
|
||||
return cls(
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
participants=participants,
|
||||
graph=config.graph,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=config.max_turns,
|
||||
)
|
||||
@@ -0,0 +1,209 @@
|
||||
import warnings
|
||||
from typing import Callable, Dict, Literal, Optional, Union
|
||||
|
||||
from agentdhal_agentchat.base import ChatAgent
|
||||
from agentdhal_agentchat.messages import BaseChatMessage
|
||||
|
||||
from ._digraph_group_chat import DiGraph, DiGraphEdge, DiGraphNode
|
||||
|
||||
|
||||
class DiGraphBuilder:
|
||||
"""
|
||||
A fluent builder for constructing :class:`DiGraph` execution graphs used in :class:`GraphFlow`.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental feature, and the API will change in the future releases.
|
||||
|
||||
This utility provides a convenient way to programmatically build a graph of agent interactions,
|
||||
including complex execution flows such as:
|
||||
|
||||
- Sequential chains
|
||||
- Parallel fan-outs
|
||||
- Conditional branching
|
||||
- Cyclic loops with safe exits
|
||||
|
||||
Each node in the graph represents an agent. Edges define execution paths between agents,
|
||||
and can optionally be conditioned on message content using callable functions.
|
||||
|
||||
The builder is compatible with the `Graph` runner and supports both standard and filtered agents.
|
||||
|
||||
Methods:
|
||||
- add_node(agent, activation): Add an agent node to the graph.
|
||||
- add_edge(source, target, condition): Connect two nodes optionally with a condition.
|
||||
- add_conditional_edges(source, condition_to_target): Add multiple conditional edges from a source.
|
||||
- set_entry_point(agent): Define the default start node (optional).
|
||||
- build(): Generate a validated `DiGraph`.
|
||||
- get_participants(): Return the list of added agents.
|
||||
|
||||
Example — Sequential Flow A → B → C:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> builder.add_edge(agent_a, agent_b).add_edge(agent_b, agent_c)
|
||||
>>> team = Graph(
|
||||
... participants=builder.get_participants(),
|
||||
... graph=builder.build(),
|
||||
... termination_condition=MaxMessageTermination(5),
|
||||
... )
|
||||
|
||||
Example — Parallel Fan-out A → (B, C):
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
|
||||
|
||||
Example — Conditional Branching A → B or A → C:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> # Add conditional edges using keyword check
|
||||
>>> builder.add_edge(agent_a, agent_b, condition="keyword1")
|
||||
>>> builder.add_edge(agent_a, agent_c, condition="keyword2")
|
||||
|
||||
|
||||
Example — Using Custom String Conditions:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> # Add condition strings to check in messages
|
||||
>>> builder.add_edge(agent_a, agent_b, condition="big")
|
||||
>>> builder.add_edge(agent_a, agent_c, condition="small")
|
||||
|
||||
Example — Loop: A → B → A or B → C:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> builder.add_edge(agent_a, agent_b)
|
||||
>> # Add a loop back to agent A
|
||||
>>> builder.add_edge(agent_b, agent_a, condition=lambda msg: "loop" in msg.to_model_text())
|
||||
>>> # Add exit condition to break the loop
|
||||
>>> builder.add_edge(agent_b, agent_c, condition=lambda msg: "loop" not in msg.to_model_text())
|
||||
|
||||
Example — Loop with multiple paths to the same node: A → B → C → B:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> builder.add_edge(agent_a, agent_b)
|
||||
>>> builder.add_edge(agent_b, agent_c)
|
||||
>>> builder.add_edge(agent_c, agent_b, activation_group="loop_back")
|
||||
|
||||
Example — Loop with multiple paths to the same node with any activation condition: A → B → (C1, C2) → B → E(exit):
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c1).add_node(agent_c2).add_node(agent_e)
|
||||
>>> builder.add_edge(agent_a, agent_b)
|
||||
>>> builder.add_edge(agent_b, agent_c1)
|
||||
>>> builder.add_edge(agent_b, agent_c2)
|
||||
>>> builder.add_edge(agent_b, agent_e, condition="exit")
|
||||
>>> builder.add_edge(agent_c1, agent_b, activation_group="loop_back_group", activation_condition="any")
|
||||
>>> builder.add_edge(agent_c2, agent_b, activation_group="loop_back_group", activation_condition="any")
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.nodes: Dict[str, DiGraphNode] = {}
|
||||
self.agents: Dict[str, ChatAgent] = {}
|
||||
self._default_start_node: Optional[str] = None
|
||||
|
||||
def _get_name(self, obj: Union[str, ChatAgent]) -> str:
|
||||
return obj if isinstance(obj, str) else obj.name
|
||||
|
||||
def add_node(self, agent: ChatAgent, activation: Literal["all", "any"] = "all") -> "DiGraphBuilder":
|
||||
"""Add a node to the graph and register its agent."""
|
||||
name = agent.name
|
||||
if name not in self.nodes:
|
||||
self.nodes[name] = DiGraphNode(name=name, edges=[], activation=activation)
|
||||
self.agents[name] = agent
|
||||
return self
|
||||
|
||||
def add_edge(
|
||||
self,
|
||||
source: Union[str, ChatAgent],
|
||||
target: Union[str, ChatAgent],
|
||||
condition: Optional[Union[str, Callable[[BaseChatMessage], bool]]] = None,
|
||||
activation_group: Optional[str] = None,
|
||||
activation_condition: Optional[Literal["all", "any"]] = None,
|
||||
) -> "DiGraphBuilder":
|
||||
"""Add a directed edge from source to target, optionally with a condition.
|
||||
|
||||
Args:
|
||||
source: Source node (agent name or agent object)
|
||||
target: Target node (agent name or agent object)
|
||||
condition: Optional condition for edge activation.
|
||||
If string, activates when substring is found in message.
|
||||
If callable, activates when function returns True for the message.
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
|
||||
Raises:
|
||||
ValueError: If source or target node doesn't exist in the builder
|
||||
"""
|
||||
source_name = self._get_name(source)
|
||||
target_name = self._get_name(target)
|
||||
|
||||
if source_name not in self.nodes:
|
||||
raise ValueError(f"Source node '{source_name}' must be added before adding an edge.")
|
||||
if target_name not in self.nodes:
|
||||
raise ValueError(f"Target node '{target_name}' must be added before adding an edge.")
|
||||
if activation_group is None:
|
||||
activation_group = target_name
|
||||
if activation_condition is None:
|
||||
activation_condition = "all"
|
||||
self.nodes[source_name].edges.append(
|
||||
DiGraphEdge(
|
||||
target=target_name,
|
||||
condition=condition,
|
||||
activation_group=activation_group,
|
||||
activation_condition=activation_condition,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def add_conditional_edges(
|
||||
self, source: Union[str, ChatAgent], condition_to_target: Dict[str, Union[str, ChatAgent]]
|
||||
) -> "DiGraphBuilder":
|
||||
"""Add multiple conditional edges from a source node based on keyword checks.
|
||||
|
||||
.. warning::
|
||||
|
||||
This method interface will be changed in the future to support callable conditions.
|
||||
Please use `add_edge` if you need to specify custom conditions.
|
||||
|
||||
Args:
|
||||
source: Source node (agent name or agent object)
|
||||
condition_to_target: Mapping from condition strings to target nodes
|
||||
Each key is a keyword that will be checked in the message content
|
||||
Each value is the target node to activate when condition is met
|
||||
|
||||
For each key (keyword), a lambda will be created that checks
|
||||
if the keyword is in the message text.
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"add_conditional_edges will be changed in the future to support callable conditions. "
|
||||
"For now, please use add_edge if you need to specify custom conditions.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
for condition_keyword, target in condition_to_target.items():
|
||||
self.add_edge(source, target, condition=condition_keyword)
|
||||
return self
|
||||
|
||||
def set_entry_point(self, name: Union[str, ChatAgent]) -> "DiGraphBuilder":
|
||||
"""Set the default start node of the graph."""
|
||||
node_name = self._get_name(name)
|
||||
if node_name not in self.nodes:
|
||||
raise ValueError(f"Start node '{node_name}' must be added before setting as entry point.")
|
||||
self._default_start_node = node_name
|
||||
return self
|
||||
|
||||
def build(self) -> DiGraph:
|
||||
"""Build and validate the DiGraph."""
|
||||
graph = DiGraph(
|
||||
nodes=self.nodes,
|
||||
default_start_node=self._default_start_node,
|
||||
)
|
||||
graph.graph_validate()
|
||||
return graph
|
||||
|
||||
def get_participants(self) -> list[ChatAgent]:
|
||||
"""Return the list of agents in the builder, in insertion order."""
|
||||
return list(self.agents.values())
|
||||
@@ -0,0 +1,5 @@
|
||||
from ._magentic_one_group_chat import MagenticOneGroupChat
|
||||
|
||||
__all__ = [
|
||||
"MagenticOneGroupChat",
|
||||
]
|
||||
@@ -0,0 +1,209 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Callable, List
|
||||
|
||||
from agentdhal_core import AgentRuntime, Component, ComponentModel
|
||||
from agentdhal_core.models import ChatCompletionClient
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
||||
from ....base import ChatAgent, TerminationCondition
|
||||
from ....messages import BaseAgentEvent, BaseChatMessage, MessageFactory
|
||||
from .._base_group_chat import BaseGroupChat
|
||||
from .._events import GroupChatTermination
|
||||
from ._magentic_one_orchestrator import MagenticOneOrchestrator
|
||||
from ._prompts import ORCHESTRATOR_FINAL_ANSWER_PROMPT
|
||||
|
||||
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class MagenticOneGroupChatConfig(BaseModel):
|
||||
"""The declarative configuration for a MagenticOneGroupChat."""
|
||||
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
participants: List[ComponentModel]
|
||||
model_client: ComponentModel
|
||||
termination_condition: ComponentModel | None = None
|
||||
max_turns: int | None = None
|
||||
max_stalls: int
|
||||
final_answer_prompt: str
|
||||
emit_team_events: bool = False
|
||||
|
||||
|
||||
class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig]):
|
||||
"""A team that runs a group chat with participants managed by the MagenticOneOrchestrator.
|
||||
|
||||
The orchestrator handles the conversation flow, ensuring that the task is completed
|
||||
efficiently by managing the participants' interactions.
|
||||
|
||||
The orchestrator is based on the Magentic-One architecture, which is a generalist multi-agent system for solving complex tasks (see references below).
|
||||
|
||||
Unlike :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` and :class:`~agentdhal_agentchat.teams.SelectorGroupChat`,
|
||||
the MagenticOneGroupChat does not support using team as participant.
|
||||
|
||||
Args:
|
||||
participants (List[ChatAgent]): The participants in the group chat.
|
||||
model_client (ChatCompletionClient): The model client used for generating responses.
|
||||
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
|
||||
Without a termination condition, the group chat will run based on the orchestrator logic or until the maximum number of turns is reached.
|
||||
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to 20.
|
||||
max_stalls (int, optional): The maximum number of stalls allowed before re-planning. Defaults to 3.
|
||||
final_answer_prompt (str, optional): The LLM prompt used to generate the final answer or response from the team's transcript. A default (sensible for GPT-4o class models) is provided.
|
||||
custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): A list of custom message types that will be used in the group chat.
|
||||
If you are using custom message types or your agents produces custom message types, you need to specify them here.
|
||||
Make sure your custom message types are subclasses of :class:`~agentdhal_agentchat.messages.BaseAgentEvent` or :class:`~agentdhal_agentchat.messages.BaseChatMessage`.
|
||||
emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
|
||||
|
||||
Raises:
|
||||
ValueError: In orchestration logic if progress ledger does not have required keys or if next speaker is not valid.
|
||||
|
||||
Examples:
|
||||
|
||||
MagenticOneGroupChat with one assistant agent:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import MagenticOneGroupChat
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
assistant = AssistantAgent(
|
||||
"Assistant",
|
||||
model_client=model_client,
|
||||
)
|
||||
team = MagenticOneGroupChat([assistant], model_client=model_client)
|
||||
await Console(team.run_stream(task="Provide a different proof to Fermat last theorem"))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
References:
|
||||
|
||||
If you use the MagenticOneGroupChat in your work, please cite the following paper:
|
||||
|
||||
.. code-block:: bibtex
|
||||
|
||||
@article{fourney2024magentic,
|
||||
title={Magentic-one: A generalist multi-agent system for solving complex tasks},
|
||||
author={Fourney, Adam and Bansal, Gagan and Mozannar, Hussein and Tan, Cheng and Salinas, Eduardo and Niedtner, Friederike and Proebsting, Grace and Bassman, Griffin and Gerrits, Jack and Alber, Jacob and others},
|
||||
journal={arXiv preprint arXiv:2411.04468},
|
||||
year={2024}
|
||||
}
|
||||
"""
|
||||
|
||||
component_config_schema = MagenticOneGroupChatConfig
|
||||
component_provider_override = "agentdhal_agentchat.teams.MagenticOneGroupChat"
|
||||
|
||||
DEFAULT_NAME = "MagenticOneGroupChat"
|
||||
DEFAULT_DESCRIPTION = "A team of agents."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
participants: List[ChatAgent],
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = 20,
|
||||
runtime: AgentRuntime | None = None,
|
||||
max_stalls: int = 3,
|
||||
final_answer_prompt: str = ORCHESTRATOR_FINAL_ANSWER_PROMPT,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
):
|
||||
for participant in participants:
|
||||
if not isinstance(participant, ChatAgent):
|
||||
raise TypeError(f"Participant {participant} must be a ChatAgent.")
|
||||
super().__init__(
|
||||
name=name or self.DEFAULT_NAME,
|
||||
description=description or self.DEFAULT_DESCRIPTION,
|
||||
participants=list(participants),
|
||||
group_chat_manager_name="MagenticOneOrchestrator",
|
||||
group_chat_manager_class=MagenticOneOrchestrator,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
|
||||
# Validate the participants.
|
||||
if len(participants) == 0:
|
||||
raise ValueError("At least one participant is required for MagenticOneGroupChat.")
|
||||
self._model_client = model_client
|
||||
self._max_stalls = max_stalls
|
||||
self._final_answer_prompt = final_answer_prompt
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], MagenticOneOrchestrator]:
|
||||
return lambda: MagenticOneOrchestrator(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
max_turns,
|
||||
message_factory,
|
||||
self._model_client,
|
||||
self._max_stalls,
|
||||
self._final_answer_prompt,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
self._emit_team_events,
|
||||
)
|
||||
|
||||
def _to_config(self) -> MagenticOneGroupChatConfig:
|
||||
participants = [participant.dump_component() for participant in self._participants]
|
||||
termination_condition = self._termination_condition.dump_component() if self._termination_condition else None
|
||||
return MagenticOneGroupChatConfig(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
participants=participants,
|
||||
model_client=self._model_client.dump_component(),
|
||||
termination_condition=termination_condition,
|
||||
max_turns=self._max_turns,
|
||||
max_stalls=self._max_stalls,
|
||||
final_answer_prompt=self._final_answer_prompt,
|
||||
emit_team_events=self._emit_team_events,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: MagenticOneGroupChatConfig) -> Self:
|
||||
participants = [ChatAgent.load_component(participant) for participant in config.participants]
|
||||
model_client = ChatCompletionClient.load_component(config.model_client)
|
||||
termination_condition = (
|
||||
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
|
||||
)
|
||||
return cls(
|
||||
participants=participants,
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
model_client=model_client,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=config.max_turns,
|
||||
max_stalls=config.max_stalls,
|
||||
final_answer_prompt=config.final_answer_prompt,
|
||||
emit_team_events=config.emit_team_events,
|
||||
)
|
||||
@@ -0,0 +1,536 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Mapping, Sequence
|
||||
|
||||
from agentdhal_core import AgentId, CancellationToken, DefaultTopicId, MessageContext, event, rpc
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
LLMMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agentdhal_core.utils import extract_json_from_str
|
||||
|
||||
from .... import TRACE_LOGGER_NAME
|
||||
from ....base import Response, TerminationCondition
|
||||
from ....messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
HandoffMessage,
|
||||
MessageFactory,
|
||||
MultiModalMessage,
|
||||
SelectSpeakerEvent,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
ToolCallSummaryMessage,
|
||||
)
|
||||
from ....state import MagenticOneOrchestratorState
|
||||
from ....utils import remove_images
|
||||
from .._base_group_chat_manager import BaseGroupChatManager
|
||||
from .._events import (
|
||||
GroupChatAgentResponse,
|
||||
GroupChatMessage,
|
||||
GroupChatRequestPublish,
|
||||
GroupChatReset,
|
||||
GroupChatStart,
|
||||
GroupChatTeamResponse,
|
||||
GroupChatTermination,
|
||||
SerializableException,
|
||||
)
|
||||
from ._prompts import (
|
||||
ORCHESTRATOR_FINAL_ANSWER_PROMPT,
|
||||
ORCHESTRATOR_PROGRESS_LEDGER_PROMPT,
|
||||
ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT,
|
||||
ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT,
|
||||
ORCHESTRATOR_TASK_LEDGER_FULL_PROMPT,
|
||||
ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT,
|
||||
ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT,
|
||||
LedgerEntry,
|
||||
)
|
||||
|
||||
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
||||
|
||||
|
||||
class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
"""The MagenticOneOrchestrator manages a group chat with ledger based orchestration."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
model_client: ChatCompletionClient,
|
||||
max_stalls: int,
|
||||
final_answer_prompt: str,
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
emit_team_events: bool,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
self._model_client = model_client
|
||||
self._max_stalls = max_stalls
|
||||
self._final_answer_prompt = final_answer_prompt
|
||||
self._max_json_retries = 10
|
||||
self._task = ""
|
||||
self._facts = ""
|
||||
self._plan = ""
|
||||
self._n_rounds = 0
|
||||
self._n_stalls = 0
|
||||
|
||||
# Produce a team description. Each agent sould appear on a single line.
|
||||
self._team_description = ""
|
||||
for topic_type, description in zip(self._participant_names, self._participant_descriptions, strict=True):
|
||||
self._team_description += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n"
|
||||
self._team_description = self._team_description.strip()
|
||||
|
||||
def _get_task_ledger_facts_prompt(self, task: str) -> str:
|
||||
return ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT.format(task=task)
|
||||
|
||||
def _get_task_ledger_plan_prompt(self, team: str) -> str:
|
||||
return ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT.format(team=team)
|
||||
|
||||
def _get_task_ledger_full_prompt(self, task: str, team: str, facts: str, plan: str) -> str:
|
||||
return ORCHESTRATOR_TASK_LEDGER_FULL_PROMPT.format(task=task, team=team, facts=facts, plan=plan)
|
||||
|
||||
def _get_progress_ledger_prompt(self, task: str, team: str, names: List[str]) -> str:
|
||||
return ORCHESTRATOR_PROGRESS_LEDGER_PROMPT.format(task=task, team=team, names=", ".join(names))
|
||||
|
||||
def _get_task_ledger_facts_update_prompt(self, task: str, facts: str) -> str:
|
||||
return ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT.format(task=task, facts=facts)
|
||||
|
||||
def _get_task_ledger_plan_update_prompt(self, team: str) -> str:
|
||||
return ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT.format(team=team)
|
||||
|
||||
def _get_final_answer_prompt(self, task: str) -> str:
|
||||
if self._final_answer_prompt == ORCHESTRATOR_FINAL_ANSWER_PROMPT:
|
||||
return ORCHESTRATOR_FINAL_ANSWER_PROMPT.format(task=task)
|
||||
else:
|
||||
return self._final_answer_prompt
|
||||
|
||||
async def _log_message(self, log_message: str) -> None:
|
||||
trace_logger.debug(log_message)
|
||||
|
||||
@rpc
|
||||
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None: # type: ignore
|
||||
"""Handle the start of a task."""
|
||||
|
||||
# Check if the conversation has already terminated.
|
||||
if self._termination_condition is not None and self._termination_condition.terminated:
|
||||
early_stop_message = StopMessage(content="The group chat has already terminated.", source=self._name)
|
||||
# Signal termination.
|
||||
await self._signal_termination(early_stop_message)
|
||||
# Stop the group chat.
|
||||
return
|
||||
assert message is not None and message.messages is not None
|
||||
|
||||
# Validate the group state given all the messages.
|
||||
await self.validate_group_state(message.messages)
|
||||
|
||||
# Log the message to the output topic.
|
||||
await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||
# Log the message to the output queue.
|
||||
for msg in message.messages:
|
||||
await self._output_message_queue.put(msg)
|
||||
|
||||
# Outer Loop for first time
|
||||
# Create the initial task ledger
|
||||
#################################
|
||||
# Combine all message contents for task
|
||||
self._task = " ".join([msg.to_model_text() for msg in message.messages])
|
||||
planning_conversation: List[LLMMessage] = []
|
||||
|
||||
# 1. GATHER FACTS
|
||||
# create a closed book task and generate a response and update the chat history
|
||||
planning_conversation.append(
|
||||
UserMessage(content=self._get_task_ledger_facts_prompt(self._task), source=self._name)
|
||||
)
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(planning_conversation), cancellation_token=ctx.cancellation_token
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
self._facts = response.content
|
||||
planning_conversation.append(AssistantMessage(content=self._facts, source=self._name))
|
||||
|
||||
# 2. CREATE A PLAN
|
||||
## plan based on available information
|
||||
planning_conversation.append(
|
||||
UserMessage(content=self._get_task_ledger_plan_prompt(self._team_description), source=self._name)
|
||||
)
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(planning_conversation), cancellation_token=ctx.cancellation_token
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
self._plan = response.content
|
||||
|
||||
# Kick things off
|
||||
self._n_stalls = 0
|
||||
await self._reenter_outer_loop(ctx.cancellation_token)
|
||||
|
||||
@event
|
||||
async def handle_agent_response( # type: ignore
|
||||
self, message: GroupChatAgentResponse | GroupChatTeamResponse, ctx: MessageContext
|
||||
) -> None: # type: ignore
|
||||
try:
|
||||
if not isinstance(message, GroupChatAgentResponse):
|
||||
raise RuntimeError("MagenticOneOrchestrator does not support GroupChatTeamResponse messages.")
|
||||
delta: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if message.response.inner_messages is not None:
|
||||
for inner_message in message.response.inner_messages:
|
||||
delta.append(inner_message)
|
||||
await self.update_message_thread([message.response.chat_message])
|
||||
delta.append(message.response.chat_message)
|
||||
|
||||
if self._termination_condition is not None:
|
||||
stop_message = await self._termination_condition(delta)
|
||||
if stop_message is not None:
|
||||
# Reset the termination conditions.
|
||||
await self._termination_condition.reset()
|
||||
# Signal termination.
|
||||
await self._signal_termination(stop_message)
|
||||
return
|
||||
|
||||
await self._orchestrate_step(ctx.cancellation_token)
|
||||
except Exception as e:
|
||||
error = SerializableException.from_exception(e)
|
||||
await self._signal_termination_with_error(error)
|
||||
# Raise the error to the runtime.
|
||||
raise
|
||||
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
pass
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = MagenticOneOrchestratorState(
|
||||
message_thread=[msg.dump() for msg in self._message_thread],
|
||||
current_turn=self._current_turn,
|
||||
task=self._task,
|
||||
facts=self._facts,
|
||||
plan=self._plan,
|
||||
n_rounds=self._n_rounds,
|
||||
n_stalls=self._n_stalls,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
orchestrator_state = MagenticOneOrchestratorState.model_validate(state)
|
||||
self._message_thread = [self._message_factory.create(message) for message in orchestrator_state.message_thread]
|
||||
self._current_turn = orchestrator_state.current_turn
|
||||
self._task = orchestrator_state.task
|
||||
self._facts = orchestrator_state.facts
|
||||
self._plan = orchestrator_state.plan
|
||||
self._n_rounds = orchestrator_state.n_rounds
|
||||
self._n_stalls = orchestrator_state.n_stalls
|
||||
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Not used in this orchestrator, we select next speaker in _orchestrate_step."""
|
||||
return [""]
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset the group chat manager."""
|
||||
self._message_thread.clear()
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._n_rounds = 0
|
||||
self._n_stalls = 0
|
||||
self._task = ""
|
||||
self._facts = ""
|
||||
self._plan = ""
|
||||
|
||||
async def _reenter_outer_loop(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Re-enter Outer loop of the orchestrator after creating task ledger."""
|
||||
# Reset the agents
|
||||
for participant_topic_type in self._participant_name_to_topic_type.values():
|
||||
await self._runtime.send_message(
|
||||
GroupChatReset(),
|
||||
recipient=AgentId(type=participant_topic_type, key=self.id.key),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
# Reset partially the group chat manager
|
||||
self._message_thread.clear()
|
||||
|
||||
# Prepare the ledger
|
||||
ledger_message = TextMessage(
|
||||
content=self._get_task_ledger_full_prompt(self._task, self._team_description, self._facts, self._plan),
|
||||
source=self._name,
|
||||
)
|
||||
|
||||
# Save my copy
|
||||
await self.update_message_thread([ledger_message])
|
||||
|
||||
# Log it to the output topic.
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=ledger_message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Log it to the output queue.
|
||||
await self._output_message_queue.put(ledger_message)
|
||||
|
||||
# Broadcast
|
||||
await self.publish_message(
|
||||
GroupChatAgentResponse(response=Response(chat_message=ledger_message), name=self._name),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
)
|
||||
|
||||
# Restart the inner loop
|
||||
await self._orchestrate_step(cancellation_token=cancellation_token)
|
||||
|
||||
async def _orchestrate_step(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Implements the inner loop of the orchestrator and selects next speaker."""
|
||||
# Check if we reached the maximum number of rounds
|
||||
if self._max_turns is not None and self._n_rounds > self._max_turns:
|
||||
await self._prepare_final_answer("Max rounds reached.", cancellation_token)
|
||||
return
|
||||
self._n_rounds += 1
|
||||
|
||||
# Update the progress ledger
|
||||
context = self._thread_to_context()
|
||||
|
||||
progress_ledger_prompt = self._get_progress_ledger_prompt(
|
||||
self._task, self._team_description, self._participant_names
|
||||
)
|
||||
context.append(UserMessage(content=progress_ledger_prompt, source=self._name))
|
||||
progress_ledger: Dict[str, Any] = {}
|
||||
assert self._max_json_retries > 0
|
||||
key_error: bool = False
|
||||
for _ in range(self._max_json_retries):
|
||||
if self._model_client.model_info.get("structured_output", False):
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), json_output=LedgerEntry
|
||||
)
|
||||
elif self._model_client.model_info.get("json_output", False):
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), cancellation_token=cancellation_token, json_output=True
|
||||
)
|
||||
else:
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), cancellation_token=cancellation_token
|
||||
)
|
||||
ledger_str = response.content
|
||||
try:
|
||||
assert isinstance(ledger_str, str)
|
||||
output_json = extract_json_from_str(ledger_str)
|
||||
if len(output_json) != 1:
|
||||
raise ValueError(
|
||||
f"Progress ledger should contain a single JSON object, but found: {len(progress_ledger)}"
|
||||
)
|
||||
progress_ledger = output_json[0]
|
||||
|
||||
# If the team consists of a single agent, deterministically set the next speaker
|
||||
if len(self._participant_names) == 1:
|
||||
progress_ledger["next_speaker"] = {
|
||||
"reason": "The team consists of only one agent.",
|
||||
"answer": self._participant_names[0],
|
||||
}
|
||||
|
||||
# Validate the structure
|
||||
required_keys = [
|
||||
"is_request_satisfied",
|
||||
"is_progress_being_made",
|
||||
"is_in_loop",
|
||||
"instruction_or_question",
|
||||
"next_speaker",
|
||||
]
|
||||
|
||||
key_error = False
|
||||
for key in required_keys:
|
||||
if (
|
||||
key not in progress_ledger
|
||||
or not isinstance(progress_ledger[key], dict)
|
||||
or "answer" not in progress_ledger[key]
|
||||
or "reason" not in progress_ledger[key]
|
||||
):
|
||||
key_error = True
|
||||
break
|
||||
|
||||
# Validate the next speaker if the task is not yet complete
|
||||
if (
|
||||
not progress_ledger["is_request_satisfied"]["answer"]
|
||||
and progress_ledger["next_speaker"]["answer"] not in self._participant_names
|
||||
):
|
||||
key_error = True
|
||||
break
|
||||
|
||||
if not key_error:
|
||||
break
|
||||
await self._log_message(f"Failed to parse ledger information, retrying: {ledger_str}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
key_error = True
|
||||
await self._log_message("Invalid ledger format encountered, retrying...")
|
||||
continue
|
||||
if key_error:
|
||||
raise ValueError("Failed to parse ledger information after multiple retries.")
|
||||
await self._log_message(f"Progress Ledger: {progress_ledger}")
|
||||
|
||||
# Check for task completion
|
||||
if progress_ledger["is_request_satisfied"]["answer"]:
|
||||
await self._log_message("Task completed, preparing final answer...")
|
||||
await self._prepare_final_answer(progress_ledger["is_request_satisfied"]["reason"], cancellation_token)
|
||||
return
|
||||
|
||||
# Check for stalling
|
||||
if not progress_ledger["is_progress_being_made"]["answer"]:
|
||||
self._n_stalls += 1
|
||||
elif progress_ledger["is_in_loop"]["answer"]:
|
||||
self._n_stalls += 1
|
||||
else:
|
||||
self._n_stalls = max(0, self._n_stalls - 1)
|
||||
|
||||
# Too much stalling
|
||||
if self._n_stalls >= self._max_stalls:
|
||||
await self._log_message("Stall count exceeded, re-planning with the outer loop...")
|
||||
await self._update_task_ledger(cancellation_token)
|
||||
await self._reenter_outer_loop(cancellation_token)
|
||||
return
|
||||
|
||||
# Broadcast the next step
|
||||
message = TextMessage(content=progress_ledger["instruction_or_question"]["answer"], source=self._name)
|
||||
await self.update_message_thread([message]) # My copy
|
||||
|
||||
await self._log_message(f"Next Speaker: {progress_ledger['next_speaker']['answer']}")
|
||||
# Log it to the output topic.
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Log it to the output queue.
|
||||
await self._output_message_queue.put(message)
|
||||
|
||||
# Broadcast it
|
||||
await self.publish_message( # Broadcast
|
||||
GroupChatAgentResponse(response=Response(chat_message=message), name=self._name),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
# Request that the step be completed
|
||||
next_speaker = progress_ledger["next_speaker"]["answer"]
|
||||
# Check if the next speaker is valid
|
||||
if next_speaker not in self._participant_name_to_topic_type:
|
||||
raise ValueError(
|
||||
f"Invalid next speaker: {next_speaker} from the ledger, participants are: {self._participant_names}"
|
||||
)
|
||||
participant_topic_type = self._participant_name_to_topic_type[next_speaker]
|
||||
await self.publish_message(
|
||||
GroupChatRequestPublish(),
|
||||
topic_id=DefaultTopicId(type=participant_topic_type),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
# Send the message to the next speaker
|
||||
if self._emit_team_events:
|
||||
select_msg = SelectSpeakerEvent(content=[next_speaker], source=self._name)
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=select_msg),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
await self._output_message_queue.put(select_msg)
|
||||
|
||||
async def _update_task_ledger(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Update the task ledger (outer loop) with the latest facts and plan."""
|
||||
context = self._thread_to_context()
|
||||
|
||||
# Update the facts
|
||||
update_facts_prompt = self._get_task_ledger_facts_update_prompt(self._task, self._facts)
|
||||
context.append(UserMessage(content=update_facts_prompt, source=self._name))
|
||||
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), cancellation_token=cancellation_token
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
self._facts = response.content
|
||||
context.append(AssistantMessage(content=self._facts, source=self._name))
|
||||
|
||||
# Update the plan
|
||||
update_plan_prompt = self._get_task_ledger_plan_update_prompt(self._team_description)
|
||||
context.append(UserMessage(content=update_plan_prompt, source=self._name))
|
||||
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), cancellation_token=cancellation_token
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
self._plan = response.content
|
||||
|
||||
async def _prepare_final_answer(self, reason: str, cancellation_token: CancellationToken) -> None:
|
||||
"""Prepare the final answer for the task."""
|
||||
context = self._thread_to_context()
|
||||
|
||||
# Get the final answer
|
||||
final_answer_prompt = self._get_final_answer_prompt(self._task)
|
||||
context.append(UserMessage(content=final_answer_prompt, source=self._name))
|
||||
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), cancellation_token=cancellation_token
|
||||
)
|
||||
assert isinstance(response.content, str)
|
||||
message = TextMessage(content=response.content, source=self._name)
|
||||
|
||||
await self.update_message_thread([message]) # My copy
|
||||
|
||||
# Log it to the output topic.
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Log it to the output queue.
|
||||
await self._output_message_queue.put(message)
|
||||
|
||||
# Broadcast
|
||||
await self.publish_message(
|
||||
GroupChatAgentResponse(response=Response(chat_message=message), name=self._name),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
# Signal termination
|
||||
await self._signal_termination(StopMessage(content=reason, source=self._name))
|
||||
|
||||
def _thread_to_context(self) -> List[LLMMessage]:
|
||||
"""Convert the message thread to a context for the model."""
|
||||
context: List[LLMMessage] = []
|
||||
for m in self._message_thread:
|
||||
if isinstance(m, ToolCallRequestEvent | ToolCallExecutionEvent):
|
||||
# Ignore tool call messages.
|
||||
continue
|
||||
elif isinstance(m, StopMessage | HandoffMessage):
|
||||
context.append(UserMessage(content=m.content, source=m.source))
|
||||
elif m.source == self._name:
|
||||
assert isinstance(m, TextMessage | ToolCallSummaryMessage)
|
||||
context.append(AssistantMessage(content=m.content, source=m.source))
|
||||
else:
|
||||
assert isinstance(m, (TextMessage, MultiModalMessage, ToolCallSummaryMessage))
|
||||
context.append(UserMessage(content=m.content, source=m.source))
|
||||
return context
|
||||
|
||||
def _get_compatible_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
|
||||
if self._model_client.model_info["vision"]:
|
||||
return messages
|
||||
else:
|
||||
return remove_images(messages)
|
||||
@@ -0,0 +1,149 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
ORCHESTRATOR_SYSTEM_MESSAGE = ""
|
||||
|
||||
|
||||
ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT = """Below I will present you a request. Before we begin addressing the request, please answer the following pre-survey to the best of your ability. Keep in mind that you are Ken Jennings-level with trivia, and Mensa-level with puzzles, so there should be a deep well to draw from.
|
||||
|
||||
Here is the request:
|
||||
|
||||
{task}
|
||||
|
||||
Here is the pre-survey:
|
||||
|
||||
1. Please list any specific facts or figures that are GIVEN in the request itself. It is possible that there are none.
|
||||
2. Please list any facts that may need to be looked up, and WHERE SPECIFICALLY they might be found. In some cases, authoritative sources are mentioned in the request itself.
|
||||
3. Please list any facts that may need to be derived (e.g., via logical deduction, simulation, or computation)
|
||||
4. Please list any facts that are recalled from memory, hunches, well-reasoned guesses, etc.
|
||||
|
||||
When answering this survey, keep in mind that "facts" will typically be specific names, dates, statistics, etc. Your answer should use headings:
|
||||
|
||||
1. GIVEN OR VERIFIED FACTS
|
||||
2. FACTS TO LOOK UP
|
||||
3. FACTS TO DERIVE
|
||||
4. EDUCATED GUESSES
|
||||
|
||||
DO NOT include any other headings or sections in your response. DO NOT list next steps or plans until asked to do so.
|
||||
"""
|
||||
|
||||
|
||||
ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT = """Fantastic. To address this request we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
Based on the team composition, and known and unknown facts, please devise a short bullet-point plan for addressing the original request. Remember, there is no requirement to involve all team members -- a team member's particular expertise may not be needed for this task."""
|
||||
|
||||
|
||||
ORCHESTRATOR_TASK_LEDGER_FULL_PROMPT = """
|
||||
We are working to address the following user request:
|
||||
|
||||
{task}
|
||||
|
||||
|
||||
To answer this request we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
|
||||
Here is an initial fact sheet to consider:
|
||||
|
||||
{facts}
|
||||
|
||||
|
||||
Here is the plan to follow as best as possible:
|
||||
|
||||
{plan}
|
||||
"""
|
||||
|
||||
|
||||
ORCHESTRATOR_PROGRESS_LEDGER_PROMPT = """
|
||||
Recall we are working on the following request:
|
||||
|
||||
{task}
|
||||
|
||||
And we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
To make progress on the request, please answer the following questions, including necessary reasoning:
|
||||
|
||||
- Is the request fully satisfied? (True if complete, or False if the original request has yet to be SUCCESSFULLY and FULLY addressed)
|
||||
- Are we in a loop where we are repeating the same requests and / or getting the same responses as before? Loops can span multiple turns, and can include repeated actions like scrolling up or down more than a handful of times.
|
||||
- Are we making forward progress? (True if just starting, or recent messages are adding value. False if recent messages show evidence of being stuck in a loop or if there is evidence of significant barriers to success such as the inability to read from a required file)
|
||||
- Who should speak next? (select from: {names})
|
||||
- What instruction or question would you give this team member? (Phrase as if speaking directly to them, and include any specific information they may need)
|
||||
|
||||
Please output an answer in pure JSON format according to the following schema. The JSON object must be parsable as-is. DO NOT OUTPUT ANYTHING OTHER THAN JSON, AND DO NOT DEVIATE FROM THIS SCHEMA:
|
||||
|
||||
{{
|
||||
"is_request_satisfied": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}},
|
||||
"is_in_loop": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}},
|
||||
"is_progress_being_made": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}},
|
||||
"next_speaker": {{
|
||||
"reason": string,
|
||||
"answer": string (select from: {names})
|
||||
}},
|
||||
"instruction_or_question": {{
|
||||
"reason": string,
|
||||
"answer": string
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
class LedgerEntryBooleanAnswer(BaseModel):
|
||||
reason: str
|
||||
answer: bool
|
||||
|
||||
|
||||
class LedgerEntryStringAnswer(BaseModel):
|
||||
reason: str
|
||||
answer: str
|
||||
|
||||
|
||||
class LedgerEntry(BaseModel):
|
||||
is_request_satisfied: LedgerEntryBooleanAnswer
|
||||
is_in_loop: LedgerEntryBooleanAnswer
|
||||
is_progress_being_made: LedgerEntryBooleanAnswer
|
||||
next_speaker: LedgerEntryStringAnswer
|
||||
instruction_or_question: LedgerEntryStringAnswer
|
||||
|
||||
|
||||
ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT = """As a reminder, we are working to solve the following task:
|
||||
|
||||
{task}
|
||||
|
||||
It's clear we aren't making as much progress as we would like, but we may have learned something new. Please rewrite the following fact sheet, updating it to include anything new we have learned that may be helpful. Example edits can include (but are not limited to) adding new guesses, moving educated guesses to verified facts if appropriate, etc. Updates may be made to any section of the fact sheet, and more than one section of the fact sheet can be edited. This is an especially good time to update educated guesses, so please at least add or update one educated guess or hunch, and explain your reasoning.
|
||||
|
||||
Here is the old fact sheet:
|
||||
|
||||
{facts}
|
||||
"""
|
||||
|
||||
|
||||
ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT = """Please briefly explain what went wrong on this last run (the root cause of the failure), and then come up with a new plan that takes steps and/or includes hints to overcome prior challenges and especially avoids repeating the same mistakes. As before, the new plan should be concise, be expressed in bullet-point form, and consider the following team composition (do not involve any other outside people since we cannot contact anyone else):
|
||||
|
||||
{team}
|
||||
"""
|
||||
|
||||
|
||||
ORCHESTRATOR_FINAL_ANSWER_PROMPT = """
|
||||
We are working on the following task:
|
||||
{task}
|
||||
|
||||
We have completed the task.
|
||||
|
||||
The above messages contain the conversation that took place to complete the task.
|
||||
|
||||
Based on the information gathered, provide the final answer to the original request.
|
||||
The answer should be phrased as if you were speaking to the user.
|
||||
"""
|
||||
@@ -0,0 +1,328 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, List, Mapping, Sequence
|
||||
|
||||
from agentdhal_core import AgentRuntime, Component, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ...base import ChatAgent, Team, TerminationCondition
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory
|
||||
from ...state import RoundRobinManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
from ._events import GroupChatTermination
|
||||
|
||||
|
||||
class RoundRobinGroupChatManager(BaseGroupChatManager):
|
||||
"""A group chat manager that selects the next speaker in a round-robin fashion."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
emit_team_events: bool,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
emit_team_events,
|
||||
)
|
||||
self._next_speaker_index = 0
|
||||
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
pass
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._current_turn = 0
|
||||
self._message_thread.clear()
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._next_speaker_index = 0
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = RoundRobinManagerState(
|
||||
message_thread=[message.dump() for message in self._message_thread],
|
||||
current_turn=self._current_turn,
|
||||
next_speaker_index=self._next_speaker_index,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
round_robin_state = RoundRobinManagerState.model_validate(state)
|
||||
self._message_thread = [self._message_factory.create(message) for message in round_robin_state.message_thread]
|
||||
self._current_turn = round_robin_state.current_turn
|
||||
self._next_speaker_index = round_robin_state.next_speaker_index
|
||||
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Select a speaker from the participants in a round-robin fashion.
|
||||
|
||||
.. note::
|
||||
|
||||
This method always returns a single speaker.
|
||||
"""
|
||||
current_speaker_index = self._next_speaker_index
|
||||
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_names)
|
||||
current_speaker = self._participant_names[current_speaker_index]
|
||||
return current_speaker
|
||||
|
||||
|
||||
class RoundRobinGroupChatConfig(BaseModel):
|
||||
"""The declarative configuration RoundRobinGroupChat."""
|
||||
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
participants: List[ComponentModel]
|
||||
termination_condition: ComponentModel | None = None
|
||||
max_turns: int | None = None
|
||||
emit_team_events: bool = False
|
||||
|
||||
|
||||
class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
||||
"""A team that runs a group chat with participants taking turns in a round-robin fashion
|
||||
to publish a message to all.
|
||||
|
||||
If an :class:`~agentdhal_agentchat.base.ChatAgent` is a participant,
|
||||
the :class:`~agentdhal_agentchat.messages.BaseChatMessage` from the agent response's
|
||||
:attr:`~agentdhal_agentchat.base.Response.chat_message` will be published
|
||||
to other participants in the group chat.
|
||||
|
||||
If a :class:`~agentdhal_agentchat.base.Team` is a participant,
|
||||
the :class:`~agentdhal_agentchat.messages.BaseChatMessage`
|
||||
from the team result' :attr:`~agentdhal_agentchat.base.TaskResult.messages` will be published
|
||||
to other participants in the group chat.
|
||||
|
||||
If a single participant is in the team, the participant will be the only speaker.
|
||||
|
||||
Args:
|
||||
participants (List[ChatAgent | Team]): The participants in the group chat.
|
||||
name (str | None, optional): The name of the group chat, using :attr:`~agentdhal_agentchat.teams.RoundRobinGroupChat.DEFAULT_NAME` if not provided.
|
||||
The name is used by a parent team to identify this group chat so it must be unique within the parent team.
|
||||
description (str | None, optional): The description of the group chat, using :attr:`~agentdhal_agentchat.teams.RoundRobinGroupChat.DEFAULT_DESCRIPTION` if not provided.
|
||||
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
|
||||
Without a termination condition, the group chat will run indefinitely.
|
||||
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit.
|
||||
custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): A list of custom message types that will be used in the group chat.
|
||||
If you are using custom message types or your agents produces custom message types, you need to specify them here.
|
||||
Make sure your custom message types are subclasses of :class:`~agentdhal_agentchat.messages.BaseAgentEvent` or :class:`~agentdhal_agentchat.messages.BaseChatMessage`.
|
||||
emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
|
||||
|
||||
Raises:
|
||||
ValueError: If no participants are provided or if participant names are not unique.
|
||||
|
||||
Examples:
|
||||
|
||||
A team with one participant with tools:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
async def get_weather(location: str) -> str:
|
||||
return f"The weather in {location} is sunny."
|
||||
|
||||
assistant = AssistantAgent(
|
||||
"Assistant",
|
||||
model_client=model_client,
|
||||
tools=[get_weather],
|
||||
)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team = RoundRobinGroupChat([assistant], termination_condition=termination)
|
||||
await Console(team.run_stream(task="What's the weather in New York?"))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
A team with multiple participants:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
|
||||
await Console(team.run_stream(task="Tell me some jokes."))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
A team of user proxy and a nested team of writer and reviewer agents:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import UserProxyAgent, AssistantAgent
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination, MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
|
||||
|
||||
writer = AssistantAgent(
|
||||
"writer", model_client=model_client, system_message="You are a writer.", model_client_stream=True
|
||||
)
|
||||
|
||||
reviewer = AssistantAgent(
|
||||
"reviewer",
|
||||
model_client=model_client,
|
||||
system_message="Provide feedback to the input and suggest improvements.",
|
||||
model_client_stream=True,
|
||||
)
|
||||
|
||||
# NOTE: you can skip input by pressing Enter.
|
||||
user_proxy = UserProxyAgent("user_proxy")
|
||||
|
||||
# Maximum 1 round of review and revision.
|
||||
inner_termination = MaxMessageTermination(max_messages=4)
|
||||
|
||||
# The outter-loop termination condition that will terminate the team when the user types "exit".
|
||||
outter_termination = TextMentionTermination("exit", sources=["user_proxy"])
|
||||
|
||||
team = RoundRobinGroupChat(
|
||||
[
|
||||
# For each turn, the writer writes a summary and the reviewer reviews it.
|
||||
RoundRobinGroupChat([writer, reviewer], termination_condition=inner_termination),
|
||||
# The user proxy gets user input once the writer and reviewer have finished their actions.
|
||||
user_proxy,
|
||||
],
|
||||
termination_condition=outter_termination,
|
||||
)
|
||||
# Start the team and wait for it to terminate.
|
||||
await Console(team.run_stream(task="Write a short essay about the impact of AI on society."))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_config_schema = RoundRobinGroupChatConfig
|
||||
component_provider_override = "agentdhal_agentchat.teams.RoundRobinGroupChat"
|
||||
|
||||
DEFAULT_NAME = "RoundRobinGroupChat"
|
||||
DEFAULT_DESCRIPTION = "A team of agents."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
participants: List[ChatAgent | Team],
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name=name or self.DEFAULT_NAME,
|
||||
description=description or self.DEFAULT_DESCRIPTION,
|
||||
participants=participants,
|
||||
group_chat_manager_name="RoundRobinGroupChatManager",
|
||||
group_chat_manager_class=RoundRobinGroupChatManager,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], RoundRobinGroupChatManager]:
|
||||
def _factory() -> RoundRobinGroupChatManager:
|
||||
return RoundRobinGroupChatManager(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
self._emit_team_events,
|
||||
)
|
||||
|
||||
return _factory
|
||||
|
||||
def _to_config(self) -> RoundRobinGroupChatConfig:
|
||||
participants = [participant.dump_component() for participant in self._participants]
|
||||
termination_condition = self._termination_condition.dump_component() if self._termination_condition else None
|
||||
return RoundRobinGroupChatConfig(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
participants=participants,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=self._max_turns,
|
||||
emit_team_events=self._emit_team_events,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: RoundRobinGroupChatConfig) -> Self:
|
||||
participants: List[ChatAgent | Team] = []
|
||||
for participant in config.participants:
|
||||
if participant.component_type == Team.component_type:
|
||||
participants.append(Team.load_component(participant))
|
||||
else:
|
||||
participants.append(ChatAgent.load_component(participant))
|
||||
|
||||
termination_condition = (
|
||||
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
|
||||
)
|
||||
return cls(
|
||||
participants,
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=config.max_turns,
|
||||
emit_team_events=config.emit_team_events,
|
||||
)
|
||||
@@ -0,0 +1,730 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast
|
||||
|
||||
from agentdhal_core import AgentRuntime, CancellationToken, Component, ComponentModel
|
||||
from agentdhal_core.model_context import (
|
||||
ChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
LLMMessage,
|
||||
ModelFamily,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ... import TRACE_LOGGER_NAME
|
||||
from ...base import ChatAgent, Team, TerminationCondition
|
||||
from ...messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
HandoffMessage,
|
||||
MessageFactory,
|
||||
ModelClientStreamingChunkEvent,
|
||||
SelectorEvent,
|
||||
)
|
||||
from ...state import SelectorManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
from ._events import GroupChatTermination
|
||||
|
||||
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
||||
|
||||
SyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None]
|
||||
AsyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]]
|
||||
SelectorFuncType = Union[SyncSelectorFunc | AsyncSelectorFunc]
|
||||
|
||||
SyncCandidateFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]]
|
||||
AsyncCandidateFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[List[str]]]
|
||||
CandidateFuncType = Union[SyncCandidateFunc | AsyncCandidateFunc]
|
||||
|
||||
|
||||
class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
"""A group chat manager that selects the next speaker using a ChatCompletion
|
||||
model and a custom selector function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
model_client: ChatCompletionClient,
|
||||
selector_prompt: str,
|
||||
allow_repeated_speaker: bool,
|
||||
selector_func: Optional[SelectorFuncType],
|
||||
max_selector_attempts: int,
|
||||
candidate_func: Optional[CandidateFuncType],
|
||||
emit_team_events: bool,
|
||||
model_context: ChatCompletionContext | None,
|
||||
model_client_streaming: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
emit_team_events,
|
||||
)
|
||||
self._model_client = model_client
|
||||
self._selector_prompt = selector_prompt
|
||||
self._previous_speaker: str | None = None
|
||||
self._allow_repeated_speaker = allow_repeated_speaker
|
||||
self._selector_func = selector_func
|
||||
self._is_selector_func_async = iscoroutinefunction(self._selector_func)
|
||||
self._max_selector_attempts = max_selector_attempts
|
||||
self._candidate_func = candidate_func
|
||||
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
|
||||
self._model_client_streaming = model_client_streaming
|
||||
if model_context is not None:
|
||||
self._model_context = model_context
|
||||
else:
|
||||
self._model_context = UnboundedChatCompletionContext()
|
||||
self._cancellation_token = CancellationToken()
|
||||
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
pass
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._current_turn = 0
|
||||
self._message_thread.clear()
|
||||
await self._model_context.clear()
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._previous_speaker = None
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = SelectorManagerState(
|
||||
message_thread=[msg.dump() for msg in self._message_thread],
|
||||
current_turn=self._current_turn,
|
||||
previous_speaker=self._previous_speaker,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
selector_state = SelectorManagerState.model_validate(state)
|
||||
self._message_thread = [self._message_factory.create(msg) for msg in selector_state.message_thread]
|
||||
await self._add_messages_to_context(
|
||||
self._model_context, [msg for msg in self._message_thread if isinstance(msg, BaseChatMessage)]
|
||||
)
|
||||
self._current_turn = selector_state.current_turn
|
||||
self._previous_speaker = selector_state.previous_speaker
|
||||
|
||||
@staticmethod
|
||||
async def _add_messages_to_context(
|
||||
model_context: ChatCompletionContext,
|
||||
messages: Sequence[BaseChatMessage],
|
||||
) -> None:
|
||||
"""
|
||||
Add incoming messages to the model context.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HandoffMessage):
|
||||
for llm_msg in msg.context:
|
||||
await model_context.add_message(llm_msg)
|
||||
await model_context.add_message(msg.to_model_message())
|
||||
|
||||
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
|
||||
self._message_thread.extend(messages)
|
||||
base_chat_messages = [m for m in messages if isinstance(m, BaseChatMessage)]
|
||||
await self._add_messages_to_context(self._model_context, base_chat_messages)
|
||||
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Selects the next speaker in a group chat using a ChatCompletion client,
|
||||
with the selector function as override if it returns a speaker name.
|
||||
|
||||
.. note::
|
||||
|
||||
This method always returns a single speaker name.
|
||||
|
||||
A key assumption is that the agent type is the same as the topic type, which we use as the agent name.
|
||||
"""
|
||||
# Use the selector function if provided.
|
||||
if self._selector_func is not None:
|
||||
if self._is_selector_func_async:
|
||||
async_selector_func = cast(AsyncSelectorFunc, self._selector_func)
|
||||
speaker = await async_selector_func(thread)
|
||||
else:
|
||||
sync_selector_func = cast(SyncSelectorFunc, self._selector_func)
|
||||
speaker = sync_selector_func(thread)
|
||||
if speaker is not None:
|
||||
if speaker not in self._participant_names:
|
||||
raise ValueError(
|
||||
f"Selector function returned an invalid speaker name: {speaker}. "
|
||||
f"Expected one of: {self._participant_names}."
|
||||
)
|
||||
# Skip the model based selection.
|
||||
return [speaker]
|
||||
|
||||
# Use the candidate function to filter participants if provided
|
||||
if self._candidate_func is not None:
|
||||
if self._is_candidate_func_async:
|
||||
async_candidate_func = cast(AsyncCandidateFunc, self._candidate_func)
|
||||
participants = await async_candidate_func(thread)
|
||||
else:
|
||||
sync_candidate_func = cast(SyncCandidateFunc, self._candidate_func)
|
||||
participants = sync_candidate_func(thread)
|
||||
if not participants:
|
||||
raise ValueError("Candidate function must return a non-empty list of participant names.")
|
||||
if not all(p in self._participant_names for p in participants):
|
||||
raise ValueError(
|
||||
f"Candidate function returned invalid participant names: {participants}. "
|
||||
f"Expected one of: {self._participant_names}."
|
||||
)
|
||||
else:
|
||||
# Construct the candidate agent list to be selected from, skip the previous speaker if not allowed.
|
||||
if self._previous_speaker is not None and not self._allow_repeated_speaker:
|
||||
participants = [p for p in self._participant_names if p != self._previous_speaker]
|
||||
else:
|
||||
participants = list(self._participant_names)
|
||||
|
||||
assert len(participants) > 0
|
||||
|
||||
# Construct agent roles.
|
||||
# Each agent sould appear on a single line.
|
||||
roles = ""
|
||||
for topic_type, description in zip(self._participant_names, self._participant_descriptions, strict=True):
|
||||
roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n"
|
||||
roles = roles.strip()
|
||||
|
||||
# Select the next speaker.
|
||||
if len(participants) > 1:
|
||||
agent_name = await self._select_speaker(roles, participants, self._max_selector_attempts)
|
||||
else:
|
||||
agent_name = participants[0]
|
||||
self._previous_speaker = agent_name
|
||||
trace_logger.debug(f"Selected speaker: {agent_name}")
|
||||
return [agent_name]
|
||||
|
||||
def construct_message_history(self, message_history: List[LLMMessage]) -> str:
|
||||
# Construct the history of the conversation.
|
||||
history_messages: List[str] = []
|
||||
for msg in message_history:
|
||||
if isinstance(msg, UserMessage) or isinstance(msg, AssistantMessage):
|
||||
message = f"{msg.source}: {msg.content}"
|
||||
history_messages.append(
|
||||
message.rstrip() + "\n\n"
|
||||
) # Create some consistency for how messages are separated in the transcript
|
||||
|
||||
history: str = "\n".join(history_messages)
|
||||
return history
|
||||
|
||||
async def _select_speaker(self, roles: str, participants: List[str], max_attempts: int) -> str:
|
||||
model_context_messages = await self._model_context.get_messages()
|
||||
model_context_history = self.construct_message_history(model_context_messages)
|
||||
|
||||
select_speaker_prompt = self._selector_prompt.format(
|
||||
roles=roles, participants=str(participants), history=model_context_history
|
||||
)
|
||||
|
||||
select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage]
|
||||
if ModelFamily.is_openai(self._model_client.model_info["family"]):
|
||||
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
|
||||
else:
|
||||
# Many other models need a UserMessage to respond to
|
||||
select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="user")]
|
||||
|
||||
num_attempts = 0
|
||||
while num_attempts < max_attempts:
|
||||
num_attempts += 1
|
||||
if self._model_client_streaming:
|
||||
chunk: CreateResult | str = ""
|
||||
async for _chunk in self._model_client.create_stream(messages=select_speaker_messages):
|
||||
chunk = _chunk
|
||||
if self._emit_team_events:
|
||||
if isinstance(chunk, str):
|
||||
await self._output_message_queue.put(
|
||||
ModelClientStreamingChunkEvent(content=cast(str, _chunk), source=self._name)
|
||||
)
|
||||
else:
|
||||
assert isinstance(chunk, CreateResult)
|
||||
assert isinstance(chunk.content, str)
|
||||
await self._output_message_queue.put(
|
||||
SelectorEvent(content=chunk.content, source=self._name)
|
||||
)
|
||||
# The last chunk must be CreateResult.
|
||||
assert isinstance(chunk, CreateResult)
|
||||
response = chunk
|
||||
else:
|
||||
response = await self._model_client.create(messages=select_speaker_messages)
|
||||
assert isinstance(response.content, str)
|
||||
select_speaker_messages.append(AssistantMessage(content=response.content, source="selector"))
|
||||
# NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed.
|
||||
# This is because the model may still select the previous speaker, and we want to catch that.
|
||||
mentions = self._mentioned_agents(response.content, self._participant_names)
|
||||
if len(mentions) == 0:
|
||||
trace_logger.debug(f"Model failed to select a valid name: {response.content} (attempt {num_attempts})")
|
||||
feedback = f"No valid name was mentioned. Please select from: {str(participants)}."
|
||||
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
|
||||
elif len(mentions) > 1:
|
||||
trace_logger.debug(f"Model selected multiple names: {str(mentions)} (attempt {num_attempts})")
|
||||
feedback = (
|
||||
f"Expected exactly one name to be mentioned. Please select only one from: {str(participants)}."
|
||||
)
|
||||
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
|
||||
else:
|
||||
agent_name = list(mentions.keys())[0]
|
||||
if (
|
||||
not self._allow_repeated_speaker
|
||||
and self._previous_speaker is not None
|
||||
and agent_name == self._previous_speaker
|
||||
):
|
||||
trace_logger.debug(f"Model selected the previous speaker: {agent_name} (attempt {num_attempts})")
|
||||
feedback = (
|
||||
f"Repeated speaker is not allowed, please select a different name from: {str(participants)}."
|
||||
)
|
||||
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
|
||||
else:
|
||||
# Valid selection
|
||||
trace_logger.debug(f"Model selected a valid name: {agent_name} (attempt {num_attempts})")
|
||||
return agent_name
|
||||
|
||||
if self._previous_speaker is not None:
|
||||
trace_logger.warning(f"Model failed to select a speaker after {max_attempts}, using the previous speaker.")
|
||||
return self._previous_speaker
|
||||
trace_logger.warning(
|
||||
f"Model failed to select a speaker after {max_attempts} and there was no previous speaker, using the first participant."
|
||||
)
|
||||
return participants[0]
|
||||
|
||||
def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]:
|
||||
"""Counts the number of times each agent is mentioned in the provided message content.
|
||||
Agent names will match under any of the following conditions (all case-sensitive):
|
||||
- Exact name match
|
||||
- If the agent name has underscores it will match with spaces instead (e.g. 'Story_writer' == 'Story writer')
|
||||
- If the agent name has underscores it will match with '\\_' instead of '_' (e.g. 'Story_writer' == 'Story\\_writer')
|
||||
|
||||
Args:
|
||||
message_content (Union[str, List]): The content of the message, either as a single string or a list of strings.
|
||||
agents (List[Agent]): A list of Agent objects, each having a 'name' attribute to be searched in the message content.
|
||||
|
||||
Returns:
|
||||
Dict: a counter for mentioned agents.
|
||||
"""
|
||||
mentions: Dict[str, int] = dict()
|
||||
for name in agent_names:
|
||||
# Finds agent mentions, taking word boundaries into account,
|
||||
# accommodates escaping underscores and underscores as spaces
|
||||
regex = (
|
||||
r"(?<=\W)("
|
||||
+ re.escape(name)
|
||||
+ r"|"
|
||||
+ re.escape(name.replace("_", " "))
|
||||
+ r"|"
|
||||
+ re.escape(name.replace("_", r"\_"))
|
||||
+ r")(?=\W)"
|
||||
)
|
||||
# Pad the message to help with matching
|
||||
count = len(re.findall(regex, f" {message_content} "))
|
||||
if count > 0:
|
||||
mentions[name] = count
|
||||
return mentions
|
||||
|
||||
|
||||
class SelectorGroupChatConfig(BaseModel):
|
||||
"""The declarative configuration for SelectorGroupChat."""
|
||||
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
participants: List[ComponentModel]
|
||||
model_client: ComponentModel
|
||||
termination_condition: ComponentModel | None = None
|
||||
max_turns: int | None = None
|
||||
selector_prompt: str
|
||||
allow_repeated_speaker: bool
|
||||
# selector_func: ComponentModel | None
|
||||
max_selector_attempts: int = 3
|
||||
emit_team_events: bool = False
|
||||
model_client_streaming: bool = False
|
||||
model_context: ComponentModel | None = None
|
||||
|
||||
|
||||
class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
||||
"""A group chat team that have participants takes turn to publish a message
|
||||
to all, using a ChatCompletion model to select the next speaker after each message.
|
||||
|
||||
If an :class:`~agentdhal_agentchat.base.ChatAgent` is a participant,
|
||||
the :class:`~agentdhal_agentchat.messages.BaseChatMessage` from the agent response's
|
||||
:attr:`~agentdhal_agentchat.base.Response.chat_message` will be published
|
||||
to other participants in the group chat.
|
||||
|
||||
If a :class:`~agentdhal_agentchat.base.Team` is a participant,
|
||||
the :class:`~agentdhal_agentchat.messages.BaseChatMessage`
|
||||
from the team result' :attr:`~agentdhal_agentchat.base.TaskResult.messages` will be published
|
||||
to other participants in the group chat.
|
||||
|
||||
Args:
|
||||
participants (List[ChatAgent | Team]): The participants in the group chat,
|
||||
must have unique names and at least two participants.
|
||||
model_client (ChatCompletionClient): The ChatCompletion model client used
|
||||
to select the next speaker.
|
||||
name (str | None, optional): The name of the group chat, using
|
||||
:attr:`~agentdhal_agentchat.teams.SelectorGroupChat.DEFAULT_NAME` if not provided.
|
||||
The name is used by a parent team to identify this group chat so it must
|
||||
be unique within the parent team.
|
||||
description (str | None, optional): The description of the group chat, using
|
||||
:attr:`~agentdhal_agentchat.teams.SelectorGroupChat.DEFAULT_DESCRIPTION` if not provided.
|
||||
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
|
||||
Without a termination condition, the group chat will run indefinitely.
|
||||
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit.
|
||||
selector_prompt (str, optional): The prompt template to use for selecting the next speaker.
|
||||
Available fields: '{roles}', '{participants}', and '{history}'.
|
||||
`{participants}` is the names of candidates for selection. The format is `["<name1>", "<name2>", ...]`.
|
||||
`{roles}` is a newline-separated list of names and descriptions of the candidate agents. The format for each line is: `"<name> : <description>"`.
|
||||
`{history}` is the conversation history formatted as a double newline separated of names and message content. The format for each message is: `"<name> : <message content>"`.
|
||||
allow_repeated_speaker (bool, optional): Whether to include the previous speaker in the list of candidates to be selected for the next turn.
|
||||
Defaults to False. The model may still select the previous speaker -- a warning will be logged if this happens.
|
||||
max_selector_attempts (int, optional): The maximum number of attempts to select a speaker using the model. Defaults to 3.
|
||||
If the model fails to select a speaker after the maximum number of attempts, the previous speaker will be used if available,
|
||||
otherwise the first participant will be used.
|
||||
selector_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]], optional): A custom selector
|
||||
function that takes the conversation history and returns the name of the next speaker.
|
||||
If provided, this function will be used to override the model to select the next speaker.
|
||||
If the function returns None, the model will be used to select the next speaker.
|
||||
NOTE: `selector_func` is not serializable and will be ignored during serialization and deserialization process.
|
||||
candidate_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[List[str]]], optional):
|
||||
A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker
|
||||
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
|
||||
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
|
||||
custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): A list of custom message types that will be used in the group chat.
|
||||
If you are using custom message types or your agents produces custom message types, you need to specify them here.
|
||||
Make sure your custom message types are subclasses of :class:`~agentdhal_agentchat.messages.BaseAgentEvent` or :class:`~agentdhal_agentchat.messages.BaseChatMessage`.
|
||||
emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
|
||||
model_client_streaming (bool, optional): Whether to use streaming for the model client. (This is useful for reasoning models like QwQ). Defaults to False.
|
||||
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving
|
||||
:class:`~agentdhal_core.models.LLMMessage`. It can be preloaded with initial messages. Messages stored in model context will be used for speaker selection. The initial messages will be cleared when the team is reset.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
|
||||
|
||||
Examples:
|
||||
|
||||
A team with multiple participants:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import SelectorGroupChat
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
async def lookup_hotel(location: str) -> str:
|
||||
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
|
||||
|
||||
async def lookup_flight(origin: str, destination: str) -> str:
|
||||
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
|
||||
|
||||
async def book_trip() -> str:
|
||||
return "Your trip is booked!"
|
||||
|
||||
travel_advisor = AssistantAgent(
|
||||
"Travel_Advisor",
|
||||
model_client,
|
||||
tools=[book_trip],
|
||||
description="Helps with travel planning.",
|
||||
)
|
||||
hotel_agent = AssistantAgent(
|
||||
"Hotel_Agent",
|
||||
model_client,
|
||||
tools=[lookup_hotel],
|
||||
description="Helps with hotel booking.",
|
||||
)
|
||||
flight_agent = AssistantAgent(
|
||||
"Flight_Agent",
|
||||
model_client,
|
||||
tools=[lookup_flight],
|
||||
description="Helps with flight booking.",
|
||||
)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team = SelectorGroupChat(
|
||||
[travel_advisor, hotel_agent, flight_agent],
|
||||
model_client=model_client,
|
||||
termination_condition=termination,
|
||||
)
|
||||
await Console(team.run_stream(task="Book a 3-day trip to new york."))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
A team with a custom selector function:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from typing import Sequence
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import SelectorGroupChat
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
def check_calculation(x: int, y: int, answer: int) -> str:
|
||||
if x + y == answer:
|
||||
return "Correct!"
|
||||
else:
|
||||
return "Incorrect!"
|
||||
|
||||
agent1 = AssistantAgent(
|
||||
"Agent1",
|
||||
model_client,
|
||||
description="For calculation",
|
||||
system_message="Calculate the sum of two numbers",
|
||||
)
|
||||
agent2 = AssistantAgent(
|
||||
"Agent2",
|
||||
model_client,
|
||||
tools=[check_calculation],
|
||||
description="For checking calculation",
|
||||
system_message="Check the answer and respond with 'Correct!' or 'Incorrect!'",
|
||||
)
|
||||
|
||||
def selector_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:
|
||||
if len(messages) == 1 or messages[-1].to_text() == "Incorrect!":
|
||||
return "Agent1"
|
||||
if messages[-1].source == "Agent1":
|
||||
return "Agent2"
|
||||
return None
|
||||
|
||||
termination = TextMentionTermination("Correct!")
|
||||
team = SelectorGroupChat(
|
||||
[agent1, agent2],
|
||||
model_client=model_client,
|
||||
selector_func=selector_func,
|
||||
termination_condition=termination,
|
||||
)
|
||||
|
||||
await Console(team.run_stream(task="What is 1 + 1?"))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
A team with custom model context:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_core.model_context import BufferedChatCompletionContext
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.teams import SelectorGroupChat
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
model_context = BufferedChatCompletionContext(buffer_size=5)
|
||||
|
||||
async def lookup_hotel(location: str) -> str:
|
||||
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
|
||||
|
||||
async def lookup_flight(origin: str, destination: str) -> str:
|
||||
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
|
||||
|
||||
async def book_trip() -> str:
|
||||
return "Your trip is booked!"
|
||||
|
||||
travel_advisor = AssistantAgent(
|
||||
"Travel_Advisor",
|
||||
model_client,
|
||||
tools=[book_trip],
|
||||
description="Helps with travel planning.",
|
||||
)
|
||||
hotel_agent = AssistantAgent(
|
||||
"Hotel_Agent",
|
||||
model_client,
|
||||
tools=[lookup_hotel],
|
||||
description="Helps with hotel booking.",
|
||||
)
|
||||
flight_agent = AssistantAgent(
|
||||
"Flight_Agent",
|
||||
model_client,
|
||||
tools=[lookup_flight],
|
||||
description="Helps with flight booking.",
|
||||
)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team = SelectorGroupChat(
|
||||
[travel_advisor, hotel_agent, flight_agent],
|
||||
model_client=model_client,
|
||||
termination_condition=termination,
|
||||
model_context=model_context,
|
||||
)
|
||||
await Console(team.run_stream(task="Book a 3-day trip to new york."))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_config_schema = SelectorGroupChatConfig
|
||||
component_provider_override = "agentdhal_agentchat.teams.SelectorGroupChat"
|
||||
|
||||
DEFAULT_NAME = "SelectorGroupChat"
|
||||
DEFAULT_DESCRIPTION = "A team of agents."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
participants: List[ChatAgent | Team],
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
selector_prompt: str = """You are in a role play game. The following roles are available:
|
||||
{roles}.
|
||||
Read the following conversation. Then select the next role from {participants} to play. Only return the role.
|
||||
|
||||
{history}
|
||||
|
||||
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
|
||||
""",
|
||||
allow_repeated_speaker: bool = False,
|
||||
max_selector_attempts: int = 3,
|
||||
selector_func: Optional[SelectorFuncType] = None,
|
||||
candidate_func: Optional[CandidateFuncType] = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
model_client_streaming: bool = False,
|
||||
model_context: ChatCompletionContext | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name or self.DEFAULT_NAME,
|
||||
description=description or self.DEFAULT_DESCRIPTION,
|
||||
participants=participants,
|
||||
group_chat_manager_name="SelectorGroupChatManager",
|
||||
group_chat_manager_class=SelectorGroupChatManager,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
# Validate the participants.
|
||||
if len(participants) < 2:
|
||||
raise ValueError("At least two participants are required for SelectorGroupChat.")
|
||||
self._selector_prompt = selector_prompt
|
||||
self._model_client = model_client
|
||||
self._allow_repeated_speaker = allow_repeated_speaker
|
||||
self._selector_func = selector_func
|
||||
self._max_selector_attempts = max_selector_attempts
|
||||
self._candidate_func = candidate_func
|
||||
self._model_client_streaming = model_client_streaming
|
||||
self._model_context = model_context
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], BaseGroupChatManager]:
|
||||
return lambda: SelectorGroupChatManager(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
self._model_client,
|
||||
self._selector_prompt,
|
||||
self._allow_repeated_speaker,
|
||||
self._selector_func,
|
||||
self._max_selector_attempts,
|
||||
self._candidate_func,
|
||||
self._emit_team_events,
|
||||
self._model_context,
|
||||
self._model_client_streaming,
|
||||
)
|
||||
|
||||
def _to_config(self) -> SelectorGroupChatConfig:
|
||||
return SelectorGroupChatConfig(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
participants=[participant.dump_component() for participant in self._participants],
|
||||
model_client=self._model_client.dump_component(),
|
||||
termination_condition=self._termination_condition.dump_component() if self._termination_condition else None,
|
||||
max_turns=self._max_turns,
|
||||
selector_prompt=self._selector_prompt,
|
||||
allow_repeated_speaker=self._allow_repeated_speaker,
|
||||
max_selector_attempts=self._max_selector_attempts,
|
||||
# selector_func=self._selector_func.dump_component() if self._selector_func else None,
|
||||
emit_team_events=self._emit_team_events,
|
||||
model_client_streaming=self._model_client_streaming,
|
||||
model_context=self._model_context.dump_component() if self._model_context else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: SelectorGroupChatConfig) -> Self:
|
||||
participants: List[ChatAgent | Team] = []
|
||||
for participant in config.participants:
|
||||
if participant.component_type == ChatAgent.component_type:
|
||||
participants.append(ChatAgent.load_component(participant))
|
||||
elif participant.component_type == Team.component_type:
|
||||
participants.append(Team.load_component(participant))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid participant component type: {participant.component_type}. " "Expected ChatAgent or Team."
|
||||
)
|
||||
return cls(
|
||||
participants=participants,
|
||||
model_client=ChatCompletionClient.load_component(config.model_client),
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
termination_condition=TerminationCondition.load_component(config.termination_condition)
|
||||
if config.termination_condition
|
||||
else None,
|
||||
max_turns=config.max_turns,
|
||||
selector_prompt=config.selector_prompt,
|
||||
allow_repeated_speaker=config.allow_repeated_speaker,
|
||||
max_selector_attempts=config.max_selector_attempts,
|
||||
# selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None])
|
||||
# if config.selector_func
|
||||
# else None,
|
||||
emit_team_events=config.emit_team_events,
|
||||
model_client_streaming=config.model_client_streaming,
|
||||
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
|
||||
)
|
||||
@@ -0,0 +1,72 @@
|
||||
import asyncio
|
||||
from typing import Any, Sequence
|
||||
|
||||
from agentdhal_core import MessageContext, RoutedAgent
|
||||
|
||||
|
||||
class FIFOLock:
|
||||
"""A lock that ensures coroutines acquire the lock in the order they request it."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._queue = asyncio.Queue[asyncio.Event]()
|
||||
self._locked = False
|
||||
|
||||
async def acquire(self) -> None:
|
||||
# If the lock is not held by any coroutine, set the lock to be held
|
||||
# by the current coroutine.
|
||||
if not self._locked:
|
||||
self._locked = True
|
||||
return
|
||||
|
||||
# If the lock is held by another coroutine, create an event and put it
|
||||
# in the queue. Wait for the event to be set.
|
||||
event = asyncio.Event()
|
||||
await self._queue.put(event)
|
||||
await event.wait()
|
||||
|
||||
def release(self) -> None:
|
||||
if not self._queue.empty():
|
||||
# If there are events in the queue, get the next event and set it.
|
||||
next_event = self._queue.get_nowait()
|
||||
next_event.set()
|
||||
else:
|
||||
# If there are no events in the queue, release the lock.
|
||||
self._locked = False
|
||||
|
||||
|
||||
class SequentialRoutedAgent(RoutedAgent):
|
||||
"""A subclass of :class:`agentdhal_core.RoutedAgent` that ensures
|
||||
that messages of certain types are processed sequentially
|
||||
using a FIFO lock.
|
||||
|
||||
This is useful for agents that need to maintain a strict order of
|
||||
processing messages, such as in a group chat scenario.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
description (str): The description of the agent.
|
||||
sequential_message_types (Sequence[Type[Any]]): A sequence of message types that should be
|
||||
processed sequentially. If a message of one of these types is received,
|
||||
the agent will acquire a FIFO lock to ensure that it is processed
|
||||
before any later messages that are also one of these types.
|
||||
"""
|
||||
|
||||
def __init__(self, description: str, sequential_message_types: Sequence[type[Any]]) -> None:
|
||||
super().__init__(description=description)
|
||||
self._fifo_lock = FIFOLock()
|
||||
self._sequential_message_types = sequential_message_types
|
||||
|
||||
async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None:
|
||||
if any(isinstance(message, sequential_type) for sequential_type in self._sequential_message_types):
|
||||
# Acquire the FIFO lock to ensure that this message is processed
|
||||
# in the order it was received.
|
||||
await self._fifo_lock.acquire()
|
||||
try:
|
||||
return await super().on_message_impl(message, ctx)
|
||||
finally:
|
||||
# Release the FIFO lock to allow the next message to be processed.
|
||||
self._fifo_lock.release()
|
||||
# If the message is not of a sequential type, process it normally.
|
||||
return await super().on_message_impl(message, ctx)
|
||||
@@ -0,0 +1,321 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, List, Mapping, Sequence
|
||||
|
||||
from agentdhal_core import AgentRuntime, Component, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, HandoffMessage, MessageFactory
|
||||
from ...state import SwarmManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
from ._events import GroupChatTermination
|
||||
|
||||
|
||||
class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
"""A group chat manager that selects the next speaker based on handoff message only."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
emit_team_events: bool,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
emit_team_events,
|
||||
)
|
||||
self._current_speaker = self._participant_names[0]
|
||||
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
"""Validate the start messages for the group chat."""
|
||||
# Check if any of the start messages is a handoff message.
|
||||
if messages:
|
||||
for message in messages:
|
||||
if isinstance(message, HandoffMessage):
|
||||
if message.target not in self._participant_names:
|
||||
raise ValueError(
|
||||
f"The target {message.target} is not one of the participants {self._participant_names}. "
|
||||
"If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target."
|
||||
)
|
||||
return
|
||||
|
||||
# Check if there is a handoff message in the thread that is not targeting a valid participant.
|
||||
for existing_message in reversed(self._message_thread):
|
||||
if isinstance(existing_message, HandoffMessage):
|
||||
if existing_message.target not in self._participant_names:
|
||||
raise ValueError(
|
||||
f"The existing handoff target {existing_message.target} is not one of the participants {self._participant_names}. "
|
||||
"If you are resuming Swarm with a new task make sure to include in your task "
|
||||
"a HandoffMessage with a valid participant as the target. For example, if you are "
|
||||
"resuming from a HandoffTermination, make sure the new task is a HandoffMessage "
|
||||
"with a valid participant as the target."
|
||||
)
|
||||
# The latest handoff message should always target a valid participant.
|
||||
# Do not look past the latest handoff message.
|
||||
return
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._current_turn = 0
|
||||
self._message_thread.clear()
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._current_speaker = self._participant_names[0]
|
||||
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Select a speaker from the participants based on handoff message.
|
||||
Looks for the last handoff message in the thread to determine the next speaker.
|
||||
|
||||
.. note::
|
||||
|
||||
This method always returns a single speaker.
|
||||
"""
|
||||
if len(thread) == 0:
|
||||
return [self._current_speaker]
|
||||
for message in reversed(thread):
|
||||
if isinstance(message, HandoffMessage):
|
||||
self._current_speaker = message.target
|
||||
# The latest handoff message should always target a valid participant.
|
||||
assert self._current_speaker in self._participant_names
|
||||
return [self._current_speaker]
|
||||
return self._current_speaker
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = SwarmManagerState(
|
||||
message_thread=[msg.dump() for msg in self._message_thread],
|
||||
current_turn=self._current_turn,
|
||||
current_speaker=self._current_speaker,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
swarm_state = SwarmManagerState.model_validate(state)
|
||||
self._message_thread = [self._message_factory.create(message) for message in swarm_state.message_thread]
|
||||
self._current_turn = swarm_state.current_turn
|
||||
self._current_speaker = swarm_state.current_speaker
|
||||
|
||||
|
||||
class SwarmConfig(BaseModel):
|
||||
"""The declarative configuration for Swarm."""
|
||||
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
participants: List[ComponentModel]
|
||||
termination_condition: ComponentModel | None = None
|
||||
max_turns: int | None = None
|
||||
emit_team_events: bool = False
|
||||
|
||||
|
||||
class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
||||
"""A group chat team that selects the next speaker based on handoff message only.
|
||||
|
||||
The first participant in the list of participants is the initial speaker.
|
||||
The next speaker is selected based on the :class:`~agentdhal_agentchat.messages.HandoffMessage` message
|
||||
sent by the current speaker. If no handoff message is sent, the current speaker
|
||||
continues to be the speaker.
|
||||
|
||||
.. note::
|
||||
|
||||
Unlike :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` and
|
||||
:class:`~agentdhal_agentchat.teams.SelectorGroupChat`, this group chat
|
||||
team does not support inner teams as participants.
|
||||
|
||||
Args:
|
||||
participants (List[ChatAgent]): The agents participating in the group chat. The first agent in the list is the initial speaker.
|
||||
name (str | None, optional): The name of the group chat, using :attr:`~agentdhal_agentchat.teams.Swarm.DEFAULT_NAME` if not provided.
|
||||
The name is used by a parent team to identify this group chat so it must be unique within the parent team.
|
||||
description (str | None, optional): The description of the group chat, using :attr:`~agentdhal_agentchat.teams.Swarm.DEFAULT_DESCRIPTION` if not provided.
|
||||
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
|
||||
Without a termination condition, the group chat will run indefinitely.
|
||||
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit.
|
||||
custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): A list of custom message types that will be used in the group chat.
|
||||
If you are using custom message types or your agents produces custom message types, you need to specify them here.
|
||||
Make sure your custom message types are subclasses of :class:`~agentdhal_agentchat.messages.BaseAgentEvent` or :class:`~agentdhal_agentchat.messages.BaseChatMessage`.
|
||||
emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
|
||||
|
||||
Basic example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import Swarm
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent(
|
||||
"Alice",
|
||||
model_client=model_client,
|
||||
handoffs=["Bob"],
|
||||
system_message="You are Alice and you only answer questions about yourself.",
|
||||
)
|
||||
agent2 = AssistantAgent(
|
||||
"Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January."
|
||||
)
|
||||
|
||||
termination = MaxMessageTermination(3)
|
||||
team = Swarm([agent1, agent2], termination_condition=termination)
|
||||
|
||||
stream = team.run_stream(task="What is bob's birthday?")
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Using the :class:`~agentdhal_agentchat.conditions.HandoffTermination` for human-in-the-loop handoff:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import Swarm
|
||||
from agentdhal_agentchat.conditions import HandoffTermination, MaxMessageTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.messages import HandoffMessage
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent = AssistantAgent(
|
||||
"Alice",
|
||||
model_client=model_client,
|
||||
handoffs=["user"],
|
||||
system_message="You are Alice and you only answer questions about yourself, ask the user for help if needed.",
|
||||
)
|
||||
termination = HandoffTermination(target="user") | MaxMessageTermination(3)
|
||||
team = Swarm([agent], termination_condition=termination)
|
||||
|
||||
# Start the conversation.
|
||||
await Console(team.run_stream(task="What is bob's birthday?"))
|
||||
|
||||
# Resume with user feedback.
|
||||
await Console(
|
||||
team.run_stream(
|
||||
task=HandoffMessage(source="user", target="Alice", content="Bob's birthday is on 1st January.")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_config_schema = SwarmConfig
|
||||
component_provider_override = "agentdhal_agentchat.teams.Swarm"
|
||||
|
||||
DEFAULT_NAME = "Swarm"
|
||||
DEFAULT_DESCRIPTION = "A team of agents."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
participants: List[ChatAgent],
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
) -> None:
|
||||
for participant in participants:
|
||||
if not isinstance(participant, ChatAgent):
|
||||
raise TypeError(f"Participant {participant} must be a ChatAgent.")
|
||||
super().__init__(
|
||||
name=name or self.DEFAULT_NAME,
|
||||
description=description or self.DEFAULT_DESCRIPTION,
|
||||
participants=[participant for participant in participants],
|
||||
group_chat_manager_name="SwarmGroupChatManager",
|
||||
group_chat_manager_class=SwarmGroupChatManager,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
# The first participant must be able to produce handoff messages.
|
||||
first_participant = self._participants[0]
|
||||
assert isinstance(first_participant, ChatAgent)
|
||||
if HandoffMessage not in first_participant.produced_message_types:
|
||||
raise ValueError("The first participant must be able to produce a handoff messages.")
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], SwarmGroupChatManager]:
|
||||
def _factory() -> SwarmGroupChatManager:
|
||||
return SwarmGroupChatManager(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
self._emit_team_events,
|
||||
)
|
||||
|
||||
return _factory
|
||||
|
||||
def _to_config(self) -> SwarmConfig:
|
||||
participants = [participant.dump_component() for participant in self._participants]
|
||||
termination_condition = self._termination_condition.dump_component() if self._termination_condition else None
|
||||
return SwarmConfig(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
participants=participants,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=self._max_turns,
|
||||
emit_team_events=self._emit_team_events,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: SwarmConfig) -> "Swarm":
|
||||
participants = [ChatAgent.load_component(participant) for participant in config.participants]
|
||||
termination_condition = (
|
||||
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
|
||||
)
|
||||
return cls(
|
||||
participants,
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=config.max_turns,
|
||||
emit_team_events=config.emit_team_events,
|
||||
)
|
||||
Reference in New Issue
Block a user