first commit
This commit is contained in:
85
agent_dhal/__init__.py
Normal file
85
agent_dhal/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
AgentDhal - Complete AI Agent Framework for DarkHal 2.0
|
||||
|
||||
A comprehensive agent framework providing:
|
||||
- Multi-agent conversation capabilities
|
||||
- Agent orchestration and team management
|
||||
- Tool integration and function calling
|
||||
- Model context management
|
||||
- Memory and state management
|
||||
- Customizable agent behaviors
|
||||
|
||||
Legal Attribution:
|
||||
This software is based on Microsoft AutoGen (https://github.com/microsoft/autogen)
|
||||
Licensed under MIT License. AgentDhal is a derivative work with
|
||||
modifications and extensions for the DarkHal project.
|
||||
|
||||
Copyright (c) 2025 DarkHal Project
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "DarkHal Project (based on Microsoft AutoGen)"
|
||||
|
||||
# Import core AgentDhal components
|
||||
from .agentdhal_core import (
|
||||
Agent,
|
||||
AgentId,
|
||||
AgentRuntime,
|
||||
SingleThreadedAgentRuntime,
|
||||
RoutedAgent,
|
||||
MessageContext,
|
||||
DefaultTopicId,
|
||||
message_handler,
|
||||
default_subscription,
|
||||
BaseAgent,
|
||||
AgentType,
|
||||
TopicId,
|
||||
Subscription
|
||||
)
|
||||
|
||||
# Import Dhal - our primary AI agent
|
||||
from .hal import Dhal, DhalConfig, create_dhal
|
||||
|
||||
# Import other AgentDhal components (available but not primary focus)
|
||||
try:
|
||||
from .agentdhal_agentchat import (
|
||||
AssistantAgent,
|
||||
UserProxyAgent,
|
||||
ChatAgent,
|
||||
Team
|
||||
)
|
||||
except ImportError:
|
||||
# Graceful fallback if agentchat modules have issues
|
||||
AssistantAgent = None
|
||||
UserProxyAgent = None
|
||||
ChatAgent = None
|
||||
Team = None
|
||||
|
||||
__all__ = [
|
||||
# Core framework
|
||||
"Agent",
|
||||
"AgentId",
|
||||
"AgentRuntime",
|
||||
"SingleThreadedAgentRuntime",
|
||||
"RoutedAgent",
|
||||
"MessageContext",
|
||||
"DefaultTopicId",
|
||||
"message_handler",
|
||||
"default_subscription",
|
||||
"BaseAgent",
|
||||
"AgentType",
|
||||
"TopicId",
|
||||
"Subscription",
|
||||
|
||||
# Primary Hal Agent
|
||||
"Hal",
|
||||
"HalConfig",
|
||||
"create_hal",
|
||||
|
||||
# Additional Agent Components (if available)
|
||||
"AssistantAgent",
|
||||
"UserProxyAgent",
|
||||
"ChatAgent",
|
||||
"Team"
|
||||
]
|
||||
14
agent_dhal/agentdhal_agentchat/__init__.py
Normal file
14
agent_dhal/agentdhal_agentchat/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
This module provides the main entry point for the agentdhal_agentchat package.
|
||||
It includes logger names for trace and event logs, and retrieves the package version.
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
TRACE_LOGGER_NAME = "agentdhal_agentchat"
|
||||
"""Logger name for trace logs."""
|
||||
|
||||
EVENT_LOGGER_NAME = "agentdhal_agentchat.events"
|
||||
"""Logger name for event logs."""
|
||||
|
||||
__version__ = importlib.metadata.version("agentdhal_agentchat")
|
||||
25
agent_dhal/agentdhal_agentchat/agents/__init__.py
Normal file
25
agent_dhal/agentdhal_agentchat/agents/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
This module initializes various pre-defined agents provided by the package.
|
||||
BaseChatAgent is the base class for all agents in AgentChat.
|
||||
"""
|
||||
|
||||
from ._assistant_agent import AssistantAgent
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
from ._code_executor_agent import ApprovalFuncType, ApprovalRequest, ApprovalResponse, CodeExecutorAgent
|
||||
from ._message_filter_agent import MessageFilterAgent, MessageFilterConfig, PerSourceFilter
|
||||
from ._society_of_mind_agent import SocietyOfMindAgent
|
||||
from ._user_proxy_agent import UserProxyAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseChatAgent",
|
||||
"AssistantAgent",
|
||||
"CodeExecutorAgent",
|
||||
"SocietyOfMindAgent",
|
||||
"UserProxyAgent",
|
||||
"MessageFilterAgent",
|
||||
"MessageFilterConfig",
|
||||
"PerSourceFilter",
|
||||
"ApprovalRequest",
|
||||
"ApprovalResponse",
|
||||
"ApprovalFuncType",
|
||||
]
|
||||
1699
agent_dhal/agentdhal_agentchat/agents/_assistant_agent.py
Normal file
1699
agent_dhal/agentdhal_agentchat/agents/_assistant_agent.py
Normal file
File diff suppressed because it is too large
Load Diff
245
agent_dhal/agentdhal_agentchat/agents/_base_chat_agent.py
Normal file
245
agent_dhal/agentdhal_agentchat/agents/_base_chat_agent.py
Normal file
@@ -0,0 +1,245 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence
|
||||
|
||||
from agentdhal_core import CancellationToken, ComponentBase, trace_create_agent_span, trace_invoke_agent_span
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..base import ChatAgent, Response, TaskResult
|
||||
from ..messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
TextMessage,
|
||||
)
|
||||
from ..state import BaseState
|
||||
|
||||
|
||||
class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
||||
"""Base class for a chat agent.
|
||||
|
||||
This abstract class provides a base implementation for a :class:`ChatAgent`.
|
||||
To create a new chat agent, subclass this class and implement the
|
||||
:meth:`on_messages`, :meth:`on_reset`, and :attr:`produced_message_types`.
|
||||
If streaming is required, also implement the :meth:`on_messages_stream` method.
|
||||
|
||||
An agent is considered stateful and maintains its state between calls to
|
||||
the :meth:`on_messages` or :meth:`on_messages_stream` methods.
|
||||
The agent should store its state in the
|
||||
agent instance. The agent should also implement the :meth:`on_reset` method
|
||||
to reset the agent to its initialization state.
|
||||
|
||||
.. note::
|
||||
|
||||
The caller should only pass the new messages to the agent on each call
|
||||
to the :meth:`on_messages` or :meth:`on_messages_stream` method.
|
||||
Do not pass the entire conversation history to the agent on each call.
|
||||
This design principle must be followed when creating a new agent.
|
||||
"""
|
||||
|
||||
component_type = "agent"
|
||||
|
||||
def __init__(self, name: str, description: str) -> None:
|
||||
"""Initialize the agent with a name and description."""
|
||||
with trace_create_agent_span(
|
||||
agent_name=name,
|
||||
agent_description=description,
|
||||
):
|
||||
self._name = name
|
||||
if self._name.isidentifier() is False:
|
||||
raise ValueError("The agent name must be a valid Python identifier.")
|
||||
self._description = description
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the agent. This is used by team to uniquely identify
|
||||
the agent. It should be unique within the team."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""The description of the agent. This is used by team to
|
||||
make decisions about which agents to use. The description should
|
||||
describe the agent's capabilities and how to interact with it."""
|
||||
return self._description
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""The types of messages that the agent produces in the
|
||||
:attr:`Response.chat_message` field. They must be :class:`BaseChatMessage` types."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
"""Handles incoming messages and returns a response.
|
||||
|
||||
.. note::
|
||||
|
||||
Agents are stateful and the messages passed to this method should
|
||||
be the new messages since the last call to this method. The agent
|
||||
should maintain its state between calls to this method. For example,
|
||||
if the agent needs to remember the previous messages to respond to
|
||||
the current message, it should store the previous messages in the
|
||||
agent state.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""Handles incoming messages and returns a stream of messages and
|
||||
and the final item is the response. The base implementation in
|
||||
:class:`BaseChatAgent` simply calls :meth:`on_messages` and yields
|
||||
the messages in the response.
|
||||
|
||||
.. note::
|
||||
|
||||
Agents are stateful and the messages passed to this method should
|
||||
be the new messages since the last call to this method. The agent
|
||||
should maintain its state between calls to this method. For example,
|
||||
if the agent needs to remember the previous messages to respond to
|
||||
the current message, it should store the previous messages in the
|
||||
agent state.
|
||||
|
||||
"""
|
||||
response = await self.on_messages(messages, cancellation_token)
|
||||
for inner_message in response.inner_messages or []:
|
||||
yield inner_message
|
||||
yield response
|
||||
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
output_task_messages: bool = True,
|
||||
) -> TaskResult:
|
||||
"""Run the agent with the given task and return the result."""
|
||||
with trace_invoke_agent_span(
|
||||
agent_name=self.name,
|
||||
agent_description=self.description,
|
||||
):
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
input_messages: List[BaseChatMessage] = []
|
||||
output_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if task is None:
|
||||
pass
|
||||
elif isinstance(task, str):
|
||||
text_msg = TextMessage(content=task, source="user")
|
||||
input_messages.append(text_msg)
|
||||
if output_task_messages:
|
||||
output_messages.append(text_msg)
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
input_messages.append(task)
|
||||
if output_task_messages:
|
||||
output_messages.append(task)
|
||||
else:
|
||||
if not task:
|
||||
raise ValueError("Task list cannot be empty.")
|
||||
# Task is a sequence of messages.
|
||||
for msg in task:
|
||||
if isinstance(msg, BaseChatMessage):
|
||||
input_messages.append(msg)
|
||||
if output_task_messages:
|
||||
output_messages.append(msg)
|
||||
else:
|
||||
raise ValueError(f"Invalid message type in sequence: {type(msg)}")
|
||||
response = await self.on_messages(input_messages, cancellation_token)
|
||||
if response.inner_messages is not None:
|
||||
output_messages += response.inner_messages
|
||||
output_messages.append(response.chat_message)
|
||||
return TaskResult(messages=output_messages)
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
output_task_messages: bool = True,
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
|
||||
"""Run the agent with the given task and return a stream of messages
|
||||
and the final task result as the last item in the stream.
|
||||
|
||||
Args:
|
||||
task: The task to run. Can be a string, a single message, or a sequence of messages.
|
||||
cancellation_token: The cancellation token to kill the task immediately.
|
||||
output_task_messages: Whether to include task messages in the output stream. Defaults to True for backward compatibility.
|
||||
"""
|
||||
with trace_invoke_agent_span(
|
||||
agent_name=self.name,
|
||||
agent_description=self.description,
|
||||
):
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
input_messages: List[BaseChatMessage] = []
|
||||
output_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if task is None:
|
||||
pass
|
||||
elif isinstance(task, str):
|
||||
text_msg = TextMessage(content=task, source="user")
|
||||
input_messages.append(text_msg)
|
||||
if output_task_messages:
|
||||
output_messages.append(text_msg)
|
||||
yield text_msg
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
input_messages.append(task)
|
||||
if output_task_messages:
|
||||
output_messages.append(task)
|
||||
yield task
|
||||
else:
|
||||
if not task:
|
||||
raise ValueError("Task list cannot be empty.")
|
||||
for msg in task:
|
||||
if isinstance(msg, BaseChatMessage):
|
||||
input_messages.append(msg)
|
||||
if output_task_messages:
|
||||
output_messages.append(msg)
|
||||
yield msg
|
||||
else:
|
||||
raise ValueError(f"Invalid message type in sequence: {type(msg)}")
|
||||
async for message in self.on_messages_stream(input_messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
yield message.chat_message
|
||||
output_messages.append(message.chat_message)
|
||||
yield TaskResult(messages=output_messages)
|
||||
else:
|
||||
yield message
|
||||
if isinstance(message, ModelClientStreamingChunkEvent):
|
||||
# Skip the model client streaming chunk events.
|
||||
continue
|
||||
output_messages.append(message)
|
||||
|
||||
@abstractmethod
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Resets the agent to its initialization state."""
|
||||
...
|
||||
|
||||
async def on_pause(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Called when the agent is paused while running in its :meth:`on_messages` or
|
||||
:meth:`on_messages_stream` method. This is a no-op by default in the
|
||||
:class:`BaseChatAgent` class. Subclasses can override this method to
|
||||
implement custom pause behavior."""
|
||||
pass
|
||||
|
||||
async def on_resume(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Called when the agent is resumed from a pause while running in
|
||||
its :meth:`on_messages` or :meth:`on_messages_stream` method.
|
||||
This is a no-op by default in the :class:`BaseChatAgent` class.
|
||||
Subclasses can override this method to implement custom resume behavior."""
|
||||
pass
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Export state. Default implementation for stateless agents."""
|
||||
return BaseState().model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Restore agent from saved state. Default implementation for stateless agents."""
|
||||
BaseState.model_validate(state)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Release any resources held by the agent. This is a no-op by default in the
|
||||
:class:`BaseChatAgent` class. Subclasses can override this method to
|
||||
implement custom close behavior."""
|
||||
pass
|
||||
881
agent_dhal/agentdhal_agentchat/agents/_code_executor_agent.py
Normal file
881
agent_dhal/agentdhal_agentchat/agents/_code_executor_agent.py
Normal file
@@ -0,0 +1,881 @@
|
||||
import logging
|
||||
import re
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import (
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from agentdhal_core import CancellationToken, Component, ComponentModel
|
||||
from agentdhal_core.code_executor import CodeBlock, CodeExecutor, CodeResult
|
||||
from agentdhal_core.model_context import (
|
||||
ChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .. import EVENT_LOGGER_NAME
|
||||
from ..base import Response
|
||||
from ..messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
CodeExecutionEvent,
|
||||
CodeGenerationEvent,
|
||||
HandoffMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
TextMessage,
|
||||
ThoughtEvent,
|
||||
)
|
||||
from ..utils import remove_images
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class CodeExecutorAgentConfig(BaseModel):
|
||||
"""Configuration for CodeExecutorAgent"""
|
||||
|
||||
name: str
|
||||
code_executor: ComponentModel
|
||||
model_client: ComponentModel | None = None
|
||||
description: str | None = None
|
||||
sources: List[str] | None = None
|
||||
system_message: str | None = None
|
||||
model_client_stream: bool = False
|
||||
model_context: ComponentModel | None = None
|
||||
supported_languages: List[str] | None = None
|
||||
|
||||
|
||||
class RetryDecision(BaseModel):
|
||||
reason: str
|
||||
retry: bool
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
"""Request for approval of code execution."""
|
||||
|
||||
code: str
|
||||
context: List[LLMMessage]
|
||||
|
||||
|
||||
class ApprovalResponse(BaseModel):
|
||||
"""Response to approval request."""
|
||||
|
||||
approved: bool
|
||||
reason: str
|
||||
|
||||
|
||||
# Type aliases for approval functions
|
||||
SyncApprovalFunc = Callable[[ApprovalRequest], ApprovalResponse]
|
||||
AsyncApprovalFunc = Callable[[ApprovalRequest], Awaitable[ApprovalResponse]]
|
||||
ApprovalFuncType = Union[SyncApprovalFunc, AsyncApprovalFunc]
|
||||
|
||||
|
||||
class CodeExecutorAgent(BaseChatAgent, Component[CodeExecutorAgentConfig]):
|
||||
"""(Experimental) An agent that generates and executes code snippets based on user instructions.
|
||||
|
||||
.. note::
|
||||
|
||||
This agent is experimental and may change in future releases.
|
||||
|
||||
It is typically used within a team with another agent that generates code snippets
|
||||
to be executed or alone with `model_client` provided so that it can generate code
|
||||
based on user query, execute it and reflect on the code result.
|
||||
|
||||
When used with `model_client`, it will generate code snippets using the model
|
||||
and execute them using the provided `code_executor`. The model will also reflect on the
|
||||
code execution results. The agent will yield the final reflection result from the model
|
||||
as the final response.
|
||||
|
||||
When used without `model_client`, it will only execute code blocks found in
|
||||
:class:`~agentdhal_agentchat.messages.TextMessage` messages and returns the output
|
||||
of the code execution.
|
||||
|
||||
.. note::
|
||||
|
||||
Using :class:`~agentdhal_agentchat.agents.AssistantAgent` with
|
||||
:class:`~agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool`
|
||||
is an alternative to this agent. However, the model for that agent will
|
||||
have to generate properly escaped code string as a parameter to the tool.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
code_executor (CodeExecutor): The code executor responsible for executing code received in messages
|
||||
(:py:class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor` recommended. See example below)
|
||||
model_client (ChatCompletionClient, optional): The model client to use for inference and generating code.
|
||||
If not provided, the agent will only execute code blocks found in input messages.
|
||||
Currently, the model must support structured output mode, which is required for
|
||||
the automatic retry mechanism to work.
|
||||
model_client_stream (bool, optional): If `True`, the model client will be used in streaming mode.
|
||||
:meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will
|
||||
also yield :class:`~agentdhal_agentchat.messages.ModelClientStreamingChunkEvent`
|
||||
messages as the model client produces chunks of response. Defaults to `False`.
|
||||
description (str, optional): The description of the agent. If not provided,
|
||||
:class:`~agentdhal_agentchat.agents.CodeExecutorAgent.DEFAULT_AGENT_DESCRIPTION` will be used.
|
||||
system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable.
|
||||
Defaults to :class:`~agentdhal_agentchat.agents.CodeExecutorAgent.DEFAULT_SYSTEM_MESSAGE`. This is only used if `model_client` is provided.
|
||||
sources (Sequence[str], optional): Check only messages from the specified agents for the code to execute.
|
||||
This is useful when the agent is part of a group chat and you want to limit the code execution to messages from specific agents.
|
||||
If not provided, all messages will be checked for code blocks.
|
||||
This is only used if `model_client` is not provided.
|
||||
max_retries_on_error (int, optional): The maximum number of retries on error. If the code execution fails, the agent will retry up to this number of times.
|
||||
If the code execution fails after this number of retries, the agent will yield a reflection result.
|
||||
supported_languages (List[str], optional): List of programming languages that will be parsed and executed from agent response;
|
||||
others will be ignored. Defaults to DEFAULT_SUPPORTED_LANGUAGES.
|
||||
approval_func (Optional[Union[Callable[[ApprovalRequest], ApprovalResponse], Callable[[ApprovalRequest], Awaitable[ApprovalResponse]]]], optional): A function that is called before each code execution to get approval.
|
||||
The function takes an ApprovalRequest containing the code to be executed and the current context, and returns an ApprovalResponse.
|
||||
The function can be either synchronous or asynchronous. If None (default), all code executions are automatically approved.
|
||||
If set, the agent cannot be serialized using :meth:`~agentdhal_agentchat.agents.CodeExecutorAgent.dump_component`.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
It is recommended that the `CodeExecutorAgent` agent uses a Docker container to execute code. This ensures that model-generated code is executed in an isolated environment. To use Docker, your environment must have Docker installed and running.
|
||||
Follow the installation instructions for `Docker <https://docs.docker.com/get-docker/>`_.
|
||||
|
||||
.. note::
|
||||
|
||||
The code executor only processes code that is properly formatted in markdown code blocks using triple backticks.
|
||||
For example:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
```python
|
||||
print("Hello World")
|
||||
```
|
||||
|
||||
# or
|
||||
|
||||
```sh
|
||||
echo "Hello World"
|
||||
```
|
||||
|
||||
In this example, we show how to set up a `CodeExecutorAgent` agent that uses the
|
||||
:py:class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor`
|
||||
to execute code snippets in a Docker container. The `work_dir` parameter indicates
|
||||
where all executed files are first saved locally before being executed in the Docker container.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import CodeExecutorAgent, ApprovalRequest, ApprovalResponse
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor
|
||||
from agentdhal_core import CancellationToken
|
||||
|
||||
|
||||
def simple_approval_func(request: ApprovalRequest) -> ApprovalResponse:
|
||||
\"\"\"Simple approval function that requests user input for code execution approval.\"\"\"
|
||||
print("Code execution approval requested:")
|
||||
print("=" * 50)
|
||||
print(request.code)
|
||||
print("=" * 50)
|
||||
|
||||
while True:
|
||||
user_input = input("Do you want to execute this code? (y/n): ").strip().lower()
|
||||
if user_input in ['y', 'yes']:
|
||||
return ApprovalResponse(approved=True, reason='Approved by user')
|
||||
elif user_input in ['n', 'no']:
|
||||
return ApprovalResponse(approved=False, reason='Denied by user')
|
||||
else:
|
||||
print("Please enter 'y' for yes or 'n' for no.")
|
||||
|
||||
|
||||
async def run_code_executor_agent() -> None:
|
||||
# Create a code executor agent that uses a Docker container to execute code.
|
||||
code_executor = DockerCommandLineCodeExecutor(work_dir="coding")
|
||||
await code_executor.start()
|
||||
code_executor_agent = CodeExecutorAgent(
|
||||
"code_executor",
|
||||
code_executor=code_executor,
|
||||
approval_func=simple_approval_func
|
||||
)
|
||||
|
||||
# Run the agent with a given code snippet.
|
||||
task = TextMessage(
|
||||
content='''Here is some code
|
||||
```python
|
||||
print('Hello world')
|
||||
```
|
||||
''',
|
||||
source="user",
|
||||
)
|
||||
response = await code_executor_agent.on_messages([task], CancellationToken())
|
||||
print(response.chat_message)
|
||||
|
||||
# Stop the code executor.
|
||||
await code_executor.stop()
|
||||
|
||||
|
||||
asyncio.run(run_code_executor_agent())
|
||||
|
||||
In this example, we show how to set up a `CodeExecutorAgent` agent that uses the
|
||||
:py:class:`~docker.types.DeviceRequest` to expose a GPU to the container for cuda-accelerated code execution.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import CodeExecutorAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor
|
||||
from agentdhal_core import CancellationToken
|
||||
from docker.types import DeviceRequest
|
||||
|
||||
|
||||
async def run_code_executor_agent() -> None:
|
||||
# Create a code executor agent that uses a Docker container to execute code.
|
||||
code_executor = DockerCommandLineCodeExecutor(
|
||||
work_dir="coding", device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])]
|
||||
)
|
||||
await code_executor.start()
|
||||
code_executor_agent = CodeExecutorAgent("code_executor", code_executor=code_executor)
|
||||
|
||||
# Display the GPU information
|
||||
task = TextMessage(
|
||||
content='''Here is some code
|
||||
```sh
|
||||
nvidia-smi
|
||||
```
|
||||
''',
|
||||
source="user",
|
||||
)
|
||||
response = await code_executor_agent.on_messages([task], CancellationToken())
|
||||
print(response.chat_message)
|
||||
|
||||
# Stop the code executor.
|
||||
await code_executor.stop()
|
||||
|
||||
|
||||
asyncio.run(run_code_executor_agent())
|
||||
|
||||
In the following example, we show how to setup `CodeExecutorAgent` without `model_client` parameter for executing code blocks generated by other agents in a group chat using :py:class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor`
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent, CodeExecutorAgent, ApprovalRequest, ApprovalResponse
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
termination_condition = MaxMessageTermination(3)
|
||||
|
||||
|
||||
def group_chat_approval_func(request: ApprovalRequest) -> ApprovalResponse:
|
||||
\"\"\"Approval function for group chat that allows basic Python operations.\"\"\"
|
||||
# Allow common safe operations
|
||||
safe_operations = ["print(", "import ", "def ", "class ", "if ", "for ", "while "]
|
||||
if any(op in request.code for op in safe_operations):
|
||||
return ApprovalResponse(approved=True, reason='Safe Python operation')
|
||||
|
||||
# Deny file system operations in group chat
|
||||
dangerous_operations = ["open(", "file(", "os.", "subprocess", "eval(", "exec("]
|
||||
if any(op in request.code for op in dangerous_operations):
|
||||
return ApprovalResponse(approved=False, reason='File system or dangerous operation not allowed')
|
||||
|
||||
return ApprovalResponse(approved=True, reason='Operation approved')
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
# define the Docker CLI Code Executor
|
||||
code_executor = DockerCommandLineCodeExecutor(work_dir="coding")
|
||||
|
||||
# start the execution container
|
||||
await code_executor.start()
|
||||
|
||||
code_executor_agent = CodeExecutorAgent(
|
||||
"code_executor_agent",
|
||||
code_executor=code_executor,
|
||||
approval_func=group_chat_approval_func
|
||||
)
|
||||
coder_agent = AssistantAgent("coder_agent", model_client=model_client)
|
||||
|
||||
groupchat = RoundRobinGroupChat(
|
||||
participants=[coder_agent, code_executor_agent], termination_condition=termination_condition
|
||||
)
|
||||
|
||||
task = "Write python code to print Hello World!"
|
||||
await Console(groupchat.run_stream(task=task))
|
||||
|
||||
# stop the execution container
|
||||
await code_executor.stop()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
---------- user ----------
|
||||
Write python code to print Hello World!
|
||||
---------- coder_agent ----------
|
||||
Certainly! Here's a simple Python code to print "Hello World!":
|
||||
|
||||
```python
|
||||
print("Hello World!")
|
||||
```
|
||||
|
||||
You can run this code in any Python environment to display the message.
|
||||
---------- code_executor_agent ----------
|
||||
Hello World!
|
||||
|
||||
In the following example, we show how to setup `CodeExecutorAgent` with `model_client`
|
||||
that can generate its own code without the help of any other agent and executing it in
|
||||
:py:class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor`.
|
||||
It also demonstrates using a model-based approval function that reviews the code for safety before execution.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_core.models import SystemMessage, UserMessage
|
||||
|
||||
from agentdhal_agentchat.agents import CodeExecutorAgent, ApprovalRequest, ApprovalResponse
|
||||
from agentdhal_agentchat.conditions import TextMessageTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
termination_condition = TextMessageTermination("code_executor_agent")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
async def model_client_approval_func(request: ApprovalRequest) -> ApprovalResponse:
|
||||
instruction = "Approve or reject the code in the last message based on whether it is dangerous or not. Use the following JSON format for your response: {approved: true/false, reason: 'your reason here'}"
|
||||
response = await model_client.create(
|
||||
messages=[SystemMessage(content=instruction)]
|
||||
+ request.context
|
||||
+ [UserMessage(content=request.code, source="user")],
|
||||
json_output=ApprovalResponse,
|
||||
)
|
||||
assert isinstance(response.content, str)
|
||||
return ApprovalResponse.model_validate_json(response.content)
|
||||
|
||||
# define the Docker CLI Code Executor
|
||||
code_executor = DockerCommandLineCodeExecutor(work_dir="coding")
|
||||
|
||||
# start the execution container
|
||||
await code_executor.start()
|
||||
|
||||
code_executor_agent = CodeExecutorAgent(
|
||||
"code_executor_agent",
|
||||
code_executor=code_executor,
|
||||
model_client=model_client,
|
||||
approval_func=model_client_approval_func,
|
||||
)
|
||||
|
||||
task = "Write python code to print Hello World!"
|
||||
await Console(code_executor_agent.run_stream(task=task))
|
||||
|
||||
# stop the execution container
|
||||
await code_executor.stop()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
---------- user ----------
|
||||
Write python code to print Hello World!
|
||||
---------- code_executor_agent ----------
|
||||
Certainly! Here is a simple Python code to print "Hello World!" to the console:
|
||||
|
||||
```python
|
||||
print("Hello World!")
|
||||
```
|
||||
|
||||
Let's execute it to confirm the output.
|
||||
---------- code_executor_agent ----------
|
||||
Hello World!
|
||||
|
||||
---------- code_executor_agent ----------
|
||||
The code has been executed successfully, and it printed "Hello World!" as expected. If you have any more requests or questions, feel free to ask!
|
||||
|
||||
"""
|
||||
|
||||
DEFAULT_TERMINAL_DESCRIPTION = "A computer terminal that performs no other action than running Python scripts (provided to it quoted in ```python code blocks), or sh shell scripts (provided to it quoted in ```sh code blocks)."
|
||||
DEFAULT_AGENT_DESCRIPTION = "A Code Execution Agent that generates and executes Python and shell scripts based on user instructions. It ensures correctness, efficiency, and minimal errors while gracefully handling edge cases."
|
||||
DEFAULT_SYSTEM_MESSAGE = "You are a Code Execution Agent. Your role is to generate and execute Python code and shell scripts based on user instructions, ensuring correctness, efficiency, and minimal errors. Handle edge cases gracefully. Python code should be provided in ```python code blocks, and sh shell scripts should be provided in ```sh code blocks for execution."
|
||||
NO_CODE_BLOCKS_FOUND_MESSAGE = "No code blocks found in the thread. Please provide at least one markdown-encoded code block to execute (i.e., quoting code in ```python or ```sh code blocks)."
|
||||
DEFAULT_SUPPORTED_LANGUAGES = ["python", "sh"]
|
||||
|
||||
component_config_schema = CodeExecutorAgentConfig
|
||||
component_provider_override = "agentdhal_agentchat.agents.CodeExecutorAgent"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
code_executor: CodeExecutor,
|
||||
*,
|
||||
model_client: ChatCompletionClient | None = None,
|
||||
model_context: ChatCompletionContext | None = None,
|
||||
model_client_stream: bool = False,
|
||||
max_retries_on_error: int = 0,
|
||||
description: str | None = None,
|
||||
system_message: str | None = DEFAULT_SYSTEM_MESSAGE,
|
||||
sources: Sequence[str] | None = None,
|
||||
supported_languages: List[str] | None = None,
|
||||
approval_func: Optional[ApprovalFuncType] = None,
|
||||
) -> None:
|
||||
if description is None:
|
||||
if model_client is None:
|
||||
description = CodeExecutorAgent.DEFAULT_TERMINAL_DESCRIPTION
|
||||
else:
|
||||
description = CodeExecutorAgent.DEFAULT_AGENT_DESCRIPTION
|
||||
|
||||
super().__init__(name=name, description=description)
|
||||
self._code_executor = code_executor
|
||||
self._sources = sources
|
||||
self._model_client_stream = model_client_stream
|
||||
self._max_retries_on_error = max_retries_on_error
|
||||
self._approval_func = approval_func
|
||||
self._approval_func_is_async = approval_func is not None and iscoroutinefunction(approval_func)
|
||||
|
||||
if supported_languages is not None:
|
||||
self._supported_languages = supported_languages
|
||||
else:
|
||||
self._supported_languages = CodeExecutorAgent.DEFAULT_SUPPORTED_LANGUAGES
|
||||
|
||||
self._supported_languages_regex = "|".join(re.escape(lang) for lang in self._supported_languages)
|
||||
|
||||
self._model_client = None
|
||||
if model_client is not None:
|
||||
self._model_client = model_client
|
||||
|
||||
if model_context is not None:
|
||||
self._model_context = model_context
|
||||
else:
|
||||
self._model_context = UnboundedChatCompletionContext()
|
||||
|
||||
self._system_messaages: List[SystemMessage] = []
|
||||
if system_message is None:
|
||||
self._system_messages = []
|
||||
else:
|
||||
self._system_messages = [SystemMessage(content=system_message)]
|
||||
|
||||
if self._max_retries_on_error > 0:
|
||||
if not self._model_client or not self._model_client.model_info:
|
||||
raise ValueError("model_client.model_info must be provided when max_retries_on_error > 0")
|
||||
if not self._model_client.model_info["structured_output"]:
|
||||
raise ValueError("Specified model_client doesn't support structured output mode.")
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""The types of messages that the code executor agent produces."""
|
||||
return (TextMessage,)
|
||||
|
||||
@property
|
||||
def model_context(self) -> ChatCompletionContext:
|
||||
"""
|
||||
The model context in use by the agent.
|
||||
"""
|
||||
return self._model_context
|
||||
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
return message
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""
|
||||
Process the incoming messages with the assistant agent and yield events/responses as they happen.
|
||||
"""
|
||||
|
||||
# Gather all relevant state here
|
||||
agent_name = self.name
|
||||
model_context = self._model_context
|
||||
system_messages = self._system_messages
|
||||
model_client = self._model_client
|
||||
model_client_stream = self._model_client_stream
|
||||
max_retries_on_error = self._max_retries_on_error
|
||||
|
||||
execution_result: CodeResult | None = None
|
||||
if model_client is None: # default behaviour for backward compatibility
|
||||
# execute generated code if present
|
||||
code_blocks: List[CodeBlock] = await self.extract_code_blocks_from_messages(messages)
|
||||
if not code_blocks:
|
||||
yield Response(
|
||||
chat_message=TextMessage(
|
||||
content=self.NO_CODE_BLOCKS_FOUND_MESSAGE,
|
||||
source=agent_name,
|
||||
)
|
||||
)
|
||||
return
|
||||
execution_result = await self.execute_code_block(code_blocks, cancellation_token)
|
||||
yield Response(chat_message=TextMessage(content=execution_result.output, source=self.name))
|
||||
return
|
||||
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
|
||||
for nth_try in range(max_retries_on_error + 1): # Do one default generation, execution and inference loop
|
||||
# Step 1: Add new user/handoff messages to the model context
|
||||
await self._add_messages_to_context(
|
||||
model_context=model_context,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Step 2: Run inference with the model context
|
||||
model_result = None
|
||||
async for inference_output in self._call_llm(
|
||||
model_client=model_client,
|
||||
model_client_stream=model_client_stream,
|
||||
system_messages=system_messages,
|
||||
model_context=model_context,
|
||||
agent_name=agent_name,
|
||||
cancellation_token=cancellation_token,
|
||||
):
|
||||
if isinstance(inference_output, CreateResult):
|
||||
model_result = inference_output
|
||||
else:
|
||||
# Streaming chunk event
|
||||
yield inference_output
|
||||
|
||||
assert model_result is not None, "No model result was produced."
|
||||
|
||||
# Step 3: [NEW] If the model produced a hidden "thought," yield it as an event
|
||||
if model_result.thought:
|
||||
thought_event = ThoughtEvent(content=model_result.thought, source=agent_name)
|
||||
yield thought_event
|
||||
inner_messages.append(thought_event)
|
||||
|
||||
# Step 4: Add the assistant message to the model context (including thought if present)
|
||||
await model_context.add_message(
|
||||
AssistantMessage(
|
||||
content=model_result.content,
|
||||
source=agent_name,
|
||||
thought=getattr(model_result, "thought", None),
|
||||
)
|
||||
)
|
||||
|
||||
# Step 5: Extract the code blocks from inferred text
|
||||
assert isinstance(model_result.content, str), "Expected inferred model_result.content to be of type str."
|
||||
code_blocks = self._extract_markdown_code_blocks(str(model_result.content))
|
||||
|
||||
# Step 6: Exit the loop if no code blocks found
|
||||
if not code_blocks:
|
||||
yield Response(
|
||||
chat_message=TextMessage(
|
||||
content=str(model_result.content),
|
||||
source=agent_name,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Step 7: Yield a CodeGenerationEvent
|
||||
inferred_text_message: CodeGenerationEvent = CodeGenerationEvent(
|
||||
retry_attempt=nth_try,
|
||||
content=model_result.content,
|
||||
code_blocks=code_blocks,
|
||||
source=agent_name,
|
||||
)
|
||||
|
||||
yield inferred_text_message
|
||||
|
||||
# Step 8: Execute the extracted code blocks
|
||||
execution_result = await self.execute_code_block(inferred_text_message.code_blocks, cancellation_token)
|
||||
|
||||
# Step 9: Update model context with the code execution result
|
||||
await model_context.add_message(
|
||||
UserMessage(
|
||||
content=execution_result.output,
|
||||
source=agent_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Step 10: Yield a CodeExecutionEvent
|
||||
yield CodeExecutionEvent(retry_attempt=nth_try, result=execution_result, source=self.name)
|
||||
|
||||
# If execution was successful or last retry, then exit
|
||||
if execution_result.exit_code == 0 or nth_try == max_retries_on_error:
|
||||
break
|
||||
|
||||
# Step 11: If exit code is non-zero and retries are available then
|
||||
# make an inference asking if we should retry or not
|
||||
chat_context = await model_context.get_messages()
|
||||
|
||||
retry_prompt = (
|
||||
f"The most recent code execution resulted in an error:\n{execution_result.output}\n\n"
|
||||
"Should we attempt to resolve it? Please respond with:\n"
|
||||
"- A boolean value for 'retry' indicating whether it should be retried.\n"
|
||||
"- A detailed explanation in 'reason' that identifies the issue, justifies your decision to retry or not, and outlines how you would resolve the error if a retry is attempted."
|
||||
)
|
||||
|
||||
chat_context = chat_context + [
|
||||
UserMessage(
|
||||
content=retry_prompt,
|
||||
source=agent_name,
|
||||
)
|
||||
]
|
||||
|
||||
response = await model_client.create(messages=chat_context, json_output=RetryDecision)
|
||||
|
||||
assert isinstance(
|
||||
response.content, str
|
||||
), "Expected structured response for retry decision to be of type str."
|
||||
should_retry_generation = RetryDecision.model_validate_json(str(response.content))
|
||||
|
||||
# Exit if no-retry is needed
|
||||
if not should_retry_generation.retry:
|
||||
break
|
||||
|
||||
yield CodeGenerationEvent(
|
||||
retry_attempt=nth_try,
|
||||
content=f"Attempt number: {nth_try + 1}\nProposed correction: {should_retry_generation.reason}",
|
||||
code_blocks=[],
|
||||
source=agent_name,
|
||||
)
|
||||
|
||||
# Always reflect on the execution result
|
||||
async for reflection_response in CodeExecutorAgent._reflect_on_code_block_results_flow(
|
||||
system_messages=system_messages,
|
||||
model_client=model_client,
|
||||
model_client_stream=model_client_stream,
|
||||
model_context=model_context,
|
||||
agent_name=agent_name,
|
||||
inner_messages=inner_messages,
|
||||
):
|
||||
yield reflection_response # Last reflection_response is of type Response so it will finish the routine
|
||||
|
||||
async def extract_code_blocks_from_messages(self, messages: Sequence[BaseChatMessage]) -> List[CodeBlock]:
|
||||
# Extract code blocks from the messages.
|
||||
code_blocks: List[CodeBlock] = []
|
||||
for msg in messages:
|
||||
if self._sources is None or msg.source in self._sources:
|
||||
if isinstance(msg, TextMessage):
|
||||
code_blocks.extend(self._extract_markdown_code_blocks(msg.content))
|
||||
# TODO: handle other message types if needed
|
||||
return code_blocks
|
||||
|
||||
async def execute_code_block(
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> CodeResult:
|
||||
# Check for approval before executing code blocks
|
||||
if self._approval_func is not None:
|
||||
# Combine all code blocks into a single string for approval
|
||||
combined_code = "\n\n".join([f"```{block.language}\n{block.code}\n```" for block in code_blocks])
|
||||
|
||||
# Get the current context from model_context
|
||||
context_messages = await self._model_context.get_messages()
|
||||
|
||||
# Create approval request
|
||||
approval_request = ApprovalRequest(code=combined_code, context=context_messages)
|
||||
|
||||
# Get approval (handle both sync and async functions)
|
||||
if self._approval_func_is_async:
|
||||
# Cast to AsyncApprovalFunc for proper typing
|
||||
async_func = cast(AsyncApprovalFunc, self._approval_func)
|
||||
approval_response = await async_func(approval_request)
|
||||
else:
|
||||
# Cast to SyncApprovalFunc for proper typing
|
||||
sync_func = cast(SyncApprovalFunc, self._approval_func)
|
||||
approval_response = sync_func(approval_request)
|
||||
|
||||
# If not approved, return error result
|
||||
if not approval_response.approved:
|
||||
return CodeResult(
|
||||
exit_code=1, output=f"Code execution was not approved. Reason: {approval_response.reason}"
|
||||
)
|
||||
|
||||
# Execute the code blocks.
|
||||
result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
|
||||
|
||||
if result.output.strip() == "":
|
||||
# No output
|
||||
result.output = f"The script ran but produced no output to console. The POSIX exit code was: {result.exit_code}. If you were expecting output, consider revising the script to ensure content is printed to stdout."
|
||||
elif result.exit_code != 0:
|
||||
# Error
|
||||
result.output = f"The script ran, then exited with an error (POSIX exit code: {result.exit_code})\nIts output was:\n{result.output}"
|
||||
|
||||
return result
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Its a no-op as the code executor agent has no mutable state."""
|
||||
pass
|
||||
|
||||
def _extract_markdown_code_blocks(self, markdown_text: str) -> List[CodeBlock]:
|
||||
pattern = re.compile(rf"```(?:\s*({self._supported_languages_regex}))\n([\s\S]*?)```", re.IGNORECASE)
|
||||
matches = pattern.findall(markdown_text)
|
||||
code_blocks: List[CodeBlock] = []
|
||||
for match in matches:
|
||||
language = match[0].strip() if match[0] else ""
|
||||
code_content = match[1]
|
||||
code_blocks.append(CodeBlock(code=code_content, language=language))
|
||||
return code_blocks
|
||||
|
||||
def _to_config(self) -> CodeExecutorAgentConfig:
|
||||
if self._approval_func is not None:
|
||||
raise ValueError(
|
||||
"Cannot serialize CodeExecutorAgent with approval_func set. The approval function is not serializable."
|
||||
)
|
||||
|
||||
return CodeExecutorAgentConfig(
|
||||
name=self.name,
|
||||
model_client=(self._model_client.dump_component() if self._model_client is not None else None),
|
||||
code_executor=self._code_executor.dump_component(),
|
||||
description=self.description,
|
||||
sources=list(self._sources) if self._sources is not None else None,
|
||||
system_message=(
|
||||
self._system_messages[0].content
|
||||
if self._system_messages and isinstance(self._system_messages[0].content, str)
|
||||
else None
|
||||
),
|
||||
model_client_stream=self._model_client_stream,
|
||||
model_context=self._model_context.dump_component(),
|
||||
supported_languages=self._supported_languages,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: CodeExecutorAgentConfig) -> Self:
|
||||
return cls(
|
||||
name=config.name,
|
||||
model_client=(
|
||||
ChatCompletionClient.load_component(config.model_client) if config.model_client is not None else None
|
||||
),
|
||||
code_executor=CodeExecutor.load_component(config.code_executor),
|
||||
description=config.description,
|
||||
sources=config.sources,
|
||||
system_message=config.system_message,
|
||||
model_client_stream=config.model_client_stream,
|
||||
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
|
||||
supported_languages=config.supported_languages,
|
||||
approval_func=None, # approval_func cannot be serialized, so it's always None when loading from config
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_compatible_context(model_client: ChatCompletionClient, messages: List[LLMMessage]) -> Sequence[LLMMessage]:
|
||||
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
|
||||
if model_client.model_info["vision"]:
|
||||
return messages
|
||||
else:
|
||||
return remove_images(messages)
|
||||
|
||||
@classmethod
|
||||
async def _call_llm(
|
||||
cls,
|
||||
model_client: ChatCompletionClient,
|
||||
model_client_stream: bool,
|
||||
system_messages: List[SystemMessage],
|
||||
model_context: ChatCompletionContext,
|
||||
agent_name: str,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> AsyncGenerator[Union[CreateResult, ModelClientStreamingChunkEvent], None]:
|
||||
"""
|
||||
Perform a model inference and yield either streaming chunk events or the final CreateResult.
|
||||
"""
|
||||
all_messages = await model_context.get_messages()
|
||||
llm_messages = cls._get_compatible_context(model_client=model_client, messages=system_messages + all_messages)
|
||||
|
||||
if model_client_stream:
|
||||
model_result: Optional[CreateResult] = None
|
||||
async for chunk in model_client.create_stream(
|
||||
llm_messages, tools=[], cancellation_token=cancellation_token
|
||||
):
|
||||
if isinstance(chunk, CreateResult):
|
||||
model_result = chunk
|
||||
elif isinstance(chunk, str):
|
||||
yield ModelClientStreamingChunkEvent(content=chunk, source=agent_name)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
|
||||
if model_result is None:
|
||||
raise RuntimeError("No final model result in streaming mode.")
|
||||
yield model_result
|
||||
else:
|
||||
model_result = await model_client.create(llm_messages, tools=[], cancellation_token=cancellation_token)
|
||||
yield model_result
|
||||
|
||||
@staticmethod
|
||||
async def _add_messages_to_context(
|
||||
model_context: ChatCompletionContext,
|
||||
messages: Sequence[BaseChatMessage],
|
||||
) -> None:
|
||||
"""
|
||||
Add incoming messages to the model context.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HandoffMessage):
|
||||
for llm_msg in msg.context:
|
||||
await model_context.add_message(llm_msg)
|
||||
await model_context.add_message(msg.to_model_message())
|
||||
|
||||
@classmethod
|
||||
async def _reflect_on_code_block_results_flow(
|
||||
cls,
|
||||
system_messages: List[SystemMessage],
|
||||
model_client: ChatCompletionClient,
|
||||
model_client_stream: bool,
|
||||
model_context: ChatCompletionContext,
|
||||
agent_name: str,
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage],
|
||||
) -> AsyncGenerator[Response | ModelClientStreamingChunkEvent | ThoughtEvent, None]:
|
||||
"""
|
||||
If reflect_on_code_block_results=True, we do another inference based on tool results
|
||||
and yield the final text response (or streaming chunks).
|
||||
"""
|
||||
all_messages = system_messages + await model_context.get_messages()
|
||||
llm_messages = cls._get_compatible_context(model_client=model_client, messages=all_messages)
|
||||
|
||||
reflection_result: Optional[CreateResult] = None
|
||||
|
||||
if model_client_stream:
|
||||
async for chunk in model_client.create_stream(llm_messages):
|
||||
if isinstance(chunk, CreateResult):
|
||||
reflection_result = chunk
|
||||
elif isinstance(chunk, str):
|
||||
yield ModelClientStreamingChunkEvent(content=chunk, source=agent_name)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
|
||||
else:
|
||||
reflection_result = await model_client.create(llm_messages)
|
||||
|
||||
if not reflection_result or not isinstance(reflection_result.content, str):
|
||||
raise RuntimeError("Reflect on tool use produced no valid text response.")
|
||||
|
||||
# --- NEW: If the reflection produced a thought, yield it ---
|
||||
if reflection_result.thought:
|
||||
thought_event = ThoughtEvent(content=reflection_result.thought, source=agent_name)
|
||||
yield thought_event
|
||||
inner_messages.append(thought_event)
|
||||
|
||||
# Add to context (including thought if present)
|
||||
await model_context.add_message(
|
||||
AssistantMessage(
|
||||
content=reflection_result.content,
|
||||
source=agent_name,
|
||||
thought=getattr(reflection_result, "thought", None),
|
||||
)
|
||||
)
|
||||
|
||||
yield Response(
|
||||
chat_message=TextMessage(
|
||||
content=reflection_result.content,
|
||||
source=agent_name,
|
||||
models_usage=reflection_result.usage,
|
||||
),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
203
agent_dhal/agentdhal_agentchat/agents/_message_filter_agent.py
Normal file
203
agent_dhal/agentdhal_agentchat/agents/_message_filter_agent.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from typing import AsyncGenerator, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from agentdhal_core import CancellationToken, Component, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agentdhal_agentchat.agents import BaseChatAgent
|
||||
from agentdhal_agentchat.base import Response
|
||||
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage
|
||||
|
||||
# ------------------------------
|
||||
# Message Filter Config
|
||||
# ------------------------------
|
||||
|
||||
|
||||
class PerSourceFilter(BaseModel):
|
||||
source: str
|
||||
position: Optional[Literal["first", "last"]] = None
|
||||
count: Optional[int] = None
|
||||
|
||||
|
||||
class MessageFilterConfig(BaseModel):
|
||||
per_source: List[PerSourceFilter]
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Component Config
|
||||
# ------------------------------
|
||||
|
||||
|
||||
class MessageFilterAgentConfig(BaseModel):
|
||||
name: str
|
||||
wrapped_agent: ComponentModel
|
||||
filter: MessageFilterConfig
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Message Filter Agent
|
||||
# ------------------------------
|
||||
|
||||
|
||||
class MessageFilterAgent(BaseChatAgent, Component[MessageFilterAgentConfig]):
|
||||
"""
|
||||
A wrapper agent that filters incoming messages before passing them to the inner agent.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental feature, and the API will change in the future releases.
|
||||
|
||||
This is useful in scenarios like multi-agent workflows where an agent should only
|
||||
process a subset of the full message history—for example, only the last message
|
||||
from each upstream agent, or only the first message from a specific source.
|
||||
|
||||
Filtering is configured using :class:`MessageFilterConfig`, which supports:
|
||||
- Filtering by message source (e.g., only messages from "user" or another agent)
|
||||
- Selecting the first N or last N messages from each source
|
||||
- If position is `None`, all messages from that source are included
|
||||
|
||||
This agent is compatible with both direct message passing and team-based execution
|
||||
such as :class:`~agentdhal_agentchat.teams.GraphFlow`.
|
||||
|
||||
Example:
|
||||
>>> agent_a = MessageFilterAgent(
|
||||
... name="A",
|
||||
... wrapped_agent=some_other_agent,
|
||||
... filter=MessageFilterConfig(
|
||||
... per_source=[
|
||||
... PerSourceFilter(source="user", position="first", count=1),
|
||||
... PerSourceFilter(source="B", position="last", count=2),
|
||||
... ]
|
||||
... ),
|
||||
... )
|
||||
|
||||
Example use case with Graph:
|
||||
Suppose you have a looping multi-agent graph: A → B → A → B → C.
|
||||
|
||||
You want:
|
||||
- A to only see the user message and the last message from B
|
||||
- B to see the user message, last message from A, and its own prior responses (for reflection)
|
||||
- C to see the user message and the last message from B
|
||||
|
||||
Wrap the agents like so:
|
||||
|
||||
>>> agent_a = MessageFilterAgent(
|
||||
... name="A",
|
||||
... wrapped_agent=agent_a_inner,
|
||||
... filter=MessageFilterConfig(
|
||||
... per_source=[
|
||||
... PerSourceFilter(source="user", position="first", count=1),
|
||||
... PerSourceFilter(source="B", position="last", count=1),
|
||||
... ]
|
||||
... ),
|
||||
... )
|
||||
|
||||
>>> agent_b = MessageFilterAgent(
|
||||
... name="B",
|
||||
... wrapped_agent=agent_b_inner,
|
||||
... filter=MessageFilterConfig(
|
||||
... per_source=[
|
||||
... PerSourceFilter(source="user", position="first", count=1),
|
||||
... PerSourceFilter(source="A", position="last", count=1),
|
||||
... PerSourceFilter(source="B", position="last", count=10),
|
||||
... ]
|
||||
... ),
|
||||
... )
|
||||
|
||||
>>> agent_c = MessageFilterAgent(
|
||||
... name="C",
|
||||
... wrapped_agent=agent_c_inner,
|
||||
... filter=MessageFilterConfig(
|
||||
... per_source=[
|
||||
... PerSourceFilter(source="user", position="first", count=1),
|
||||
... PerSourceFilter(source="B", position="last", count=1),
|
||||
... ]
|
||||
... ),
|
||||
... )
|
||||
|
||||
Then define the graph:
|
||||
|
||||
>>> graph = DiGraph(
|
||||
... nodes={
|
||||
... "A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
||||
... "B": DiGraphNode(
|
||||
... name="B",
|
||||
... edges=[
|
||||
... DiGraphEdge(target="C", condition="exit"),
|
||||
... DiGraphEdge(target="A", condition="loop"),
|
||||
... ],
|
||||
... ),
|
||||
... "C": DiGraphNode(name="C", edges=[]),
|
||||
... },
|
||||
... default_start_node="A",
|
||||
... )
|
||||
|
||||
This will ensure each agent sees only what is needed for its decision or action logic.
|
||||
"""
|
||||
|
||||
component_config_schema = MessageFilterAgentConfig
|
||||
component_provider_override = "agentdhal_agentchat.agents.MessageFilterAgent"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
wrapped_agent: BaseChatAgent,
|
||||
filter: MessageFilterConfig,
|
||||
):
|
||||
super().__init__(name=name, description=f"{wrapped_agent.description} (with message filtering)")
|
||||
self._wrapped_agent = wrapped_agent
|
||||
self._filter = filter
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return self._wrapped_agent.produced_message_types
|
||||
|
||||
def _apply_filter(self, messages: Sequence[BaseChatMessage]) -> Sequence[BaseChatMessage]:
|
||||
result: List[BaseChatMessage] = []
|
||||
|
||||
for source_filter in self._filter.per_source:
|
||||
msgs = [m for m in messages if m.source == source_filter.source]
|
||||
|
||||
if source_filter.position == "first" and source_filter.count:
|
||||
msgs = msgs[: source_filter.count]
|
||||
elif source_filter.position == "last" and source_filter.count:
|
||||
msgs = msgs[-source_filter.count :]
|
||||
|
||||
result.extend(msgs)
|
||||
|
||||
return result
|
||||
|
||||
async def on_messages(
|
||||
self,
|
||||
messages: Sequence[BaseChatMessage],
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Response:
|
||||
filtered = self._apply_filter(messages)
|
||||
return await self._wrapped_agent.on_messages(filtered, cancellation_token)
|
||||
|
||||
async def on_messages_stream(
|
||||
self,
|
||||
messages: Sequence[BaseChatMessage],
|
||||
cancellation_token: CancellationToken,
|
||||
) -> AsyncGenerator[Union[BaseAgentEvent, BaseChatMessage, Response], None]:
|
||||
filtered = self._apply_filter(messages)
|
||||
async for item in self._wrapped_agent.on_messages_stream(filtered, cancellation_token):
|
||||
yield item
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
await self._wrapped_agent.on_reset(cancellation_token)
|
||||
|
||||
def _to_config(self) -> MessageFilterAgentConfig:
|
||||
return MessageFilterAgentConfig(
|
||||
name=self.name,
|
||||
wrapped_agent=self._wrapped_agent.dump_component(),
|
||||
filter=self._filter,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: MessageFilterAgentConfig) -> "MessageFilterAgent":
|
||||
wrapped = BaseChatAgent.load_component(config.wrapped_agent)
|
||||
return cls(
|
||||
name=config.name,
|
||||
wrapped_agent=wrapped,
|
||||
filter=config.filter,
|
||||
)
|
||||
302
agent_dhal/agentdhal_agentchat/agents/_society_of_mind_agent.py
Normal file
302
agent_dhal/agentdhal_agentchat/agents/_society_of_mind_agent.py
Normal file
@@ -0,0 +1,302 @@
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence
|
||||
|
||||
from agentdhal_core import CancellationToken, Component, ComponentModel
|
||||
from agentdhal_core.model_context import (
|
||||
ChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
from agentdhal_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from agentdhal_agentchat.base import Response
|
||||
from agentdhal_agentchat.state import SocietyOfMindAgentState
|
||||
|
||||
from ..base import TaskResult, Team
|
||||
from ..messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
HandoffMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
TextMessage,
|
||||
)
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
|
||||
class SocietyOfMindAgentConfig(BaseModel):
|
||||
"""The declarative configuration for a SocietyOfMindAgent."""
|
||||
|
||||
name: str
|
||||
team: ComponentModel
|
||||
model_client: ComponentModel
|
||||
description: str | None = None
|
||||
instruction: str | None = None
|
||||
response_prompt: str | None = None
|
||||
model_context: ComponentModel | None = None
|
||||
|
||||
|
||||
class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
|
||||
"""An agent that uses an inner team of agents to generate responses.
|
||||
|
||||
Each time the agent's :meth:`on_messages` or :meth:`on_messages_stream`
|
||||
method is called, it runs the inner team of agents and then uses the
|
||||
model client to generate a response based on the inner team's messages.
|
||||
Once the response is generated, the agent resets the inner team by
|
||||
calling :meth:`Team.reset`.
|
||||
|
||||
Limit context size sent to the model:
|
||||
|
||||
You can limit the number of messages sent to the model by setting
|
||||
the `model_context` parameter to a :class:`~agentdhal_core.model_context.BufferedChatCompletionContext`.
|
||||
This will limit the number of recent messages sent to the model and can be useful
|
||||
when the model has a limit on the number of tokens it can process.
|
||||
You can also create your own model context by subclassing
|
||||
:class:`~agentdhal_core.model_context.ChatCompletionContext`.
|
||||
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
team (Team): The team of agents to use.
|
||||
model_client (ChatCompletionClient): The model client to use for preparing responses.
|
||||
description (str, optional): The description of the agent.
|
||||
instruction (str, optional): The instruction to use when generating a response using the inner team's messages.
|
||||
Defaults to :attr:`DEFAULT_INSTRUCTION`. It assumes the role of 'system'.
|
||||
response_prompt (str, optional): The response prompt to use when generating a response using the inner team's messages.
|
||||
Defaults to :attr:`DEFAULT_RESPONSE_PROMPT`. It assumes the role of 'system'.
|
||||
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~agentdhal_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset.
|
||||
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.agents import AssistantAgent, SocietyOfMindAgent
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a writer, write well.")
|
||||
agent2 = AssistantAgent(
|
||||
"assistant2",
|
||||
model_client=model_client,
|
||||
system_message="You are an editor, provide critical feedback. Respond with 'APPROVE' if the text addresses all feedbacks.",
|
||||
)
|
||||
inner_termination = TextMentionTermination("APPROVE")
|
||||
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
|
||||
|
||||
society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
|
||||
|
||||
agent3 = AssistantAgent(
|
||||
"assistant3", model_client=model_client, system_message="Translate the text to Spanish."
|
||||
)
|
||||
team = RoundRobinGroupChat([society_of_mind_agent, agent3], max_turns=2)
|
||||
|
||||
stream = team.run_stream(task="Write a short story with a surprising ending.")
|
||||
await Console(stream)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_config_schema = SocietyOfMindAgentConfig
|
||||
component_provider_override = "agentdhal_agentchat.agents.SocietyOfMindAgent"
|
||||
|
||||
DEFAULT_INSTRUCTION = "Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:"
|
||||
"""str: The default instruction to use when generating a response using the
|
||||
inner team's messages. The instruction will be prepended to the inner team's
|
||||
messages when generating a response using the model. It assumes the role of
|
||||
'system'."""
|
||||
|
||||
DEFAULT_RESPONSE_PROMPT = (
|
||||
"Output a standalone response to the original request, without mentioning any of the intermediate discussion."
|
||||
)
|
||||
"""str: The default response prompt to use when generating a response using
|
||||
the inner team's messages. It assumes the role of 'system'."""
|
||||
|
||||
DEFAULT_DESCRIPTION = "An agent that uses an inner team of agents to generate responses."
|
||||
"""str: The default description for a SocietyOfMindAgent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
team: Team,
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
instruction: str = DEFAULT_INSTRUCTION,
|
||||
response_prompt: str = DEFAULT_RESPONSE_PROMPT,
|
||||
model_context: ChatCompletionContext | None = None,
|
||||
) -> None:
|
||||
super().__init__(name=name, description=description)
|
||||
self._team = team
|
||||
self._model_client = model_client
|
||||
self._instruction = instruction
|
||||
self._response_prompt = response_prompt
|
||||
|
||||
if model_context is not None:
|
||||
self._model_context = model_context
|
||||
else:
|
||||
self._model_context = UnboundedChatCompletionContext()
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (TextMessage,)
|
||||
|
||||
@property
|
||||
def model_context(self) -> ChatCompletionContext:
|
||||
"""
|
||||
The model context in use by the agent.
|
||||
"""
|
||||
return self._model_context
|
||||
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
# Call the stream method and collect the messages.
|
||||
response: Response | None = None
|
||||
async for msg in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(msg, Response):
|
||||
response = msg
|
||||
assert response is not None
|
||||
return response
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
# Prepare the task for the team of agents.
|
||||
task_messages = list(messages)
|
||||
|
||||
# Run the team of agents.
|
||||
result: TaskResult | None = None
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
model_context = self._model_context
|
||||
|
||||
prev_content = await model_context.get_messages()
|
||||
if len(prev_content) > 0:
|
||||
prev_message = HandoffMessage(
|
||||
content="relevant previous messages",
|
||||
source=self.name,
|
||||
target="",
|
||||
context=prev_content,
|
||||
)
|
||||
task_messages = [prev_message] + task_messages
|
||||
|
||||
if len(task_messages) == 0:
|
||||
task = None
|
||||
else:
|
||||
task = task_messages
|
||||
|
||||
# Use the new output_task_messages parameter to avoid fragile count-based logic
|
||||
async for inner_msg in self._team.run_stream(
|
||||
task=task, cancellation_token=cancellation_token, output_task_messages=False
|
||||
):
|
||||
if isinstance(inner_msg, TaskResult):
|
||||
result = inner_msg
|
||||
else:
|
||||
yield inner_msg
|
||||
if isinstance(inner_msg, ModelClientStreamingChunkEvent):
|
||||
# Skip the model client streaming chunk events.
|
||||
continue
|
||||
inner_messages.append(inner_msg)
|
||||
assert result is not None
|
||||
|
||||
if len(inner_messages) == 0:
|
||||
yield Response(
|
||||
chat_message=TextMessage(source=self.name, content="No response."),
|
||||
inner_messages=[],
|
||||
# Response's inner_messages should be empty. Cause that mean is response to outer world.
|
||||
)
|
||||
else:
|
||||
llm_messages: List[LLMMessage] = []
|
||||
|
||||
if self._model_client.model_info.get("multiple_system_messages", False):
|
||||
# The model client supports multiple system messages, so we
|
||||
llm_messages.append(SystemMessage(content=self._instruction))
|
||||
else:
|
||||
# The model client does not support multiple system messages, so we
|
||||
llm_messages.append(UserMessage(content=self._instruction, source="user"))
|
||||
|
||||
# Generate a response using the model client.
|
||||
for message in inner_messages:
|
||||
if isinstance(message, BaseChatMessage):
|
||||
llm_messages.append(message.to_model_message())
|
||||
|
||||
if self._model_client.model_info.get("multiple_system_messages", False):
|
||||
# The model client supports multiple system messages, so we
|
||||
llm_messages.append(SystemMessage(content=self._response_prompt))
|
||||
else:
|
||||
# The model client does not support multiple system messages, so we
|
||||
llm_messages.append(UserMessage(content=self._response_prompt, source="user"))
|
||||
completion = await self._model_client.create(messages=llm_messages, cancellation_token=cancellation_token)
|
||||
assert isinstance(completion.content, str)
|
||||
yield Response(
|
||||
chat_message=TextMessage(source=self.name, content=completion.content, models_usage=completion.usage),
|
||||
inner_messages=[],
|
||||
# Response's inner_messages should be empty. Cause that mean is response to outer world.
|
||||
)
|
||||
|
||||
# Add new user/handoff messages to the model context
|
||||
await self._add_messages_to_context(
|
||||
model_context=model_context,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Reset the team.
|
||||
await self._team.reset()
|
||||
|
||||
@staticmethod
|
||||
async def _add_messages_to_context(
|
||||
model_context: ChatCompletionContext,
|
||||
messages: Sequence[BaseChatMessage],
|
||||
) -> None:
|
||||
"""
|
||||
Add incoming messages to the model context.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HandoffMessage):
|
||||
for llm_msg in msg.context:
|
||||
await model_context.add_message(llm_msg)
|
||||
await model_context.add_message(msg.to_model_message())
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
await self._team.reset()
|
||||
await self._model_context.clear()
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
team_state = await self._team.save_state()
|
||||
state = SocietyOfMindAgentState(inner_team_state=team_state)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
society_of_mind_state = SocietyOfMindAgentState.model_validate(state)
|
||||
await self._team.load_state(society_of_mind_state.inner_team_state)
|
||||
|
||||
def _to_config(self) -> SocietyOfMindAgentConfig:
|
||||
return SocietyOfMindAgentConfig(
|
||||
name=self.name,
|
||||
team=self._team.dump_component(),
|
||||
model_client=self._model_client.dump_component(),
|
||||
description=self.description,
|
||||
instruction=self._instruction,
|
||||
response_prompt=self._response_prompt,
|
||||
model_context=self._model_context.dump_component(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: SocietyOfMindAgentConfig) -> Self:
|
||||
model_client = ChatCompletionClient.load_component(config.model_client)
|
||||
team = Team.load_component(config.team)
|
||||
return cls(
|
||||
name=config.name,
|
||||
team=team,
|
||||
model_client=model_client,
|
||||
description=config.description or cls.DEFAULT_DESCRIPTION,
|
||||
instruction=config.instruction or cls.DEFAULT_INSTRUCTION,
|
||||
response_prompt=config.response_prompt or cls.DEFAULT_RESPONSE_PROMPT,
|
||||
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
|
||||
)
|
||||
249
agent_dhal/agentdhal_agentchat/agents/_user_proxy_agent.py
Normal file
249
agent_dhal/agentdhal_agentchat/agents/_user_proxy_agent.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, ClassVar, Generator, Optional, Sequence, Union, cast
|
||||
|
||||
from agentdhal_core import CancellationToken, Component
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..base import Response
|
||||
from ..messages import BaseAgentEvent, BaseChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
SyncInputFunc = Callable[[str], str]
|
||||
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
|
||||
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
|
||||
|
||||
|
||||
# TODO: check if using to_thread fixes this in jupyter
|
||||
async def cancellable_input(prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
|
||||
task: asyncio.Task[str] = asyncio.create_task(asyncio.to_thread(input, prompt))
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(task)
|
||||
return await task
|
||||
|
||||
|
||||
class UserProxyAgentConfig(BaseModel):
|
||||
"""Declarative configuration for the UserProxyAgent."""
|
||||
|
||||
name: str
|
||||
description: str = "A human user"
|
||||
input_func: str | None = None
|
||||
|
||||
|
||||
class UserProxyAgent(BaseChatAgent, Component[UserProxyAgentConfig]):
|
||||
"""An agent that can represent a human user through an input function.
|
||||
|
||||
This agent can be used to represent a human user in a chat system by providing a custom input function.
|
||||
|
||||
.. note::
|
||||
|
||||
Using :class:`UserProxyAgent` puts a running team in a temporary blocked
|
||||
state until the user responds. So it is important to time out the user input
|
||||
function and cancel using the :class:`~agentdhal_core.CancellationToken` if the user does not respond.
|
||||
The input function should also handle exceptions and return a default response if needed.
|
||||
|
||||
For typical use cases that involve
|
||||
slow human responses, it is recommended to use termination conditions
|
||||
such as :class:`~agentdhal_agentchat.conditions.HandoffTermination` or :class:`~agentdhal_agentchat.conditions.SourceMatchTermination`
|
||||
to stop the running team and return the control to the application.
|
||||
You can run the team again with the user input. This way, the state of the team
|
||||
can be saved and restored when the user responds.
|
||||
|
||||
See `Human-in-the-loop <https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/human-in-the-loop.html>`_ for more information.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
description (str, optional): A description of the agent.
|
||||
input_func (Optional[Callable[[str], str]], Callable[[str, Optional[CancellationToken]], Awaitable[str]]): A function that takes a prompt and returns a user input string.
|
||||
|
||||
For examples of integrating with web and UI frameworks, see the following:
|
||||
|
||||
* `FastAPI <https://github.com/microsoft/autogen/tree/main/python/samples/agentchat_fastapi>`_
|
||||
* `ChainLit <https://github.com/microsoft/autogen/tree/main/python/samples/agentchat_chainlit>`_
|
||||
|
||||
Example:
|
||||
Simple usage case::
|
||||
|
||||
import asyncio
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_agentchat.agents import UserProxyAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
|
||||
|
||||
async def simple_user_agent():
|
||||
agent = UserProxyAgent("user_proxy")
|
||||
response = await asyncio.create_task(
|
||||
agent.on_messages(
|
||||
[TextMessage(content="What is your name? ", source="user")],
|
||||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
)
|
||||
assert isinstance(response.chat_message, TextMessage)
|
||||
print(f"Your name is {response.chat_message.content}")
|
||||
|
||||
Example:
|
||||
Cancellable usage case::
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_agentchat.agents import UserProxyAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
|
||||
|
||||
token = CancellationToken()
|
||||
agent = UserProxyAgent("user_proxy")
|
||||
|
||||
|
||||
async def timeout(delay: float):
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
|
||||
def cancellation_callback(task: asyncio.Task[Any]):
|
||||
token.cancel()
|
||||
|
||||
|
||||
async def cancellable_user_agent():
|
||||
try:
|
||||
timeout_task = asyncio.create_task(timeout(3))
|
||||
timeout_task.add_done_callback(cancellation_callback)
|
||||
agent_task = asyncio.create_task(
|
||||
agent.on_messages(
|
||||
[TextMessage(content="What is your name? ", source="user")],
|
||||
cancellation_token=token,
|
||||
)
|
||||
)
|
||||
response = await agent_task
|
||||
assert isinstance(response.chat_message, TextMessage)
|
||||
print(f"Your name is {response.chat_message.content}")
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}")
|
||||
except BaseException as e:
|
||||
print(f"BaseException: {e}")
|
||||
"""
|
||||
|
||||
component_type = "agent"
|
||||
component_provider_override = "agentdhal_agentchat.agents.UserProxyAgent"
|
||||
component_config_schema = UserProxyAgentConfig
|
||||
|
||||
class InputRequestContext:
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError(
|
||||
"InputRequestContext cannot be instantiated. It is a static class that provides context management for user input requests."
|
||||
)
|
||||
|
||||
_INPUT_REQUEST_CONTEXT_VAR: ClassVar[ContextVar[str]] = ContextVar("_INPUT_REQUEST_CONTEXT_VAR")
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def populate_context(cls, ctx: str) -> Generator[None, Any, None]:
|
||||
""":meta private:"""
|
||||
token = UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.reset(token)
|
||||
|
||||
@classmethod
|
||||
def request_id(cls) -> str:
|
||||
try:
|
||||
return cls._INPUT_REQUEST_CONTEXT_VAR.get()
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"InputRequestContext.runtime() must be called within the input callback of a UserProxyAgent."
|
||||
) from e
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
description: str = "A human user",
|
||||
input_func: Optional[InputFuncType] = None,
|
||||
) -> None:
|
||||
"""Initialize the UserProxyAgent."""
|
||||
super().__init__(name=name, description=description)
|
||||
self.input_func = input_func or cancellable_input
|
||||
self._is_async = iscoroutinefunction(self.input_func)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""Message types this agent can produce."""
|
||||
return (TextMessage, HandoffMessage)
|
||||
|
||||
def _get_latest_handoff(self, messages: Sequence[BaseChatMessage]) -> Optional[HandoffMessage]:
|
||||
"""Find the HandoffMessage in the message sequence that addresses this agent."""
|
||||
if len(messages) > 0 and isinstance(messages[-1], HandoffMessage):
|
||||
if messages[-1].target == self.name:
|
||||
return messages[-1]
|
||||
else:
|
||||
raise RuntimeError(f"Handoff message target does not match agent name: {messages[-1].source}")
|
||||
return None
|
||||
|
||||
async def _get_input(self, prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
|
||||
"""Handle input based on function signature."""
|
||||
try:
|
||||
if self._is_async:
|
||||
# Cast to AsyncInputFunc for proper typing
|
||||
async_func = cast(AsyncInputFunc, self.input_func)
|
||||
return await async_func(prompt, cancellation_token)
|
||||
else:
|
||||
# Cast to SyncInputFunc for proper typing
|
||||
sync_func = cast(SyncInputFunc, self.input_func)
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, sync_func, prompt)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
|
||||
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
return message
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""Handle incoming messages by requesting user input."""
|
||||
try:
|
||||
# Check for handoff first
|
||||
handoff = self._get_latest_handoff(messages)
|
||||
prompt = (
|
||||
f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: "
|
||||
)
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
input_requested_event = UserInputRequestedEvent(request_id=request_id, source=self.name)
|
||||
yield input_requested_event
|
||||
with UserProxyAgent.InputRequestContext.populate_context(request_id):
|
||||
user_input = await self._get_input(prompt, cancellation_token)
|
||||
|
||||
# Return appropriate message type based on handoff presence
|
||||
if handoff:
|
||||
yield Response(chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name))
|
||||
else:
|
||||
yield Response(chat_message=TextMessage(content=user_input, source=self.name))
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
|
||||
|
||||
async def on_reset(self, cancellation_token: Optional[CancellationToken] = None) -> None:
|
||||
"""Reset agent state."""
|
||||
pass
|
||||
|
||||
def _to_config(self) -> UserProxyAgentConfig:
|
||||
# TODO: Add ability to serialie input_func
|
||||
return UserProxyAgentConfig(name=self.name, description=self.description, input_func=None)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: UserProxyAgentConfig) -> Self:
|
||||
return cls(name=config.name, description=config.description, input_func=None)
|
||||
18
agent_dhal/agentdhal_agentchat/base/__init__.py
Normal file
18
agent_dhal/agentdhal_agentchat/base/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from ._chat_agent import ChatAgent, Response
|
||||
from ._handoff import Handoff
|
||||
from ._task import TaskResult, TaskRunner
|
||||
from ._team import Team
|
||||
from ._termination import AndTerminationCondition, OrTerminationCondition, TerminatedException, TerminationCondition
|
||||
|
||||
__all__ = [
|
||||
"ChatAgent",
|
||||
"Response",
|
||||
"Team",
|
||||
"TerminatedException",
|
||||
"TerminationCondition",
|
||||
"AndTerminationCondition",
|
||||
"OrTerminationCondition",
|
||||
"TaskResult",
|
||||
"TaskRunner",
|
||||
"Handoff",
|
||||
]
|
||||
94
agent_dhal/agentdhal_agentchat/base/_chat_agent.py
Normal file
94
agent_dhal/agentdhal_agentchat/base/_chat_agent.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Mapping, Sequence
|
||||
|
||||
from agentdhal_core import CancellationToken, ComponentBase
|
||||
from pydantic import BaseModel, SerializeAsAny
|
||||
|
||||
from ..messages import BaseAgentEvent, BaseChatMessage
|
||||
from ._task import TaskRunner
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Response:
|
||||
"""A response from calling :meth:`ChatAgent.on_messages`."""
|
||||
|
||||
chat_message: SerializeAsAny[BaseChatMessage]
|
||||
"""A chat message produced by the agent as the response."""
|
||||
|
||||
inner_messages: Sequence[SerializeAsAny[BaseAgentEvent | BaseChatMessage]] | None = None
|
||||
"""Inner messages produced by the agent, they can be :class:`BaseAgentEvent`
|
||||
or :class:`BaseChatMessage`."""
|
||||
|
||||
|
||||
class ChatAgent(ABC, TaskRunner, ComponentBase[BaseModel]):
|
||||
"""Protocol for a chat agent."""
|
||||
|
||||
component_type = "agent"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""The name of the agent. This is used by team to uniquely identify
|
||||
the agent. It should be unique within the team."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""The description of the agent. This is used by team to
|
||||
make decisions about which agents to use. The description should
|
||||
describe the agent's capabilities and how to interact with it."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""The types of messages that the agent produces in the
|
||||
:attr:`Response.chat_message` field. They must be :class:`BaseChatMessage` types."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
"""Handles incoming messages and returns a response."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def on_messages_stream(
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""Handles incoming messages and returns a stream of inner messages and
|
||||
and the final item is the response."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Resets the agent to its initialization state."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_pause(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Called when the agent is paused. The agent may be running in :meth:`on_messages` or
|
||||
:meth:`on_messages_stream` when this method is called."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_resume(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Called when the agent is resumed. The agent may be running in :meth:`on_messages` or
|
||||
:meth:`on_messages_stream` when this method is called."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save agent state for later restoration"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Restore agent from saved state"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Release any resources held by the agent."""
|
||||
...
|
||||
62
agent_dhal/agentdhal_agentchat/base/_handoff.py
Normal file
62
agent_dhal/agentdhal_agentchat/base/_handoff.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from agentdhal_core.tools import BaseTool, FunctionTool
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from .. import EVENT_LOGGER_NAME
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class Handoff(BaseModel):
|
||||
"""Handoff configuration."""
|
||||
|
||||
target: str
|
||||
"""The name of the target agent to handoff to."""
|
||||
|
||||
description: str = Field(default="")
|
||||
"""The description of the handoff such as the condition under which it should happen and the target agent's ability.
|
||||
If not provided, it is generated from the target agent's name."""
|
||||
|
||||
name: str = Field(default="")
|
||||
"""The name of this handoff configuration. If not provided, it is generated from the target agent's name."""
|
||||
|
||||
message: str = Field(default="")
|
||||
"""The message to the target agent.
|
||||
By default, it will be the result for the handoff tool.
|
||||
If not provided, it is generated from the target agent's name."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if not values.get("description"):
|
||||
values["description"] = f"Handoff to {values['target']}."
|
||||
if not values.get("name"):
|
||||
values["name"] = f"transfer_to_{values['target']}".lower()
|
||||
else:
|
||||
name = values["name"]
|
||||
if not isinstance(name, str):
|
||||
raise ValueError(f"Handoff name must be a string: {values['name']}")
|
||||
# Check if name is a valid identifier.
|
||||
if not name.isidentifier():
|
||||
raise ValueError(f"Handoff name must be a valid identifier: {values['name']}")
|
||||
if not values.get("message"):
|
||||
values["message"] = (
|
||||
f"Transferred to {values['target']}, adopting the role of {values['target']} immediately."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def handoff_tool(self) -> BaseTool[BaseModel, BaseModel]:
|
||||
"""Create a handoff tool from this handoff configuration."""
|
||||
|
||||
def _handoff_tool() -> str:
|
||||
return self.message
|
||||
|
||||
return FunctionTool(_handoff_tool, name=self.name, description=self.description, strict=True)
|
||||
|
||||
"""
|
||||
The tool that can be used to handoff to the target agent.
|
||||
Typically, the results of the tool's execution are provided to the target agent.
|
||||
"""
|
||||
65
agent_dhal/agentdhal_agentchat/base/_task.py
Normal file
65
agent_dhal/agentdhal_agentchat/base/_task.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import AsyncGenerator, Protocol, Sequence
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from pydantic import BaseModel, SerializeAsAny
|
||||
|
||||
from ..messages import BaseAgentEvent, BaseChatMessage
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
"""Result of running a task."""
|
||||
|
||||
messages: Sequence[SerializeAsAny[BaseAgentEvent | BaseChatMessage]]
|
||||
"""Messages produced by the task."""
|
||||
|
||||
stop_reason: str | None = None
|
||||
"""The reason the task stopped."""
|
||||
|
||||
|
||||
class TaskRunner(Protocol):
|
||||
"""A task runner."""
|
||||
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
output_task_messages: bool = True,
|
||||
) -> TaskResult:
|
||||
"""Run the task and return the result.
|
||||
|
||||
The task can be a string, a single message, or a sequence of messages.
|
||||
|
||||
The runner is stateful and a subsequent call to this method will continue
|
||||
from where the previous call left off. If the task is not specified,
|
||||
the runner will continue with the current task.
|
||||
|
||||
Args:
|
||||
task: The task to run. Can be a string, a single message, or a sequence of messages.
|
||||
cancellation_token: The cancellation token to kill the task immediately.
|
||||
output_task_messages: Whether to include task messages in :attr:`TaskResult.messages`. Defaults to True for backward compatibility.
|
||||
"""
|
||||
...
|
||||
|
||||
def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
output_task_messages: bool = True,
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
|
||||
"""Run the task and produces a stream of messages and the final result
|
||||
:class:`TaskResult` as the last item in the stream.
|
||||
|
||||
The task can be a string, a single message, or a sequence of messages.
|
||||
|
||||
The runner is stateful and a subsequent call to this method will continue
|
||||
from where the previous call left off. If the task is not specified,
|
||||
the runner will continue with the current task.
|
||||
|
||||
Args:
|
||||
task: The task to run. Can be a string, a single message, or a sequence of messages.
|
||||
cancellation_token: The cancellation token to kill the task immediately.
|
||||
output_task_messages: Whether to include task messages in the output stream. Defaults to True for backward compatibility.
|
||||
"""
|
||||
...
|
||||
54
agent_dhal/agentdhal_agentchat/base/_team.py
Normal file
54
agent_dhal/agentdhal_agentchat/base/_team.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Mapping
|
||||
|
||||
from agentdhal_core import ComponentBase
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._task import TaskRunner
|
||||
|
||||
|
||||
class Team(ABC, TaskRunner, ComponentBase[BaseModel]):
|
||||
component_type = "team"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""The name of the team. This is used by team to uniquely identify itself
|
||||
in a larger team of teams."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""A description of the team. This is used to provide context about the
|
||||
team and its purpose to its parent orchestrator."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""Reset the team and all its participants to its initial state."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def pause(self) -> None:
|
||||
"""Pause the team and all its participants. This is useful for
|
||||
pausing the :meth:`agentdhal_agentchat.base.TaskRunner.run` or
|
||||
:meth:`agentdhal_agentchat.base.TaskRunner.run_stream` methods from
|
||||
concurrently, while keeping them alive."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def resume(self) -> None:
|
||||
"""Resume the team and all its participants from a pause after
|
||||
:meth:`pause` was called."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the current state of the team."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of the team."""
|
||||
...
|
||||
179
agent_dhal/agentdhal_agentchat/base/_termination.py
Normal file
179
agent_dhal/agentdhal_agentchat/base/_termination.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Sequence
|
||||
|
||||
from agentdhal_core import Component, ComponentBase, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..messages import BaseAgentEvent, BaseChatMessage, StopMessage
|
||||
|
||||
|
||||
class TerminatedException(BaseException): ...
|
||||
|
||||
|
||||
class TerminationCondition(ABC, ComponentBase[BaseModel]):
|
||||
"""A stateful condition that determines when a conversation should be terminated.
|
||||
|
||||
A termination condition is a callable that takes a sequence of BaseChatMessage objects
|
||||
since the last time the condition was called, and returns a StopMessage if the
|
||||
conversation should be terminated, or None otherwise.
|
||||
Once a termination condition has been reached, it must be reset before it can be used again.
|
||||
|
||||
Termination conditions can be combined using the AND and OR operators.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination, TextMentionTermination
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Terminate the conversation after 10 turns or if the text "TERMINATE" is mentioned.
|
||||
cond1 = MaxMessageTermination(10) | TextMentionTermination("TERMINATE")
|
||||
|
||||
# Terminate the conversation after 10 turns and if the text "TERMINATE" is mentioned.
|
||||
cond2 = MaxMessageTermination(10) & TextMentionTermination("TERMINATE")
|
||||
|
||||
# ...
|
||||
|
||||
# Reset the termination condition.
|
||||
await cond1.reset()
|
||||
await cond2.reset()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_type = "termination"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def terminated(self) -> bool:
|
||||
"""Check if the termination condition has been reached"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
"""Check if the conversation should be terminated based on the messages received
|
||||
since the last time the condition was called.
|
||||
Return a StopMessage if the conversation should be terminated, or None otherwise.
|
||||
|
||||
Args:
|
||||
messages: The messages received since the last time the condition was called.
|
||||
|
||||
Returns:
|
||||
StopMessage | None: A StopMessage if the conversation should be terminated, or None otherwise.
|
||||
|
||||
Raises:
|
||||
TerminatedException: If the termination condition has already been reached."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""Reset the termination condition."""
|
||||
...
|
||||
|
||||
def __and__(self, other: "TerminationCondition") -> "TerminationCondition":
|
||||
"""Combine two termination conditions with an AND operation."""
|
||||
return AndTerminationCondition(self, other)
|
||||
|
||||
def __or__(self, other: "TerminationCondition") -> "TerminationCondition":
|
||||
"""Combine two termination conditions with an OR operation."""
|
||||
return OrTerminationCondition(self, other)
|
||||
|
||||
|
||||
class AndTerminationConditionConfig(BaseModel):
|
||||
conditions: List[ComponentModel]
|
||||
|
||||
|
||||
class AndTerminationCondition(TerminationCondition, Component[AndTerminationConditionConfig]):
|
||||
component_config_schema = AndTerminationConditionConfig
|
||||
component_type = "termination"
|
||||
component_provider_override = "agentdhal_agentchat.base.AndTerminationCondition"
|
||||
|
||||
def __init__(self, *conditions: TerminationCondition) -> None:
|
||||
self._conditions = conditions
|
||||
self._stop_messages: List[StopMessage] = []
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return all(condition.terminated for condition in self._conditions)
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self.terminated:
|
||||
raise TerminatedException("Termination condition has already been reached.")
|
||||
# Check all remaining conditions.
|
||||
stop_messages = await asyncio.gather(
|
||||
*[condition(messages) for condition in self._conditions if not condition.terminated]
|
||||
)
|
||||
# Collect stop messages.
|
||||
for stop_message in stop_messages:
|
||||
if stop_message is not None:
|
||||
self._stop_messages.append(stop_message)
|
||||
if any(stop_message is None for stop_message in stop_messages):
|
||||
# If any remaining condition has not reached termination, it is not terminated.
|
||||
return None
|
||||
content = ", ".join(stop_message.content for stop_message in self._stop_messages)
|
||||
source = ", ".join(stop_message.source for stop_message in self._stop_messages)
|
||||
return StopMessage(content=content, source=source)
|
||||
|
||||
async def reset(self) -> None:
|
||||
for condition in self._conditions:
|
||||
await condition.reset()
|
||||
self._stop_messages.clear()
|
||||
|
||||
def _to_config(self) -> AndTerminationConditionConfig:
|
||||
"""Convert the AND termination condition to a config."""
|
||||
return AndTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions])
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: AndTerminationConditionConfig) -> Self:
|
||||
"""Create an AND termination condition from a config."""
|
||||
conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions]
|
||||
return cls(*conditions)
|
||||
|
||||
|
||||
class OrTerminationConditionConfig(BaseModel):
|
||||
conditions: List[ComponentModel]
|
||||
"""List of termination conditions where any one being satisfied is sufficient."""
|
||||
|
||||
|
||||
class OrTerminationCondition(TerminationCondition, Component[OrTerminationConditionConfig]):
|
||||
component_config_schema = OrTerminationConditionConfig
|
||||
component_type = "termination"
|
||||
component_provider_override = "agentdhal_agentchat.base.OrTerminationCondition"
|
||||
|
||||
def __init__(self, *conditions: TerminationCondition) -> None:
|
||||
self._conditions = conditions
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return any(condition.terminated for condition in self._conditions)
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self.terminated:
|
||||
raise RuntimeError("Termination condition has already been reached")
|
||||
stop_messages = await asyncio.gather(*[condition(messages) for condition in self._conditions])
|
||||
stop_messages_filter = [stop_message for stop_message in stop_messages if stop_message is not None]
|
||||
if len(stop_messages_filter) > 0:
|
||||
content = ", ".join(stop_message.content for stop_message in stop_messages_filter)
|
||||
source = ", ".join(stop_message.source for stop_message in stop_messages_filter)
|
||||
return StopMessage(content=content, source=source)
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
for condition in self._conditions:
|
||||
await condition.reset()
|
||||
|
||||
def _to_config(self) -> OrTerminationConditionConfig:
|
||||
"""Convert the OR termination condition to a config."""
|
||||
return OrTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions])
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: OrTerminationConditionConfig) -> Self:
|
||||
"""Create an OR termination condition from a config."""
|
||||
conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions]
|
||||
return cls(*conditions)
|
||||
32
agent_dhal/agentdhal_agentchat/conditions/__init__.py
Normal file
32
agent_dhal/agentdhal_agentchat/conditions/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
This module provides various termination conditions for controlling the behavior of
|
||||
multi-agent teams.
|
||||
"""
|
||||
|
||||
from ._terminations import (
|
||||
ExternalTermination,
|
||||
FunctionalTermination,
|
||||
FunctionCallTermination,
|
||||
HandoffTermination,
|
||||
MaxMessageTermination,
|
||||
SourceMatchTermination,
|
||||
StopMessageTermination,
|
||||
TextMentionTermination,
|
||||
TextMessageTermination,
|
||||
TimeoutTermination,
|
||||
TokenUsageTermination,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MaxMessageTermination",
|
||||
"TextMentionTermination",
|
||||
"StopMessageTermination",
|
||||
"TokenUsageTermination",
|
||||
"HandoffTermination",
|
||||
"TimeoutTermination",
|
||||
"ExternalTermination",
|
||||
"SourceMatchTermination",
|
||||
"TextMessageTermination",
|
||||
"FunctionCallTermination",
|
||||
"FunctionalTermination",
|
||||
]
|
||||
614
agent_dhal/agentdhal_agentchat/conditions/_terminations.py
Normal file
614
agent_dhal/agentdhal_agentchat/conditions/_terminations.py
Normal file
@@ -0,0 +1,614 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Awaitable, Callable, List, Sequence
|
||||
|
||||
from agentdhal_core import Component
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..base import TerminatedException, TerminationCondition
|
||||
from ..messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
HandoffMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
)
|
||||
|
||||
|
||||
class StopMessageTerminationConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class StopMessageTermination(TerminationCondition, Component[StopMessageTerminationConfig]):
|
||||
"""Terminate the conversation if a StopMessage is received."""
|
||||
|
||||
component_config_schema = StopMessageTerminationConfig
|
||||
component_provider_override = "agentdhal_agentchat.conditions.StopMessageTermination"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._terminated = False
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
if isinstance(message, StopMessage):
|
||||
self._terminated = True
|
||||
return StopMessage(content="Stop message received", source="StopMessageTermination")
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._terminated = False
|
||||
|
||||
def _to_config(self) -> StopMessageTerminationConfig:
|
||||
return StopMessageTerminationConfig()
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: StopMessageTerminationConfig) -> Self:
|
||||
return cls()
|
||||
|
||||
|
||||
class MaxMessageTerminationConfig(BaseModel):
|
||||
max_messages: int
|
||||
include_agent_event: bool = False
|
||||
|
||||
|
||||
class MaxMessageTermination(TerminationCondition, Component[MaxMessageTerminationConfig]):
|
||||
"""Terminate the conversation after a maximum number of messages have been exchanged.
|
||||
|
||||
Args:
|
||||
max_messages: The maximum number of messages allowed in the conversation.
|
||||
include_agent_event: If True, include :class:`~agentdhal_agentchat.messages.BaseAgentEvent` in the message count.
|
||||
Otherwise, only include :class:`~agentdhal_agentchat.messages.BaseChatMessage`. Defaults to False.
|
||||
"""
|
||||
|
||||
component_config_schema = MaxMessageTerminationConfig
|
||||
component_provider_override = "agentdhal_agentchat.conditions.MaxMessageTermination"
|
||||
|
||||
def __init__(self, max_messages: int, include_agent_event: bool = False) -> None:
|
||||
self._max_messages = max_messages
|
||||
self._message_count = 0
|
||||
self._include_agent_event = include_agent_event
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return self._message_count >= self._max_messages
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self.terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
self._message_count += len([m for m in messages if self._include_agent_event or isinstance(m, BaseChatMessage)])
|
||||
if self._message_count >= self._max_messages:
|
||||
return StopMessage(
|
||||
content=f"Maximum number of messages {self._max_messages} reached, current message count: {self._message_count}",
|
||||
source="MaxMessageTermination",
|
||||
)
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._message_count = 0
|
||||
|
||||
def _to_config(self) -> MaxMessageTerminationConfig:
|
||||
return MaxMessageTerminationConfig(
|
||||
max_messages=self._max_messages, include_agent_event=self._include_agent_event
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: MaxMessageTerminationConfig) -> Self:
|
||||
return cls(max_messages=config.max_messages, include_agent_event=config.include_agent_event)
|
||||
|
||||
|
||||
class TextMentionTerminationConfig(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class TextMentionTermination(TerminationCondition, Component[TextMentionTerminationConfig]):
|
||||
"""Terminate the conversation if a specific text is mentioned.
|
||||
|
||||
|
||||
Args:
|
||||
text: The text to look for in the messages.
|
||||
sources: Check only messages of the specified agents for the text to look for.
|
||||
"""
|
||||
|
||||
component_config_schema = TextMentionTerminationConfig
|
||||
component_provider_override = "agentdhal_agentchat.conditions.TextMentionTermination"
|
||||
|
||||
def __init__(self, text: str, sources: Sequence[str] | None = None) -> None:
|
||||
self._termination_text = text
|
||||
self._terminated = False
|
||||
self._sources = sources
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
if self._sources is not None and message.source not in self._sources:
|
||||
continue
|
||||
|
||||
content = message.to_text()
|
||||
if self._termination_text in content:
|
||||
self._terminated = True
|
||||
return StopMessage(
|
||||
content=f"Text '{self._termination_text}' mentioned", source="TextMentionTermination"
|
||||
)
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._terminated = False
|
||||
|
||||
def _to_config(self) -> TextMentionTerminationConfig:
|
||||
return TextMentionTerminationConfig(text=self._termination_text)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: TextMentionTerminationConfig) -> Self:
|
||||
return cls(text=config.text)
|
||||
|
||||
|
||||
class FunctionalTermination(TerminationCondition):
|
||||
"""Terminate the conversation if an functional expression is met.
|
||||
|
||||
Args:
|
||||
func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], bool] | Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[bool]]): A function that takes a sequence of messages
|
||||
and returns True if the termination condition is met, False otherwise.
|
||||
The function can be a callable or an async callable.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from typing import Sequence
|
||||
|
||||
from agentdhal_agentchat.conditions import FunctionalTermination
|
||||
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, StopMessage
|
||||
|
||||
|
||||
def expression(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> bool:
|
||||
# Check if the last message is a stop message
|
||||
return isinstance(messages[-1], StopMessage)
|
||||
|
||||
|
||||
termination = FunctionalTermination(expression)
|
||||
|
||||
|
||||
async def run() -> None:
|
||||
messages = [
|
||||
StopMessage(source="agent1", content="Stop"),
|
||||
]
|
||||
result = await termination(messages)
|
||||
print(result)
|
||||
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
StopMessage(source="FunctionalTermination", content="Functional termination condition met")
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], bool]
|
||||
| Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[bool]],
|
||||
) -> None:
|
||||
self._func = func
|
||||
self._terminated = False
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
if asyncio.iscoroutinefunction(self._func):
|
||||
result = await self._func(messages)
|
||||
else:
|
||||
result = self._func(messages)
|
||||
if result is True:
|
||||
self._terminated = True
|
||||
return StopMessage(content="Functional termination condition met", source="FunctionalTermination")
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._terminated = False
|
||||
|
||||
|
||||
class TokenUsageTerminationConfig(BaseModel):
|
||||
max_total_token: int | None
|
||||
max_prompt_token: int | None
|
||||
max_completion_token: int | None
|
||||
|
||||
|
||||
class TokenUsageTermination(TerminationCondition, Component[TokenUsageTerminationConfig]):
|
||||
"""Terminate the conversation if a token usage limit is reached.
|
||||
|
||||
Args:
|
||||
max_total_token: The maximum total number of tokens allowed in the conversation.
|
||||
max_prompt_token: The maximum number of prompt tokens allowed in the conversation.
|
||||
max_completion_token: The maximum number of completion tokens allowed in the conversation.
|
||||
|
||||
Raises:
|
||||
ValueError: If none of max_total_token, max_prompt_token, or max_completion_token is provided.
|
||||
"""
|
||||
|
||||
component_config_schema = TokenUsageTerminationConfig
|
||||
component_provider_override = "agentdhal_agentchat.conditions.TokenUsageTermination"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_total_token: int | None = None,
|
||||
max_prompt_token: int | None = None,
|
||||
max_completion_token: int | None = None,
|
||||
) -> None:
|
||||
if max_total_token is None and max_prompt_token is None and max_completion_token is None:
|
||||
raise ValueError(
|
||||
"At least one of max_total_token, max_prompt_token, or max_completion_token must be provided"
|
||||
)
|
||||
self._max_total_token = max_total_token
|
||||
self._max_prompt_token = max_prompt_token
|
||||
self._max_completion_token = max_completion_token
|
||||
self._total_token_count = 0
|
||||
self._prompt_token_count = 0
|
||||
self._completion_token_count = 0
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return (
|
||||
(self._max_total_token is not None and self._total_token_count >= self._max_total_token)
|
||||
or (self._max_prompt_token is not None and self._prompt_token_count >= self._max_prompt_token)
|
||||
or (self._max_completion_token is not None and self._completion_token_count >= self._max_completion_token)
|
||||
)
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self.terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
if message.models_usage is not None:
|
||||
self._prompt_token_count += message.models_usage.prompt_tokens
|
||||
self._completion_token_count += message.models_usage.completion_tokens
|
||||
self._total_token_count += message.models_usage.prompt_tokens + message.models_usage.completion_tokens
|
||||
if self.terminated:
|
||||
content = f"Token usage limit reached, total token count: {self._total_token_count}, prompt token count: {self._prompt_token_count}, completion token count: {self._completion_token_count}."
|
||||
return StopMessage(content=content, source="TokenUsageTermination")
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._total_token_count = 0
|
||||
self._prompt_token_count = 0
|
||||
self._completion_token_count = 0
|
||||
|
||||
def _to_config(self) -> TokenUsageTerminationConfig:
|
||||
return TokenUsageTerminationConfig(
|
||||
max_total_token=self._max_total_token,
|
||||
max_prompt_token=self._max_prompt_token,
|
||||
max_completion_token=self._max_completion_token,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: TokenUsageTerminationConfig) -> Self:
|
||||
return cls(
|
||||
max_total_token=config.max_total_token,
|
||||
max_prompt_token=config.max_prompt_token,
|
||||
max_completion_token=config.max_completion_token,
|
||||
)
|
||||
|
||||
|
||||
class HandoffTerminationConfig(BaseModel):
|
||||
target: str
|
||||
|
||||
|
||||
class HandoffTermination(TerminationCondition, Component[HandoffTerminationConfig]):
|
||||
"""Terminate the conversation if a :class:`~agentdhal_agentchat.messages.HandoffMessage`
|
||||
with the given target is received.
|
||||
|
||||
Args:
|
||||
target (str): The target of the handoff message.
|
||||
"""
|
||||
|
||||
component_config_schema = HandoffTerminationConfig
|
||||
component_provider_override = "agentdhal_agentchat.conditions.HandoffTermination"
|
||||
|
||||
def __init__(self, target: str) -> None:
|
||||
self._terminated = False
|
||||
self._target = target
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
if isinstance(message, HandoffMessage) and message.target == self._target:
|
||||
self._terminated = True
|
||||
return StopMessage(
|
||||
content=f"Handoff to {self._target} from {message.source} detected.", source="HandoffTermination"
|
||||
)
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._terminated = False
|
||||
|
||||
def _to_config(self) -> HandoffTerminationConfig:
|
||||
return HandoffTerminationConfig(target=self._target)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: HandoffTerminationConfig) -> Self:
|
||||
return cls(target=config.target)
|
||||
|
||||
|
||||
class TimeoutTerminationConfig(BaseModel):
|
||||
timeout_seconds: float
|
||||
|
||||
|
||||
class TimeoutTermination(TerminationCondition, Component[TimeoutTerminationConfig]):
|
||||
"""Terminate the conversation after a specified duration has passed.
|
||||
|
||||
Args:
|
||||
timeout_seconds: The maximum duration in seconds before terminating the conversation.
|
||||
"""
|
||||
|
||||
component_config_schema = TimeoutTerminationConfig
|
||||
component_provider_override = "agentdhal_agentchat.conditions.TimeoutTermination"
|
||||
|
||||
def __init__(self, timeout_seconds: float) -> None:
|
||||
self._timeout_seconds = timeout_seconds
|
||||
self._start_time = time.monotonic()
|
||||
self._terminated = False
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
|
||||
if (time.monotonic() - self._start_time) >= self._timeout_seconds:
|
||||
self._terminated = True
|
||||
return StopMessage(
|
||||
content=f"Timeout of {self._timeout_seconds} seconds reached", source="TimeoutTermination"
|
||||
)
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._start_time = time.monotonic()
|
||||
self._terminated = False
|
||||
|
||||
def _to_config(self) -> TimeoutTerminationConfig:
|
||||
return TimeoutTerminationConfig(timeout_seconds=self._timeout_seconds)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: TimeoutTerminationConfig) -> Self:
|
||||
return cls(timeout_seconds=config.timeout_seconds)
|
||||
|
||||
|
||||
class ExternalTerminationConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class ExternalTermination(TerminationCondition, Component[ExternalTerminationConfig]):
|
||||
"""A termination condition that is externally controlled
|
||||
by calling the :meth:`set` method.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_agentchat.conditions import ExternalTermination
|
||||
|
||||
termination = ExternalTermination()
|
||||
|
||||
# Run the team in an asyncio task.
|
||||
...
|
||||
|
||||
# Set the termination condition externally
|
||||
termination.set()
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = ExternalTerminationConfig
|
||||
component_provider_override = "agentdhal_agentchat.conditions.ExternalTermination"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._terminated = False
|
||||
self._setted = False
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
def set(self) -> None:
|
||||
"""Set the termination condition to terminated."""
|
||||
self._setted = True
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
if self._setted:
|
||||
self._terminated = True
|
||||
return StopMessage(content="External termination requested", source="ExternalTermination")
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._terminated = False
|
||||
self._setted = False
|
||||
|
||||
def _to_config(self) -> ExternalTerminationConfig:
|
||||
return ExternalTerminationConfig()
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: ExternalTerminationConfig) -> Self:
|
||||
return cls()
|
||||
|
||||
|
||||
class SourceMatchTerminationConfig(BaseModel):
|
||||
sources: List[str]
|
||||
|
||||
|
||||
class SourceMatchTermination(TerminationCondition, Component[SourceMatchTerminationConfig]):
|
||||
"""Terminate the conversation after a specific source responds.
|
||||
|
||||
Args:
|
||||
sources (List[str]): List of source names to terminate the conversation.
|
||||
|
||||
Raises:
|
||||
TerminatedException: If the termination condition has already been reached.
|
||||
"""
|
||||
|
||||
component_config_schema = SourceMatchTerminationConfig
|
||||
component_provider_override = "agentdhal_agentchat.conditions.SourceMatchTermination"
|
||||
|
||||
def __init__(self, sources: List[str]) -> None:
|
||||
self._sources = sources
|
||||
self._terminated = False
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
if not messages:
|
||||
return None
|
||||
for message in messages:
|
||||
if message.source in self._sources:
|
||||
self._terminated = True
|
||||
return StopMessage(content=f"'{message.source}' answered", source="SourceMatchTermination")
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._terminated = False
|
||||
|
||||
def _to_config(self) -> SourceMatchTerminationConfig:
|
||||
return SourceMatchTerminationConfig(sources=self._sources)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: SourceMatchTerminationConfig) -> Self:
|
||||
return cls(sources=config.sources)
|
||||
|
||||
|
||||
class TextMessageTerminationConfig(BaseModel):
|
||||
"""Configuration for the TextMessageTermination termination condition."""
|
||||
|
||||
source: str | None = None
|
||||
"""The source of the text message to terminate the conversation."""
|
||||
|
||||
|
||||
class TextMessageTermination(TerminationCondition, Component[TextMessageTerminationConfig]):
|
||||
"""Terminate the conversation if a :class:`~agentdhal_agentchat.messages.TextMessage` is received.
|
||||
|
||||
This termination condition checks for TextMessage instances in the message sequence. When a TextMessage is found,
|
||||
it terminates the conversation if either:
|
||||
- No source was specified (terminates on any TextMessage)
|
||||
- The message source matches the specified source
|
||||
|
||||
Args:
|
||||
source (str | None, optional): The source name to match against incoming messages. If None, matches any source.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
component_config_schema = TextMessageTerminationConfig
|
||||
component_provider_override = "agentdhal_agentchat.conditions.TextMessageTermination"
|
||||
|
||||
def __init__(self, source: str | None = None) -> None:
|
||||
self._terminated = False
|
||||
self._source = source
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
if isinstance(message, TextMessage) and (self._source is None or message.source == self._source):
|
||||
self._terminated = True
|
||||
return StopMessage(
|
||||
content=f"Text message received from '{message.source}'", source="TextMessageTermination"
|
||||
)
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._terminated = False
|
||||
|
||||
def _to_config(self) -> TextMessageTerminationConfig:
|
||||
return TextMessageTerminationConfig(source=self._source)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: TextMessageTerminationConfig) -> Self:
|
||||
return cls(source=config.source)
|
||||
|
||||
|
||||
class FunctionCallTerminationConfig(BaseModel):
|
||||
"""Configuration for the :class:`FunctionCallTermination` termination condition."""
|
||||
|
||||
function_name: str
|
||||
|
||||
|
||||
class FunctionCallTermination(TerminationCondition, Component[FunctionCallTerminationConfig]):
|
||||
"""Terminate the conversation if a :class:`~agentdhal_core.models.FunctionExecutionResult`
|
||||
with a specific name was received.
|
||||
|
||||
Args:
|
||||
function_name (str): The name of the function to look for in the messages.
|
||||
|
||||
Raises:
|
||||
TerminatedException: If the termination condition has already been reached.
|
||||
"""
|
||||
|
||||
component_config_schema = FunctionCallTerminationConfig
|
||||
component_provider_override = "agentdhal_agentchat.conditions.FunctionCallTermination"
|
||||
"""The schema for the component configuration."""
|
||||
|
||||
def __init__(self, function_name: str) -> None:
|
||||
self._terminated = False
|
||||
self._function_name = function_name
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
if isinstance(message, ToolCallExecutionEvent):
|
||||
for execution in message.content:
|
||||
if execution.name == self._function_name:
|
||||
self._terminated = True
|
||||
return StopMessage(
|
||||
content=f"Function '{self._function_name}' was executed.",
|
||||
source="FunctionCallTermination",
|
||||
)
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._terminated = False
|
||||
|
||||
def _to_config(self) -> FunctionCallTerminationConfig:
|
||||
return FunctionCallTerminationConfig(
|
||||
function_name=self._function_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: FunctionCallTerminationConfig) -> Self:
|
||||
return cls(
|
||||
function_name=config.function_name,
|
||||
)
|
||||
693
agent_dhal/agentdhal_agentchat/messages.py
Normal file
693
agent_dhal/agentdhal_agentchat/messages.py
Normal file
@@ -0,0 +1,693 @@
|
||||
"""
|
||||
This module defines various message types used for agent-to-agent communication.
|
||||
Each message type inherits either from the BaseChatMessage class or BaseAgentEvent
|
||||
class and includes specific fields relevant to the type of message being sent.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generic, List, Literal, Mapping, Optional, Type, TypeVar
|
||||
|
||||
from agentdhal_core import Component, ComponentBase, FunctionCall, Image
|
||||
from agentdhal_core.code_executor import CodeBlock, CodeResult
|
||||
from agentdhal_core.memory import MemoryContent
|
||||
from agentdhal_core.models import (
|
||||
FunctionExecutionResult,
|
||||
LLMMessage,
|
||||
RequestUsage,
|
||||
UserMessage,
|
||||
)
|
||||
from agentdhal_core.utils import schema_to_pydantic_model
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from typing_extensions import Annotated, Self
|
||||
|
||||
|
||||
class BaseMessage(BaseModel, ABC):
|
||||
"""Abstract base class for all message types in AgentChat.
|
||||
|
||||
.. warning::
|
||||
|
||||
If you want to create a new message type, do not inherit from this class.
|
||||
Instead, inherit from :class:`BaseChatMessage` or :class:`BaseAgentEvent`
|
||||
to clarify the purpose of the message type.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def to_text(self) -> str:
|
||||
"""Convert the message content to a string-only representation
|
||||
that can be rendered in the console and inspected by the user or conditions.
|
||||
This is not used for creating text-only content for models.
|
||||
For :class:`BaseChatMessage` types, use :meth:`to_model_text` instead."""
|
||||
...
|
||||
|
||||
def dump(self) -> Mapping[str, Any]:
|
||||
"""Convert the message to a JSON-serializable dictionary.
|
||||
|
||||
The default implementation uses the Pydantic model's
|
||||
:meth:`model_dump` method to convert the message to a dictionary.
|
||||
Datetime objects are automatically converted to ISO format strings
|
||||
to ensure JSON serialization compatibility.
|
||||
Override this method if you want to customize the serialization
|
||||
process or add additional fields to the output.
|
||||
"""
|
||||
return self.model_dump(mode="json")
|
||||
|
||||
@classmethod
|
||||
def load(cls, data: Mapping[str, Any]) -> Self:
|
||||
"""Create a message from a dictionary of JSON-serializable data.
|
||||
|
||||
The default implementation uses the Pydantic model's
|
||||
:meth:`model_validate` method to create the message from the data.
|
||||
Override this method if you want to customize the deserialization
|
||||
process or add additional fields to the input data."""
|
||||
return cls.model_validate(data)
|
||||
|
||||
|
||||
class BaseChatMessage(BaseMessage, ABC):
|
||||
"""Abstract base class for chat messages.
|
||||
|
||||
.. note::
|
||||
|
||||
If you want to create a new message type that is used for agent-to-agent
|
||||
communication, inherit from this class, or simply use
|
||||
:class:`StructuredMessage` if your content type is a subclass of
|
||||
Pydantic BaseModel.
|
||||
|
||||
This class is used for messages that are sent between agents in a chat
|
||||
conversation. Agents are expected to process the content of the
|
||||
message using models and return a response as another :class:`BaseChatMessage`.
|
||||
"""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""Unique identifier for this message."""
|
||||
|
||||
source: str
|
||||
"""The name of the agent that sent this message."""
|
||||
|
||||
models_usage: RequestUsage | None = None
|
||||
"""The model client usage incurred when producing this message."""
|
||||
|
||||
metadata: Dict[str, str] = {}
|
||||
"""Additional metadata about the message."""
|
||||
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
"""The time when the message was created."""
|
||||
|
||||
@abstractmethod
|
||||
def to_model_text(self) -> str:
|
||||
"""Convert the content of the message to text-only representation.
|
||||
This is used for creating text-only content for models.
|
||||
|
||||
This is not used for rendering the message in console. For that, use
|
||||
:meth:`~BaseMessage.to_text`.
|
||||
|
||||
The difference between this and :meth:`to_model_message` is that this
|
||||
is used to construct parts of the a message for the model client,
|
||||
while :meth:`to_model_message` is used to create a complete message
|
||||
for the model client.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_model_message(self) -> UserMessage:
|
||||
"""Convert the message content to a :class:`~agentdhal_core.models.UserMessage`
|
||||
for use with model client, e.g., :class:`~agentdhal_core.models.ChatCompletionClient`.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BaseTextChatMessage(BaseChatMessage, ABC):
|
||||
"""Base class for all text-only :class:`BaseChatMessage` types.
|
||||
It has implementations for :meth:`to_text`, :meth:`to_model_text`,
|
||||
and :meth:`to_model_message` methods.
|
||||
|
||||
Inherit from this class if your message content type is a string.
|
||||
"""
|
||||
|
||||
content: str
|
||||
"""The content of the message."""
|
||||
|
||||
def to_text(self) -> str:
|
||||
return self.content
|
||||
|
||||
def to_model_text(self) -> str:
|
||||
return self.content
|
||||
|
||||
def to_model_message(self) -> UserMessage:
|
||||
return UserMessage(content=self.content, source=self.source)
|
||||
|
||||
|
||||
class BaseAgentEvent(BaseMessage, ABC):
|
||||
"""Base class for agent events.
|
||||
|
||||
.. note::
|
||||
|
||||
If you want to create a new message type for signaling observable events
|
||||
to user and application, inherit from this class.
|
||||
|
||||
Agent events are used to signal actions and thoughts produced by agents
|
||||
and teams to user and applications. They are not used for agent-to-agent
|
||||
communication and are not expected to be processed by other agents.
|
||||
|
||||
You should override the :meth:`to_text` method if you want to provide
|
||||
a custom rendering of the content.
|
||||
"""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""Unique identifier for this event."""
|
||||
|
||||
source: str
|
||||
"""The name of the agent that sent this message."""
|
||||
|
||||
models_usage: RequestUsage | None = None
|
||||
"""The model client usage incurred when producing this message."""
|
||||
|
||||
metadata: Dict[str, str] = {}
|
||||
"""Additional metadata about the message."""
|
||||
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
"""The time when the message was created."""
|
||||
|
||||
|
||||
StructuredContentType = TypeVar("StructuredContentType", bound=BaseModel, covariant=True)
|
||||
"""Type variable for structured content types."""
|
||||
|
||||
|
||||
class StructuredMessage(BaseChatMessage, Generic[StructuredContentType]):
|
||||
"""A :class:`BaseChatMessage` type with an unspecified content type.
|
||||
|
||||
To create a new structured message type, specify the content type
|
||||
as a subclass of `Pydantic BaseModel <https://docs.pydantic.dev/latest/concepts/models/>`_.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pydantic import BaseModel
|
||||
from agentdhal_agentchat.messages import StructuredMessage
|
||||
|
||||
|
||||
class MyMessageContent(BaseModel):
|
||||
text: str
|
||||
number: int
|
||||
|
||||
|
||||
message = StructuredMessage[MyMessageContent](
|
||||
content=MyMessageContent(text="Hello", number=42),
|
||||
source="agent1",
|
||||
)
|
||||
|
||||
print(message.to_text()) # {"text": "Hello", "number": 42}
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pydantic import BaseModel
|
||||
from agentdhal_agentchat.messages import StructuredMessage
|
||||
|
||||
|
||||
class MyMessageContent(BaseModel):
|
||||
text: str
|
||||
number: int
|
||||
|
||||
|
||||
message = StructuredMessage[MyMessageContent](
|
||||
content=MyMessageContent(text="Hello", number=42),
|
||||
source="agent",
|
||||
format_string="Hello, {text} {number}!",
|
||||
)
|
||||
|
||||
print(message.to_text()) # Hello, agent 42!
|
||||
|
||||
"""
|
||||
|
||||
content: StructuredContentType
|
||||
"""The content of the message. Must be a subclass of
|
||||
`Pydantic BaseModel <https://docs.pydantic.dev/latest/concepts/models/>`_."""
|
||||
|
||||
format_string: Optional[str] = None
|
||||
"""(Experimental) An optional format string to render the content into a human-readable format.
|
||||
The format string can use the fields of the content model as placeholders.
|
||||
For example, if the content model has a field `name`, you can use
|
||||
`{name}` in the format string to include the value of that field.
|
||||
The format string is used in the :meth:`to_text` method to create a
|
||||
human-readable representation of the message.
|
||||
This setting is experimental and will change in the future.
|
||||
"""
|
||||
|
||||
@computed_field
|
||||
def type(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
def to_text(self) -> str:
|
||||
if self.format_string is not None:
|
||||
return self.format_string.format(**self.content.model_dump())
|
||||
else:
|
||||
return self.content.model_dump_json()
|
||||
|
||||
def to_model_text(self) -> str:
|
||||
if self.format_string is not None:
|
||||
return self.format_string.format(**self.content.model_dump())
|
||||
else:
|
||||
return self.content.model_dump_json()
|
||||
|
||||
def to_model_message(self) -> UserMessage:
|
||||
return UserMessage(
|
||||
content=self.content.model_dump_json(),
|
||||
source=self.source,
|
||||
)
|
||||
|
||||
|
||||
class StructureMessageConfig(BaseModel):
|
||||
"""The declarative configuration for the structured output."""
|
||||
|
||||
json_schema: Dict[str, Any]
|
||||
format_string: Optional[str] = None
|
||||
content_model_name: str
|
||||
|
||||
|
||||
class StructuredMessageFactory(ComponentBase[StructureMessageConfig], Component[StructureMessageConfig]):
|
||||
""":meta private:
|
||||
|
||||
A component that creates structured chat messages from Pydantic models or JSON schemas.
|
||||
|
||||
This component helps you generate strongly-typed chat messages with content defined using a Pydantic model.
|
||||
It can be used in declarative workflows where message structure must be validated, formatted, and serialized.
|
||||
|
||||
You can initialize the component directly using a `BaseModel` subclass, or dynamically from a configuration
|
||||
object (e.g., loaded from disk or a database).
|
||||
|
||||
### Example 1: Create from a Pydantic Model
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pydantic import BaseModel
|
||||
from agentdhal_agentchat.messages import StructuredMessageFactory
|
||||
|
||||
|
||||
class TestContent(BaseModel):
|
||||
field1: str
|
||||
field2: int
|
||||
|
||||
|
||||
format_string = "This is a string {field1} and this is an int {field2}"
|
||||
sm_component = StructuredMessageFactory(input_model=TestContent, format_string=format_string)
|
||||
|
||||
message = sm_component.StructuredMessage(
|
||||
source="test_agent", content=TestContent(field1="Hello", field2=42), format_string=format_string
|
||||
)
|
||||
|
||||
print(message.to_model_text()) # Output: This is a string Hello and this is an int 42
|
||||
|
||||
config = sm_component.dump_component()
|
||||
|
||||
s_m_dyn = StructuredMessageFactory.load_component(config)
|
||||
message = s_m_dyn.StructuredMessage(
|
||||
source="test_agent",
|
||||
content=s_m_dyn.ContentModel(field1="dyn agent", field2=43),
|
||||
format_string=s_m_dyn.format_string,
|
||||
)
|
||||
print(type(message)) # StructuredMessage[GeneratedModel]
|
||||
print(message.to_model_text()) # Output: This is a string dyn agent and this is an int 43
|
||||
|
||||
Attributes:
|
||||
component_config_schema (StructureMessageConfig): Defines the configuration structure for this component.
|
||||
component_provider_override (str): Path used to reference this component in external tooling.
|
||||
component_type (str): Identifier used for categorization (e.g., "structured_message").
|
||||
|
||||
Raises:
|
||||
ValueError: If neither `json_schema` nor `input_model` is provided.
|
||||
|
||||
Args:
|
||||
json_schema (Optional[str]): JSON schema to dynamically create a Pydantic model.
|
||||
input_model (Optional[Type[BaseModel]]): A subclass of `BaseModel` that defines the expected message structure.
|
||||
format_string (Optional[str]): Optional string to render content into a human-readable format.
|
||||
content_model_name (Optional[str]): Optional name for the generated Pydantic model.
|
||||
"""
|
||||
|
||||
component_config_schema = StructureMessageConfig
|
||||
component_provider_override = "agentdhal_agentchat.messages.StructuredMessageFactory"
|
||||
component_type = "structured_message"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
json_schema: Optional[Dict[str, Any]] = None,
|
||||
input_model: Optional[Type[BaseModel]] = None,
|
||||
format_string: Optional[str] = None,
|
||||
content_model_name: Optional[str] = None,
|
||||
) -> None:
|
||||
self.format_string = format_string
|
||||
|
||||
if json_schema:
|
||||
self.ContentModel = schema_to_pydantic_model(
|
||||
json_schema, model_name=content_model_name or "GeneratedContentModel"
|
||||
)
|
||||
elif input_model:
|
||||
self.ContentModel = input_model
|
||||
else:
|
||||
raise ValueError("Either `json_schema` or `input_model` must be provided.")
|
||||
|
||||
self.StructuredMessage = StructuredMessage[self.ContentModel] # type: ignore[name-defined]
|
||||
|
||||
def _to_config(self) -> StructureMessageConfig:
|
||||
return StructureMessageConfig(
|
||||
json_schema=self.ContentModel.model_json_schema(),
|
||||
format_string=self.format_string,
|
||||
content_model_name=self.ContentModel.__name__,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: StructureMessageConfig) -> "StructuredMessageFactory":
|
||||
return cls(
|
||||
json_schema=config.json_schema,
|
||||
format_string=config.format_string,
|
||||
content_model_name=config.content_model_name,
|
||||
)
|
||||
|
||||
|
||||
class TextMessage(BaseTextChatMessage):
|
||||
"""A text message with string-only content."""
|
||||
|
||||
type: Literal["TextMessage"] = "TextMessage"
|
||||
|
||||
|
||||
class MultiModalMessage(BaseChatMessage):
|
||||
"""A multimodal message."""
|
||||
|
||||
content: List[str | Image]
|
||||
"""The content of the message."""
|
||||
|
||||
type: Literal["MultiModalMessage"] = "MultiModalMessage"
|
||||
|
||||
def to_model_text(self, image_placeholder: str | None = "[image]") -> str:
|
||||
"""Convert the content of the message to a string-only representation.
|
||||
If an image is present, it will be replaced with the image placeholder
|
||||
by default, otherwise it will be a base64 string when set to None.
|
||||
"""
|
||||
text = ""
|
||||
for c in self.content:
|
||||
if isinstance(c, str):
|
||||
text += c
|
||||
elif isinstance(c, Image):
|
||||
if image_placeholder is not None:
|
||||
text += f" {image_placeholder}"
|
||||
else:
|
||||
text += f" {c.to_base64()}"
|
||||
return text
|
||||
|
||||
def to_text(self, iterm: bool = False) -> str:
|
||||
result: List[str] = []
|
||||
for c in self.content:
|
||||
if isinstance(c, str):
|
||||
result.append(c)
|
||||
else:
|
||||
if iterm:
|
||||
# iTerm2 image rendering protocol: https://iterm2.com/documentation-images.html
|
||||
image_data = c.to_base64()
|
||||
result.append(f"\033]1337;File=inline=1:{image_data}\a\n")
|
||||
else:
|
||||
result.append("<image>")
|
||||
return "\n".join(result)
|
||||
|
||||
def to_model_message(self) -> UserMessage:
|
||||
return UserMessage(content=self.content, source=self.source)
|
||||
|
||||
|
||||
class StopMessage(BaseTextChatMessage):
|
||||
"""A message requesting stop of a conversation."""
|
||||
|
||||
type: Literal["StopMessage"] = "StopMessage"
|
||||
|
||||
|
||||
class HandoffMessage(BaseTextChatMessage):
|
||||
"""A message requesting handoff of a conversation to another agent."""
|
||||
|
||||
target: str
|
||||
"""The name of the target agent to handoff to."""
|
||||
|
||||
context: List[LLMMessage] = []
|
||||
"""The model context to be passed to the target agent."""
|
||||
|
||||
type: Literal["HandoffMessage"] = "HandoffMessage"
|
||||
|
||||
|
||||
class ToolCallSummaryMessage(BaseTextChatMessage):
|
||||
"""A message signaling the summary of tool call results."""
|
||||
|
||||
type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"
|
||||
|
||||
tool_calls: List[FunctionCall]
|
||||
"""The tool calls that were made."""
|
||||
|
||||
results: List[FunctionExecutionResult]
|
||||
"""The results of the tool calls."""
|
||||
|
||||
|
||||
class ToolCallRequestEvent(BaseAgentEvent):
|
||||
"""An event signaling a request to use tools."""
|
||||
|
||||
content: List[FunctionCall]
|
||||
"""The tool calls."""
|
||||
|
||||
type: Literal["ToolCallRequestEvent"] = "ToolCallRequestEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class CodeGenerationEvent(BaseAgentEvent):
|
||||
"""An event signaling code generation event."""
|
||||
|
||||
retry_attempt: int
|
||||
"Retry number, 0 means first generation"
|
||||
|
||||
content: str
|
||||
"The complete content as string."
|
||||
|
||||
code_blocks: List[CodeBlock]
|
||||
"List of code blocks present in content"
|
||||
|
||||
type: Literal["CodeGenerationEvent"] = "CodeGenerationEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return self.content
|
||||
|
||||
|
||||
class CodeExecutionEvent(BaseAgentEvent):
|
||||
"""An event signaling code execution event."""
|
||||
|
||||
retry_attempt: int
|
||||
"Retry number, 0 means first execution"
|
||||
|
||||
result: CodeResult
|
||||
"Code Execution Result"
|
||||
|
||||
type: Literal["CodeExecutionEvent"] = "CodeExecutionEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return self.result.output
|
||||
|
||||
|
||||
class ToolCallExecutionEvent(BaseAgentEvent):
|
||||
"""An event signaling the execution of tool calls."""
|
||||
|
||||
content: List[FunctionExecutionResult]
|
||||
"""The tool call results."""
|
||||
|
||||
type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class UserInputRequestedEvent(BaseAgentEvent):
|
||||
"""An event signaling a that the user proxy has requested user input. Published prior to invoking the input callback."""
|
||||
|
||||
request_id: str
|
||||
"""Identifier for the user input request."""
|
||||
|
||||
content: Literal[""] = ""
|
||||
"""Empty content for compat with consumers expecting a content field."""
|
||||
|
||||
type: Literal["UserInputRequestedEvent"] = "UserInputRequestedEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class MemoryQueryEvent(BaseAgentEvent):
|
||||
"""An event signaling the results of memory queries."""
|
||||
|
||||
content: List[MemoryContent]
|
||||
"""The memory query results."""
|
||||
|
||||
type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class ModelClientStreamingChunkEvent(BaseAgentEvent):
|
||||
"""An event signaling a text output chunk from a model client in streaming mode."""
|
||||
|
||||
content: str
|
||||
"""A string chunk from the model client."""
|
||||
|
||||
full_message_id: str | None = None
|
||||
"""Optional reference to the complete message that may come after the chunks.
|
||||
This allows consumers of the stream to correlate chunks with the eventual completed message."""
|
||||
|
||||
type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return self.content
|
||||
|
||||
|
||||
class ThoughtEvent(BaseAgentEvent):
|
||||
"""An event signaling the thought process of a model.
|
||||
It is used to communicate the reasoning tokens generated by a reasoning model,
|
||||
or the extra text content generated by a function call."""
|
||||
|
||||
content: str
|
||||
"""The thought process of the model."""
|
||||
|
||||
type: Literal["ThoughtEvent"] = "ThoughtEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return self.content
|
||||
|
||||
|
||||
class SelectSpeakerEvent(BaseAgentEvent):
|
||||
"""An event signaling the selection of speakers for a conversation."""
|
||||
|
||||
content: List[str]
|
||||
"""The names of the selected speakers."""
|
||||
|
||||
type: Literal["SelectSpeakerEvent"] = "SelectSpeakerEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class SelectorEvent(BaseAgentEvent):
|
||||
"""An event emitted from the `SelectorGroupChat`."""
|
||||
|
||||
content: str
|
||||
"""The content of the event."""
|
||||
|
||||
type: Literal["SelectorEvent"] = "SelectorEvent"
|
||||
|
||||
def to_text(self) -> str:
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class MessageFactory:
|
||||
""":meta private:
|
||||
|
||||
A factory for creating messages from JSON-serializable dictionaries.
|
||||
|
||||
This is useful for deserializing messages from JSON data.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._message_types: Dict[str, type[BaseAgentEvent | BaseChatMessage]] = {}
|
||||
# Register all message types.
|
||||
self._message_types[TextMessage.__name__] = TextMessage
|
||||
self._message_types[MultiModalMessage.__name__] = MultiModalMessage
|
||||
self._message_types[StopMessage.__name__] = StopMessage
|
||||
self._message_types[ToolCallSummaryMessage.__name__] = ToolCallSummaryMessage
|
||||
self._message_types[HandoffMessage.__name__] = HandoffMessage
|
||||
self._message_types[ToolCallRequestEvent.__name__] = ToolCallRequestEvent
|
||||
self._message_types[ToolCallExecutionEvent.__name__] = ToolCallExecutionEvent
|
||||
self._message_types[MemoryQueryEvent.__name__] = MemoryQueryEvent
|
||||
self._message_types[UserInputRequestedEvent.__name__] = UserInputRequestedEvent
|
||||
self._message_types[ModelClientStreamingChunkEvent.__name__] = ModelClientStreamingChunkEvent
|
||||
self._message_types[ThoughtEvent.__name__] = ThoughtEvent
|
||||
self._message_types[SelectSpeakerEvent.__name__] = SelectSpeakerEvent
|
||||
self._message_types[CodeGenerationEvent.__name__] = CodeGenerationEvent
|
||||
self._message_types[CodeExecutionEvent.__name__] = CodeExecutionEvent
|
||||
|
||||
def is_registered(self, message_type: type[BaseAgentEvent | BaseChatMessage]) -> bool:
|
||||
"""Check if a message type is registered with the factory."""
|
||||
# Get the class name of the message type.
|
||||
class_name = message_type.__name__
|
||||
# Check if the class name is already registered.
|
||||
return class_name in self._message_types
|
||||
|
||||
def register(self, message_type: type[BaseAgentEvent | BaseChatMessage]) -> None:
|
||||
"""Register a new message type with the factory."""
|
||||
if self.is_registered(message_type):
|
||||
raise ValueError(f"Message type {message_type} is already registered.")
|
||||
if not issubclass(message_type, BaseChatMessage) and not issubclass(message_type, BaseAgentEvent):
|
||||
raise ValueError(f"Message type {message_type} must be a subclass of BaseChatMessage or BaseAgentEvent.")
|
||||
# Get the class name of the
|
||||
class_name = message_type.__name__
|
||||
# Check if the class name is already registered.
|
||||
# Register the message type.
|
||||
self._message_types[class_name] = message_type
|
||||
|
||||
def create(self, data: Mapping[str, Any]) -> BaseAgentEvent | BaseChatMessage:
|
||||
"""Create a message from a dictionary of JSON-serializable data."""
|
||||
# Get the type of the message from the dictionary.
|
||||
message_type = data.get("type")
|
||||
if message_type is None:
|
||||
raise ValueError("Field 'type' is required in the message data to recover the message type.")
|
||||
if message_type not in self._message_types:
|
||||
raise ValueError(f"Unknown message type: {message_type}")
|
||||
if not isinstance(message_type, str):
|
||||
raise ValueError(f"Message type must be a string, got {type(message_type)}")
|
||||
|
||||
# Get the class for the message type.
|
||||
message_class = self._message_types[message_type]
|
||||
|
||||
# Create an instance of the message class.
|
||||
assert issubclass(message_class, BaseChatMessage) or issubclass(message_class, BaseAgentEvent)
|
||||
return message_class.load(data)
|
||||
|
||||
|
||||
ChatMessage = Annotated[
|
||||
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
"""The union type of all built-in concrete subclasses of :class:`BaseChatMessage`.
|
||||
It does not include :class:`StructuredMessage` types."""
|
||||
|
||||
AgentEvent = Annotated[
|
||||
ToolCallRequestEvent
|
||||
| ToolCallExecutionEvent
|
||||
| MemoryQueryEvent
|
||||
| UserInputRequestedEvent
|
||||
| ModelClientStreamingChunkEvent
|
||||
| ThoughtEvent
|
||||
| SelectSpeakerEvent
|
||||
| CodeGenerationEvent
|
||||
| CodeExecutionEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
"""The union type of all built-in concrete subclasses of :class:`BaseAgentEvent`."""
|
||||
|
||||
__all__ = [
|
||||
"AgentEvent",
|
||||
"BaseMessage",
|
||||
"ChatMessage",
|
||||
"BaseChatMessage",
|
||||
"BaseAgentEvent",
|
||||
"BaseTextChatMessage",
|
||||
"StructuredContentType",
|
||||
"StructuredMessage",
|
||||
"StructuredMessageFactory",
|
||||
"HandoffMessage",
|
||||
"MultiModalMessage",
|
||||
"StopMessage",
|
||||
"TextMessage",
|
||||
"ToolCallExecutionEvent",
|
||||
"ToolCallRequestEvent",
|
||||
"ToolCallSummaryMessage",
|
||||
"MemoryQueryEvent",
|
||||
"UserInputRequestedEvent",
|
||||
"ModelClientStreamingChunkEvent",
|
||||
"ThoughtEvent",
|
||||
"SelectSpeakerEvent",
|
||||
"MessageFactory",
|
||||
"CodeGenerationEvent",
|
||||
"CodeExecutionEvent",
|
||||
]
|
||||
0
agent_dhal/agentdhal_agentchat/py.typed
Normal file
0
agent_dhal/agentdhal_agentchat/py.typed
Normal file
27
agent_dhal/agentdhal_agentchat/state/__init__.py
Normal file
27
agent_dhal/agentdhal_agentchat/state/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""State management for agents, teams and termination conditions."""
|
||||
|
||||
from ._states import (
|
||||
AssistantAgentState,
|
||||
BaseGroupChatManagerState,
|
||||
BaseState,
|
||||
ChatAgentContainerState,
|
||||
MagenticOneOrchestratorState,
|
||||
RoundRobinManagerState,
|
||||
SelectorManagerState,
|
||||
SocietyOfMindAgentState,
|
||||
SwarmManagerState,
|
||||
TeamState,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseState",
|
||||
"AssistantAgentState",
|
||||
"BaseGroupChatManagerState",
|
||||
"ChatAgentContainerState",
|
||||
"RoundRobinManagerState",
|
||||
"SelectorManagerState",
|
||||
"SwarmManagerState",
|
||||
"MagenticOneOrchestratorState",
|
||||
"TeamState",
|
||||
"SocietyOfMindAgentState",
|
||||
]
|
||||
79
agent_dhal/agentdhal_agentchat/state/_states.py
Normal file
79
agent_dhal/agentdhal_agentchat/state/_states.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BaseState(BaseModel):
|
||||
"""Base class for all saveable state"""
|
||||
|
||||
type: str = Field(default="BaseState")
|
||||
version: str = Field(default="1.0.0")
|
||||
|
||||
|
||||
class HalState(BaseState):
|
||||
"""State for an assistant agent."""
|
||||
|
||||
llm_context: Mapping[str, Any] = Field(default_factory=lambda: dict([("messages", [])]))
|
||||
type: str = Field(default="AssistantAgentState")
|
||||
|
||||
|
||||
class TeamState(BaseState):
|
||||
"""State for a team of agents."""
|
||||
|
||||
agent_states: Mapping[str, Any] = Field(default_factory=dict)
|
||||
type: str = Field(default="TeamState")
|
||||
|
||||
|
||||
class BaseGroupChatManagerState(BaseState):
|
||||
"""Base state for all group chat managers."""
|
||||
|
||||
message_thread: List[Mapping[str, Any]] = Field(default_factory=list)
|
||||
current_turn: int = Field(default=0)
|
||||
type: str = Field(default="BaseGroupChatManagerState")
|
||||
|
||||
|
||||
class ChatAgentContainerState(BaseState):
|
||||
"""State for a container of chat agents."""
|
||||
|
||||
agent_state: Mapping[str, Any] = Field(default_factory=dict)
|
||||
message_buffer: List[Mapping[str, Any]] = Field(default_factory=list)
|
||||
type: str = Field(default="ChatAgentContainerState")
|
||||
|
||||
|
||||
class RoundRobinManagerState(BaseGroupChatManagerState):
|
||||
"""State for :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` manager."""
|
||||
|
||||
next_speaker_index: int = Field(default=0)
|
||||
type: str = Field(default="RoundRobinManagerState")
|
||||
|
||||
|
||||
class SelectorManagerState(BaseGroupChatManagerState):
|
||||
"""State for :class:`~agentdhal_agentchat.teams.SelectorGroupChat` manager."""
|
||||
|
||||
previous_speaker: Optional[str] = Field(default=None)
|
||||
type: str = Field(default="SelectorManagerState")
|
||||
|
||||
|
||||
class SwarmManagerState(BaseGroupChatManagerState):
|
||||
"""State for :class:`~agentdhal_agentchat.teams.Swarm` manager."""
|
||||
|
||||
current_speaker: str = Field(default="")
|
||||
type: str = Field(default="SwarmManagerState")
|
||||
|
||||
|
||||
class MagenticOneOrchestratorState(BaseGroupChatManagerState):
|
||||
"""State for :class:`~agentdhal_agentchat.teams.MagneticOneGroupChat` orchestrator."""
|
||||
|
||||
task: str = Field(default="")
|
||||
facts: str = Field(default="")
|
||||
plan: str = Field(default="")
|
||||
n_rounds: int = Field(default=0)
|
||||
n_stalls: int = Field(default=0)
|
||||
type: str = Field(default="MagenticOneOrchestratorState")
|
||||
|
||||
|
||||
class SocietyOfMindAgentState(BaseState):
|
||||
"""State for a Society of Mind agent."""
|
||||
|
||||
inner_team_state: Mapping[str, Any] = Field(default_factory=dict)
|
||||
type: str = Field(default="SocietyOfMindAgentState")
|
||||
30
agent_dhal/agentdhal_agentchat/teams/__init__.py
Normal file
30
agent_dhal/agentdhal_agentchat/teams/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
This module provides implementation of various pre-defined multi-agent teams.
|
||||
Each team inherits from the BaseGroupChat class.
|
||||
"""
|
||||
|
||||
from ._group_chat._base_group_chat import BaseGroupChat
|
||||
from ._group_chat._graph import (
|
||||
DiGraph,
|
||||
DiGraphBuilder,
|
||||
DiGraphEdge,
|
||||
DiGraphNode,
|
||||
GraphFlow,
|
||||
)
|
||||
from ._group_chat._magentic_one import MagenticOneGroupChat
|
||||
from ._group_chat._round_robin_group_chat import RoundRobinGroupChat
|
||||
from ._group_chat._selector_group_chat import SelectorGroupChat
|
||||
from ._group_chat._swarm_group_chat import Swarm
|
||||
|
||||
__all__ = [
|
||||
"BaseGroupChat",
|
||||
"RoundRobinGroupChat",
|
||||
"SelectorGroupChat",
|
||||
"Swarm",
|
||||
"MagenticOneGroupChat",
|
||||
"DiGraphBuilder",
|
||||
"DiGraph",
|
||||
"DiGraphNode",
|
||||
"DiGraphEdge",
|
||||
"GraphFlow",
|
||||
]
|
||||
@@ -0,0 +1,834 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Sequence
|
||||
|
||||
from agentdhal_core import (
|
||||
AgentId,
|
||||
AgentRuntime,
|
||||
AgentType,
|
||||
CancellationToken,
|
||||
ComponentBase,
|
||||
SingleThreadedAgentRuntime,
|
||||
TypeSubscription,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
MessageFactory,
|
||||
ModelClientStreamingChunkEvent,
|
||||
StopMessage,
|
||||
StructuredMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from ...state import TeamState
|
||||
from ._chat_agent_container import ChatAgentContainer
|
||||
from ._events import (
|
||||
GroupChatPause,
|
||||
GroupChatReset,
|
||||
GroupChatResume,
|
||||
GroupChatStart,
|
||||
GroupChatTermination,
|
||||
SerializableException,
|
||||
)
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
|
||||
class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
"""The base class for group chat teams.
|
||||
|
||||
In a group chat team, participants share context by publishing their messages
|
||||
to all other participants.
|
||||
|
||||
If an :class:`~agentdhal_agentchat.base.ChatAgent` is a participant,
|
||||
the :class:`~agentdhal_agentchat.messages.BaseChatMessage` from the agent response's
|
||||
:attr:`~agentdhal_agentchat.base.Response.chat_message` will be published
|
||||
to other participants in the group chat.
|
||||
|
||||
If a :class:`~agentdhal_agentchat.base.Team` is a participant,
|
||||
the :class:`~agentdhal_agentchat.messages.BaseChatMessage`
|
||||
from the team result' :attr:`~agentdhal_agentchat.base.TaskResult.messages` will be published
|
||||
to other participants in the group chat.
|
||||
|
||||
To implement a group chat team, first create a subclass of :class:`BaseGroupChatManager` and then
|
||||
create a subclass of :class:`BaseGroupChat` that uses the group chat manager.
|
||||
|
||||
This base class provides the mapping between the agents of the AgentChat API
|
||||
and the agent runtime of the Core API, and handles high-level features like
|
||||
running, pausing, resuming, and resetting the team.
|
||||
"""
|
||||
|
||||
component_type = "team"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
participants: List[ChatAgent | Team],
|
||||
group_chat_manager_name: str,
|
||||
group_chat_manager_class: type[SequentialRoutedAgent],
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
):
|
||||
self._name = name
|
||||
self._description = description
|
||||
if len(participants) == 0:
|
||||
raise ValueError("At least one participant is required.")
|
||||
if len(participants) != len(set(participant.name for participant in participants)):
|
||||
raise ValueError("The participant names must be unique.")
|
||||
self._participants = participants
|
||||
self._base_group_chat_manager_class = group_chat_manager_class
|
||||
self._termination_condition = termination_condition
|
||||
self._max_turns = max_turns
|
||||
self._message_factory = MessageFactory()
|
||||
if custom_message_types is not None:
|
||||
for message_type in custom_message_types:
|
||||
self._message_factory.register(message_type)
|
||||
|
||||
for agent in participants:
|
||||
if isinstance(agent, ChatAgent):
|
||||
for message_type in agent.produced_message_types:
|
||||
try:
|
||||
is_registered = self._message_factory.is_registered(message_type) # type: ignore[reportUnknownArgumentType]
|
||||
if issubclass(message_type, StructuredMessage) and not is_registered:
|
||||
self._message_factory.register(message_type) # type: ignore[reportUnknownArgumentType]
|
||||
except TypeError:
|
||||
# Not a class or not a valid subclassable type (skip)
|
||||
pass
|
||||
|
||||
# The team ID is a UUID that is used to identify the team and its participants
|
||||
# in the agent runtime. It is used to create unique topic types for each participant.
|
||||
# Currently, team ID is binded to an object instance of the group chat class.
|
||||
# So if you create two instances of group chat, there will be two teams with different IDs.
|
||||
self._team_id = str(uuid.uuid4())
|
||||
|
||||
# Constants for the group chat team.
|
||||
# The names are used to identify the agents within the team.
|
||||
# The names may not be unique across different teams.
|
||||
self._group_chat_manager_name = group_chat_manager_name
|
||||
self._participant_names: List[str] = [participant.name for participant in participants]
|
||||
self._participant_descriptions: List[str] = [participant.description for participant in participants]
|
||||
# The group chat topic type is used for broadcast communication among all participants and the group chat manager.
|
||||
self._group_topic_type = f"group_topic_{self._team_id}"
|
||||
# The group chat manager topic type is used for direct communication with the group chat manager.
|
||||
self._group_chat_manager_topic_type = f"{self._group_chat_manager_name}_{self._team_id}"
|
||||
# The participant topic types are used for direct communication with each participant.
|
||||
self._participant_topic_types: List[str] = [
|
||||
f"{participant.name}_{self._team_id}" for participant in participants
|
||||
]
|
||||
# The output topic type is used for emitting streaming messages from the group chat.
|
||||
# The group chat manager will relay the messages to the output message queue.
|
||||
self._output_topic_type = f"output_topic_{self._team_id}"
|
||||
|
||||
# The queue for collecting the output messages.
|
||||
self._output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination] = (
|
||||
asyncio.Queue()
|
||||
)
|
||||
|
||||
# Create a runtime for the team.
|
||||
if runtime is not None:
|
||||
self._runtime = runtime
|
||||
self._embedded_runtime = False
|
||||
else:
|
||||
# Use a embedded single-threaded runtime for the group chat.
|
||||
# Background exceptions must not be ignored as it results in non-surfaced exceptions and early team termination.
|
||||
self._runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
|
||||
self._embedded_runtime = True
|
||||
|
||||
# Flag to track if the group chat has been initialized.
|
||||
self._initialized = False
|
||||
|
||||
# Flag to track if the group chat is running.
|
||||
self._is_running = False
|
||||
|
||||
# Flag to track if the team events should be emitted.
|
||||
self._emit_team_events = emit_team_events
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the group chat team."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""A description of the group chat team."""
|
||||
return self._description
|
||||
|
||||
@abstractmethod
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], SequentialRoutedAgent]: ...
|
||||
|
||||
def _create_participant_factory(
|
||||
self,
|
||||
parent_topic_type: str,
|
||||
output_topic_type: str,
|
||||
agent: ChatAgent | Team,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], ChatAgentContainer]:
|
||||
def _factory() -> ChatAgentContainer:
|
||||
container = ChatAgentContainer(parent_topic_type, output_topic_type, agent, message_factory)
|
||||
return container
|
||||
|
||||
return _factory
|
||||
|
||||
async def _init(self, runtime: AgentRuntime) -> None:
|
||||
# Constants for the group chat manager.
|
||||
group_chat_manager_agent_type = AgentType(self._group_chat_manager_topic_type)
|
||||
|
||||
# Register participants.
|
||||
# Use the participant topic type as the agent type.
|
||||
for participant, agent_type in zip(self._participants, self._participant_topic_types, strict=True):
|
||||
# Register the participant factory.
|
||||
await ChatAgentContainer.register(
|
||||
runtime,
|
||||
type=agent_type,
|
||||
factory=self._create_participant_factory(
|
||||
self._group_topic_type, self._output_topic_type, participant, self._message_factory
|
||||
),
|
||||
)
|
||||
# Add subscriptions for the participant.
|
||||
# The participant should be able to receive messages from its own topic.
|
||||
await runtime.add_subscription(TypeSubscription(topic_type=agent_type, agent_type=agent_type))
|
||||
# The participant should be able to receive messages from the group topic.
|
||||
await runtime.add_subscription(TypeSubscription(topic_type=self._group_topic_type, agent_type=agent_type))
|
||||
|
||||
# Register the group chat manager.
|
||||
await self._base_group_chat_manager_class.register(
|
||||
runtime,
|
||||
type=group_chat_manager_agent_type.type,
|
||||
factory=self._create_group_chat_manager_factory(
|
||||
name=self._group_chat_manager_name,
|
||||
group_topic_type=self._group_topic_type,
|
||||
output_topic_type=self._output_topic_type,
|
||||
participant_names=self._participant_names,
|
||||
participant_topic_types=self._participant_topic_types,
|
||||
participant_descriptions=self._participant_descriptions,
|
||||
output_message_queue=self._output_message_queue,
|
||||
termination_condition=self._termination_condition,
|
||||
max_turns=self._max_turns,
|
||||
message_factory=self._message_factory,
|
||||
),
|
||||
)
|
||||
# Add subscriptions for the group chat manager.
|
||||
# The group chat manager should be able to receive messages from the its own topic.
|
||||
await runtime.add_subscription(
|
||||
TypeSubscription(
|
||||
topic_type=self._group_chat_manager_topic_type, agent_type=group_chat_manager_agent_type.type
|
||||
)
|
||||
)
|
||||
# The group chat manager should be able to receive messages from the group topic.
|
||||
await runtime.add_subscription(
|
||||
TypeSubscription(topic_type=self._group_topic_type, agent_type=group_chat_manager_agent_type.type)
|
||||
)
|
||||
# The group chat manager will relay the messages from output topic to the output message queue.
|
||||
await runtime.add_subscription(
|
||||
TypeSubscription(topic_type=self._output_topic_type, agent_type=group_chat_manager_agent_type.type)
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
output_task_messages: bool = True,
|
||||
) -> TaskResult:
|
||||
"""Run the team and return the result. The base implementation uses
|
||||
:meth:`run_stream` to run the team and then returns the final result.
|
||||
Once the team is stopped, the termination condition is reset.
|
||||
|
||||
Args:
|
||||
task (str | BaseChatMessage | Sequence[BaseChatMessage] | None): The task to run the team with. Can be a string, a single :class:`BaseChatMessage` , or a list of :class:`BaseChatMessage`.
|
||||
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
|
||||
Setting the cancellation token potentially put the team in an inconsistent state,
|
||||
and it may not reset the termination condition.
|
||||
To gracefully stop the team, use :class:`~agentdhal_agentchat.conditions.ExternalTermination` instead.
|
||||
|
||||
Returns:
|
||||
result: The result of the task as :class:`~agentdhal_agentchat.base.TaskResult`. The result contains the messages produced by the team and the stop reason.
|
||||
|
||||
Example using the :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` team:
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
termination = MaxMessageTermination(3)
|
||||
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
|
||||
|
||||
result = await team.run(task="Count from 1 to 10, respond one at a time.")
|
||||
print(result)
|
||||
|
||||
# Run the team again without a task to continue the previous task.
|
||||
result = await team.run()
|
||||
print(result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Example using the :class:`~agentdhal_core.CancellationToken` to cancel the task:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
termination = MaxMessageTermination(3)
|
||||
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
|
||||
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
# Create a task to run the team in the background.
|
||||
run_task = asyncio.create_task(
|
||||
team.run(
|
||||
task="Count from 1 to 10, respond one at a time.",
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for 1 second and then cancel the task.
|
||||
await asyncio.sleep(1)
|
||||
cancellation_token.cancel()
|
||||
|
||||
# This will raise a cancellation error.
|
||||
await run_task
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
result: TaskResult | None = None
|
||||
async for message in self.run_stream(
|
||||
task=task,
|
||||
cancellation_token=cancellation_token,
|
||||
output_task_messages=output_task_messages,
|
||||
):
|
||||
if isinstance(message, TaskResult):
|
||||
result = message
|
||||
if result is not None:
|
||||
return result
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
output_task_messages: bool = True,
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
|
||||
"""Run the team and produces a stream of messages and the final result
|
||||
of the type :class:`~agentdhal_agentchat.base.TaskResult` as the last item in the stream. Once the
|
||||
team is stopped, the termination condition is reset.
|
||||
|
||||
.. note::
|
||||
|
||||
If an agent produces :class:`~agentdhal_agentchat.messages.ModelClientStreamingChunkEvent`,
|
||||
the message will be yielded in the stream but it will not be included in the
|
||||
:attr:`~agentdhal_agentchat.base.TaskResult.messages`.
|
||||
|
||||
Args:
|
||||
task (str | BaseChatMessage | Sequence[BaseChatMessage] | None): The task to run the team with. Can be a string, a single :class:`BaseChatMessage` , or a list of :class:`BaseChatMessage`.
|
||||
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
|
||||
Setting the cancellation token potentially put the team in an inconsistent state,
|
||||
and it may not reset the termination condition.
|
||||
To gracefully stop the team, use :class:`~agentdhal_agentchat.conditions.ExternalTermination` instead.
|
||||
output_task_messages (bool): Whether to include task messages in the output stream. Defaults to True for backward compatibility.
|
||||
|
||||
Returns:
|
||||
stream: an :class:`~collections.abc.AsyncGenerator` that yields :class:`~agentdhal_agentchat.messages.BaseAgentEvent`, :class:`~agentdhal_agentchat.messages.BaseChatMessage`, and the final result :class:`~agentdhal_agentchat.base.TaskResult` as the last item in the stream.
|
||||
|
||||
Example using the :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` team:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
termination = MaxMessageTermination(3)
|
||||
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
|
||||
|
||||
stream = team.run_stream(task="Count from 1 to 10, respond one at a time.")
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
# Run the team again without a task to continue the previous task.
|
||||
stream = team.run_stream()
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Example using the :class:`~agentdhal_core.CancellationToken` to cancel the task:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
termination = MaxMessageTermination(3)
|
||||
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
|
||||
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
# Create a task to run the team in the background.
|
||||
run_task = asyncio.create_task(
|
||||
Console(
|
||||
team.run_stream(
|
||||
task="Count from 1 to 10, respond one at a time.",
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for 1 second and then cancel the task.
|
||||
await asyncio.sleep(1)
|
||||
cancellation_token.cancel()
|
||||
|
||||
# This will raise a cancellation error.
|
||||
await run_task
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
# Create the messages list if the task is a string or a chat message.
|
||||
messages: List[BaseChatMessage] | None = None
|
||||
if task is None:
|
||||
pass
|
||||
elif isinstance(task, str):
|
||||
messages = [TextMessage(content=task, source="user")]
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
messages = [task]
|
||||
elif isinstance(task, list):
|
||||
if not task:
|
||||
raise ValueError("Task list cannot be empty.")
|
||||
messages = []
|
||||
for msg in task:
|
||||
if not isinstance(msg, BaseChatMessage):
|
||||
raise ValueError("All messages in task list must be valid BaseChatMessage types")
|
||||
messages.append(msg)
|
||||
else:
|
||||
raise ValueError("Task must be a string, a BaseChatMessage, or a list of BaseChatMessage.")
|
||||
# Check if the messages types are registered with the message factory.
|
||||
if messages is not None:
|
||||
for msg in messages:
|
||||
if not self._message_factory.is_registered(msg.__class__):
|
||||
raise ValueError(
|
||||
f"Message type {msg.__class__} is not registered with the message factory. "
|
||||
"Please register it with the message factory by adding it to the "
|
||||
"custom_message_types list when creating the team."
|
||||
)
|
||||
|
||||
if self._is_running:
|
||||
raise ValueError("The team is already running, it cannot run again until it is stopped.")
|
||||
self._is_running = True
|
||||
|
||||
if self._embedded_runtime:
|
||||
# Start the embedded runtime.
|
||||
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
|
||||
self._runtime.start()
|
||||
|
||||
if not self._initialized:
|
||||
await self._init(self._runtime)
|
||||
|
||||
shutdown_task: asyncio.Task[None] | None = None
|
||||
if self._embedded_runtime:
|
||||
|
||||
async def stop_runtime() -> None:
|
||||
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
|
||||
try:
|
||||
# This will propagate any exceptions raised.
|
||||
await self._runtime.stop_when_idle()
|
||||
# Put a termination message in the queue to indicate that the group chat is stopped for whatever reason
|
||||
# but not due to an exception.
|
||||
await self._output_message_queue.put(
|
||||
GroupChatTermination(
|
||||
message=StopMessage(
|
||||
content="The group chat is stopped.", source=self._group_chat_manager_name
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Stop the consumption of messages and end the stream.
|
||||
# NOTE: we also need to put a GroupChatTermination event here because when the runtime
|
||||
# has an exception, the group chat manager may not be able to put a GroupChatTermination event in the queue.
|
||||
# This may not be necessary if the group chat manager is able to handle the exception and put the event in the queue.
|
||||
await self._output_message_queue.put(
|
||||
GroupChatTermination(
|
||||
message=StopMessage(
|
||||
content="An exception occurred in the runtime.", source=self._group_chat_manager_name
|
||||
),
|
||||
error=SerializableException.from_exception(e),
|
||||
)
|
||||
)
|
||||
|
||||
# Create a background task to stop the runtime when the group chat
|
||||
# is stopped or has an exception.
|
||||
shutdown_task = asyncio.create_task(stop_runtime())
|
||||
|
||||
try:
|
||||
# Run the team by sending the start message to the group chat manager.
|
||||
# The group chat manager will start the group chat by relaying the message to the participants
|
||||
# and the group chat manager.
|
||||
await self._runtime.send_message(
|
||||
GroupChatStart(messages=messages, output_task_messages=output_task_messages),
|
||||
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
# Collect the output messages in order.
|
||||
output_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
stop_reason: str | None = None
|
||||
|
||||
# Yield the messages until the queue is empty.
|
||||
while True:
|
||||
message_future = asyncio.ensure_future(self._output_message_queue.get())
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(message_future)
|
||||
# Wait for the next message, this will raise an exception if the task is cancelled.
|
||||
message = await message_future
|
||||
if isinstance(message, GroupChatTermination):
|
||||
# If the message contains an error, we need to raise it here.
|
||||
# This will stop the team and propagate the error.
|
||||
if message.error is not None:
|
||||
raise RuntimeError(str(message.error))
|
||||
stop_reason = message.message.content
|
||||
break
|
||||
yield message
|
||||
if isinstance(message, ModelClientStreamingChunkEvent):
|
||||
# Skip the model client streaming chunk events.
|
||||
continue
|
||||
output_messages.append(message)
|
||||
|
||||
# Yield the final result.
|
||||
yield TaskResult(messages=output_messages, stop_reason=stop_reason)
|
||||
|
||||
finally:
|
||||
try:
|
||||
if shutdown_task is not None:
|
||||
# Wait for the shutdown task to finish.
|
||||
# This will propagate any exceptions raised.
|
||||
await shutdown_task
|
||||
finally:
|
||||
# Clear the output message queue.
|
||||
while not self._output_message_queue.empty():
|
||||
self._output_message_queue.get_nowait()
|
||||
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset the team and its participants to their initial state.
|
||||
|
||||
The team must be stopped before it can be reset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the team has not been initialized or is currently running.
|
||||
|
||||
Example using the :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` team:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
termination = MaxMessageTermination(3)
|
||||
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
|
||||
stream = team.run_stream(task="Count from 1 to 10, respond one at a time.")
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
# Reset the team.
|
||||
await team.reset()
|
||||
stream = team.run_stream(task="Count from 1 to 10, respond one at a time.")
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
if not self._initialized:
|
||||
await self._init(self._runtime)
|
||||
|
||||
if self._is_running:
|
||||
raise RuntimeError("The group chat is currently running. It must be stopped before it can be reset.")
|
||||
self._is_running = True
|
||||
|
||||
if self._embedded_runtime:
|
||||
# Start the runtime.
|
||||
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
|
||||
self._runtime.start()
|
||||
|
||||
try:
|
||||
# Send a reset messages to all participants.
|
||||
for participant_topic_type in self._participant_topic_types:
|
||||
await self._runtime.send_message(
|
||||
GroupChatReset(),
|
||||
recipient=AgentId(type=participant_topic_type, key=self._team_id),
|
||||
)
|
||||
# Send a reset message to the group chat manager.
|
||||
await self._runtime.send_message(
|
||||
GroupChatReset(),
|
||||
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
|
||||
)
|
||||
finally:
|
||||
if self._embedded_runtime:
|
||||
# Stop the runtime.
|
||||
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
|
||||
await self._runtime.stop_when_idle()
|
||||
|
||||
# Reset the output message queue.
|
||||
while not self._output_message_queue.empty():
|
||||
self._output_message_queue.get_nowait()
|
||||
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
|
||||
async def pause(self) -> None:
|
||||
"""Pause its participants when the team is running by calling their
|
||||
:meth:`~agentdhal_agentchat.base.ChatAgent.on_pause` method via direct RPC calls.
|
||||
|
||||
.. attention::
|
||||
|
||||
This is an experimental feature introduced in v0.4.9 and may subject
|
||||
to change or removal in the future.
|
||||
|
||||
The team must be initialized before it can be paused.
|
||||
|
||||
Different from termination, pausing the team does not cause the
|
||||
:meth:`run` or :meth:`run_stream` method to return. It calls the
|
||||
:meth:`~agentdhal_agentchat.base.ChatAgent.on_pause` method on each
|
||||
participant, and if the participant does not implement the method, it
|
||||
will be a no-op.
|
||||
|
||||
.. note::
|
||||
|
||||
It is the responsibility of the agent class to handle the pause
|
||||
and ensure that the agent can be resumed later.
|
||||
Make sure to implement the :meth:`~agentdhal_agentchat.agents.BaseChatAgent.on_pause`
|
||||
method in your agent class for custom pause behavior.
|
||||
By default, the agent will not do anything when called.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the team has not been initialized. Exceptions from
|
||||
the participants when calling their implementations of
|
||||
:class:`~agentdhal_agentchat.base.ChatAgent.on_pause` are
|
||||
propagated to this method and raised.
|
||||
"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("The group chat has not been initialized. It must be run before it can be paused.")
|
||||
|
||||
# Send a pause message to all participants.
|
||||
for participant_topic_type in self._participant_topic_types:
|
||||
await self._runtime.send_message(
|
||||
GroupChatPause(),
|
||||
recipient=AgentId(type=participant_topic_type, key=self._team_id),
|
||||
)
|
||||
# Send a pause message to the group chat manager.
|
||||
await self._runtime.send_message(
|
||||
GroupChatPause(),
|
||||
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
|
||||
)
|
||||
|
||||
async def resume(self) -> None:
|
||||
"""Resume its participants when the team is running and paused by calling their
|
||||
:meth:`~agentdhal_agentchat.base.ChatAgent.on_resume` method via direct RPC calls.
|
||||
|
||||
.. attention::
|
||||
|
||||
This is an experimental feature introduced in v0.4.9 and may subject
|
||||
to change or removal in the future.
|
||||
|
||||
The team must be initialized before it can be resumed.
|
||||
|
||||
Different from termination and restart with a new task, resuming the team
|
||||
does not cause the :meth:`run` or :meth:`run_stream` method to return.
|
||||
It calls the :meth:`~agentdhal_agentchat.base.ChatAgent.on_resume` method on each
|
||||
participant, and if the participant does not implement the method, it
|
||||
will be a no-op.
|
||||
|
||||
.. note::
|
||||
|
||||
It is the responsibility of the agent class to handle the resume
|
||||
and ensure that the agent continues from where it was paused.
|
||||
Make sure to implement the :meth:`~agentdhal_agentchat.agents.BaseChatAgent.on_resume`
|
||||
method in your agent class for custom resume behavior.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the team has not been initialized. Exceptions from
|
||||
the participants when calling their implementations of :class:`~agentdhal_agentchat.base.ChatAgent.on_resume`
|
||||
method are propagated to this method and raised.
|
||||
|
||||
"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("The group chat has not been initialized. It must be run before it can be resumed.")
|
||||
|
||||
# Send a resume message to all participants.
|
||||
for participant_topic_type in self._participant_topic_types:
|
||||
await self._runtime.send_message(
|
||||
GroupChatResume(),
|
||||
recipient=AgentId(type=participant_topic_type, key=self._team_id),
|
||||
)
|
||||
# Send a resume message to the group chat manager.
|
||||
await self._runtime.send_message(
|
||||
GroupChatResume(),
|
||||
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
|
||||
)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the group chat team.
|
||||
|
||||
The state is saved by calling the :meth:`~agentdhal_core.AgentRuntime.agent_save_state` method
|
||||
on each participant and the group chat manager with their internal agent ID.
|
||||
The state is returned as a nested dictionary: a dictionary with key `agent_states`,
|
||||
which is a dictionary the agent names as keys and the state as values.
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
{
|
||||
"agent_states": {
|
||||
"agent1": ...,
|
||||
"agent2": ...,
|
||||
"RoundRobinGroupChatManager": ...
|
||||
}
|
||||
}
|
||||
|
||||
.. note::
|
||||
|
||||
Starting v0.4.9, the state is using the agent name as the key instead of the agent ID,
|
||||
and the `team_id` field is removed from the state. This is to allow the state to be
|
||||
portable across different teams and runtimes. States saved with the old format
|
||||
may not be compatible with the new format in the future.
|
||||
|
||||
.. caution::
|
||||
|
||||
When calling :func:`~agentdhal_agentchat.teams.BaseGroupChat.save_state` on a team
|
||||
while it is running, the state may not be consistent and may result in an unexpected state.
|
||||
It is recommended to call this method when the team is not running or after it is stopped.
|
||||
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self._init(self._runtime)
|
||||
|
||||
# Store state of each agent by their name.
|
||||
# NOTE: we don't use the agent ID as the key here because we need to be able to decouple
|
||||
# the state of the agents from their identities in the agent runtime.
|
||||
agent_states: Dict[str, Mapping[str, Any]] = {}
|
||||
# Save the state of all participants.
|
||||
for name, agent_type in zip(self._participant_names, self._participant_topic_types, strict=True):
|
||||
agent_id = AgentId(type=agent_type, key=self._team_id)
|
||||
# NOTE: We are using the runtime's save state method rather than the agent instance's
|
||||
# save_state method because we want to support saving state of remote agents.
|
||||
agent_states[name] = await self._runtime.agent_save_state(agent_id)
|
||||
# Save the state of the group chat manager.
|
||||
agent_id = AgentId(type=self._group_chat_manager_topic_type, key=self._team_id)
|
||||
agent_states[self._group_chat_manager_name] = await self._runtime.agent_save_state(agent_id)
|
||||
return TeamState(agent_states=agent_states).model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load an external state and overwrite the current state of the group chat team.
|
||||
|
||||
The state is loaded by calling the :meth:`~agentdhal_core.AgentRuntime.agent_load_state` method
|
||||
on each participant and the group chat manager with their internal agent ID.
|
||||
See :meth:`~agentdhal_agentchat.teams.BaseGroupChat.save_state` for the expected format of the state.
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self._init(self._runtime)
|
||||
|
||||
if self._is_running:
|
||||
raise RuntimeError("The team cannot be loaded while it is running.")
|
||||
self._is_running = True
|
||||
|
||||
try:
|
||||
team_state = TeamState.model_validate(state)
|
||||
# Load the state of all participants.
|
||||
for name, agent_type in zip(self._participant_names, self._participant_topic_types, strict=True):
|
||||
agent_id = AgentId(type=agent_type, key=self._team_id)
|
||||
if name not in team_state.agent_states:
|
||||
raise ValueError(f"Agent state for {name} not found in the saved state.")
|
||||
await self._runtime.agent_load_state(agent_id, team_state.agent_states[name])
|
||||
# Load the state of the group chat manager.
|
||||
agent_id = AgentId(type=self._group_chat_manager_topic_type, key=self._team_id)
|
||||
if self._group_chat_manager_name not in team_state.agent_states:
|
||||
raise ValueError(f"Agent state for {self._group_chat_manager_name} not found in the saved state.")
|
||||
await self._runtime.agent_load_state(agent_id, team_state.agent_states[self._group_chat_manager_name])
|
||||
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
"Invalid state format. The expected state format has changed since v0.4.9. "
|
||||
"Please read the release note on GitHub."
|
||||
) from e
|
||||
|
||||
finally:
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
@@ -0,0 +1,326 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Sequence
|
||||
|
||||
from agentdhal_core import CancellationToken, DefaultTopicId, MessageContext, event, rpc
|
||||
|
||||
from ...base import TerminationCondition
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, SelectSpeakerEvent, StopMessage
|
||||
from ._events import (
|
||||
GroupChatAgentResponse,
|
||||
GroupChatError,
|
||||
GroupChatMessage,
|
||||
GroupChatPause,
|
||||
GroupChatRequestPublish,
|
||||
GroupChatReset,
|
||||
GroupChatResume,
|
||||
GroupChatStart,
|
||||
GroupChatTeamResponse,
|
||||
GroupChatTermination,
|
||||
SerializableException,
|
||||
)
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
|
||||
class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
"""Base class for a group chat manager that manages a group chat with multiple participants.
|
||||
|
||||
It is the responsibility of the caller to ensure:
|
||||
- All participants must subscribe to the group chat topic and each of their own topics.
|
||||
- The group chat manager must subscribe to the group chat topic.
|
||||
- The agent types of the participants must be unique.
|
||||
- For each participant, the agent type must be the same as the topic type.
|
||||
|
||||
Without the above conditions, the group chat will not function correctly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
emit_team_events: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
description="Group chat manager",
|
||||
sequential_message_types=[
|
||||
GroupChatStart,
|
||||
GroupChatAgentResponse,
|
||||
GroupChatTeamResponse,
|
||||
GroupChatMessage,
|
||||
GroupChatReset,
|
||||
],
|
||||
)
|
||||
if max_turns is not None and max_turns <= 0:
|
||||
raise ValueError("The maximum number of turns must be greater than 0.")
|
||||
if len(participant_topic_types) != len(participant_descriptions):
|
||||
raise ValueError("The number of participant topic types, agent types, and descriptions must be the same.")
|
||||
if len(set(participant_topic_types)) != len(participant_topic_types):
|
||||
raise ValueError("The participant topic types must be unique.")
|
||||
if group_topic_type in participant_topic_types:
|
||||
raise ValueError("The group topic type must not be in the participant topic types.")
|
||||
self._name = name
|
||||
self._group_topic_type = group_topic_type
|
||||
self._output_topic_type = output_topic_type
|
||||
self._participant_names = participant_names
|
||||
self._participant_name_to_topic_type = {
|
||||
name: topic_type for name, topic_type in zip(participant_names, participant_topic_types, strict=True)
|
||||
}
|
||||
self._participant_descriptions = participant_descriptions
|
||||
self._message_thread: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
self._output_message_queue = output_message_queue
|
||||
self._termination_condition = termination_condition
|
||||
self._max_turns = max_turns
|
||||
self._current_turn = 0
|
||||
self._message_factory = message_factory
|
||||
self._emit_team_events = emit_team_events
|
||||
self._active_speakers: List[str] = []
|
||||
|
||||
@rpc
|
||||
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
|
||||
"""Handle the start of a group chat by selecting a speaker to start the conversation."""
|
||||
|
||||
# Check if the conversation has already terminated.
|
||||
if self._termination_condition is not None and self._termination_condition.terminated:
|
||||
early_stop_message = StopMessage(
|
||||
content="The group chat has already terminated.",
|
||||
source=self._name,
|
||||
)
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(early_stop_message)
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Validate the group state given the start messages
|
||||
await self.validate_group_state(message.messages)
|
||||
|
||||
if message.messages is not None:
|
||||
# Log all messages at once
|
||||
await self.publish_message(
|
||||
GroupChatStart(messages=message.messages),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
|
||||
# Only put messages in output queue if output_task_messages is True
|
||||
if message.output_task_messages:
|
||||
for msg in message.messages:
|
||||
await self._output_message_queue.put(msg)
|
||||
|
||||
# Relay all messages at once to participants
|
||||
await self.publish_message(
|
||||
GroupChatStart(messages=message.messages),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
|
||||
# Append all messages to thread
|
||||
await self.update_message_thread(message.messages)
|
||||
|
||||
# Check termination condition after processing all messages
|
||||
if await self._apply_termination_condition(message.messages):
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Select speakers to start/continue the conversation
|
||||
await self._transition_to_next_speakers(ctx.cancellation_token)
|
||||
|
||||
@event
|
||||
async def handle_agent_response(
|
||||
self, message: GroupChatAgentResponse | GroupChatTeamResponse, ctx: MessageContext
|
||||
) -> None:
|
||||
try:
|
||||
# Construct the detla from the agent response.
|
||||
delta: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if isinstance(message, GroupChatAgentResponse):
|
||||
if message.response.inner_messages is not None:
|
||||
for inner_message in message.response.inner_messages:
|
||||
delta.append(inner_message)
|
||||
delta.append(message.response.chat_message)
|
||||
else:
|
||||
delta.extend(message.result.messages)
|
||||
|
||||
# Append the messages to the message thread.
|
||||
await self.update_message_thread(delta)
|
||||
|
||||
# Remove the agent from the active speakers list.
|
||||
self._active_speakers.remove(message.name)
|
||||
if len(self._active_speakers) > 0:
|
||||
# If there are still active speakers, return without doing anything.
|
||||
return
|
||||
|
||||
# Check if the conversation should be terminated.
|
||||
if await self._apply_termination_condition(delta, increment_turn_count=True):
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Select speakers to continue the conversation.
|
||||
await self._transition_to_next_speakers(ctx.cancellation_token)
|
||||
except Exception as e:
|
||||
# Handle the exception and signal termination with an error.
|
||||
error = SerializableException.from_exception(e)
|
||||
await self._signal_termination_with_error(error)
|
||||
# Raise the exception to the runtime.
|
||||
raise
|
||||
|
||||
async def _transition_to_next_speakers(self, cancellation_token: CancellationToken) -> None:
|
||||
speaker_names_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
|
||||
# Link the select speaker future to the cancellation token.
|
||||
cancellation_token.link_future(speaker_names_future)
|
||||
speaker_names = await speaker_names_future
|
||||
if isinstance(speaker_names, str):
|
||||
# If only one speaker is selected, convert it to a list.
|
||||
speaker_names = [speaker_names]
|
||||
for speaker_name in speaker_names:
|
||||
if speaker_name not in self._participant_name_to_topic_type:
|
||||
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
|
||||
await self._log_speaker_selection(speaker_names)
|
||||
|
||||
# Send request to publish message to the next speakers
|
||||
for speaker_name in speaker_names:
|
||||
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
|
||||
await self.publish_message(
|
||||
GroupChatRequestPublish(),
|
||||
topic_id=DefaultTopicId(type=speaker_topic_type),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
self._active_speakers.append(speaker_name)
|
||||
|
||||
async def _apply_termination_condition(
|
||||
self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False
|
||||
) -> bool:
|
||||
"""Apply the termination condition to the delta and return True if the conversation should be terminated.
|
||||
It also resets the termination condition and turn count, and signals termination to the caller of the team."""
|
||||
if self._termination_condition is not None:
|
||||
stop_message = await self._termination_condition(delta)
|
||||
if stop_message is not None:
|
||||
# Reset the termination conditions and turn count.
|
||||
await self._termination_condition.reset()
|
||||
self._current_turn = 0
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(stop_message)
|
||||
# Stop the group chat.
|
||||
return True
|
||||
if increment_turn_count:
|
||||
# Increment the turn count.
|
||||
self._current_turn += 1
|
||||
# Check if the maximum number of turns has been reached.
|
||||
if self._max_turns is not None:
|
||||
if self._current_turn >= self._max_turns:
|
||||
stop_message = StopMessage(
|
||||
content=f"Maximum number of turns {self._max_turns} reached.",
|
||||
source=self._name,
|
||||
)
|
||||
# Reset the termination conditions and turn count.
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._current_turn = 0
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(stop_message)
|
||||
# Stop the group chat.
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _log_speaker_selection(self, speaker_names: List[str]) -> None:
|
||||
"""Log the selected speaker to the output message queue."""
|
||||
select_msg = SelectSpeakerEvent(content=speaker_names, source=self._name)
|
||||
if self._emit_team_events:
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=select_msg),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
await self._output_message_queue.put(select_msg)
|
||||
|
||||
async def _signal_termination(self, message: StopMessage) -> None:
|
||||
termination_event = GroupChatTermination(message=message)
|
||||
# Log the early stop message.
|
||||
await self.publish_message(
|
||||
termination_event,
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Put the termination event in the output message queue.
|
||||
await self._output_message_queue.put(termination_event)
|
||||
|
||||
async def _signal_termination_with_error(self, error: SerializableException) -> None:
|
||||
termination_event = GroupChatTermination(
|
||||
message=StopMessage(content="An error occurred in the group chat.", source=self._name), error=error
|
||||
)
|
||||
# Log the termination event.
|
||||
await self.publish_message(
|
||||
termination_event,
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Put the termination event in the output message queue.
|
||||
await self._output_message_queue.put(termination_event)
|
||||
|
||||
@event
|
||||
async def handle_group_chat_message(self, message: GroupChatMessage, ctx: MessageContext) -> None:
|
||||
"""Handle a group chat message by appending the content to its output message queue."""
|
||||
await self._output_message_queue.put(message.message)
|
||||
|
||||
@event
|
||||
async def handle_group_chat_error(self, message: GroupChatError, ctx: MessageContext) -> None:
|
||||
"""Handle a group chat error by logging the error and signaling termination."""
|
||||
await self._signal_termination_with_error(message.error)
|
||||
|
||||
@rpc
|
||||
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
|
||||
"""Reset the group chat manager. Calling :meth:`reset` to reset the group chat manager
|
||||
and clear the message thread."""
|
||||
await self.reset()
|
||||
|
||||
@rpc
|
||||
async def handle_pause(self, message: GroupChatPause, ctx: MessageContext) -> None:
|
||||
"""Pause the group chat manager. This is a no-op in the base class."""
|
||||
pass
|
||||
|
||||
@rpc
|
||||
async def handle_resume(self, message: GroupChatResume, ctx: MessageContext) -> None:
|
||||
"""Resume the group chat manager. This is a no-op in the base class."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
"""Validate the state of the group chat given the start messages.
|
||||
This is executed when the group chat manager receives a GroupChatStart event.
|
||||
|
||||
Args:
|
||||
messages: A list of chat messages to validate, or None if no messages are provided.
|
||||
"""
|
||||
...
|
||||
|
||||
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
|
||||
"""Update the message thread with the new messages.
|
||||
This is called when the group chat receives a GroupChatStart or GroupChatAgentResponse event,
|
||||
before calling the select_speakers method.
|
||||
"""
|
||||
self._message_thread.extend(messages)
|
||||
|
||||
@abstractmethod
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Select speakers from the participants and return the topic types of the selected speaker.
|
||||
This is called when the group chat manager have received all responses from the participants
|
||||
for a turn and is ready to select the next speakers for the next turn.
|
||||
|
||||
Args:
|
||||
thread: The message thread of the group chat.
|
||||
|
||||
Returns:
|
||||
A list of topic types of the selected speakers.
|
||||
If only one speaker is selected, a single string is returned instead of a list.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""Reset the group chat manager."""
|
||||
...
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
raise ValueError(f"Unhandled message in group chat manager: {type(message)}")
|
||||
@@ -0,0 +1,213 @@
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from agentdhal_core import DefaultTopicId, MessageContext, event, rpc, trace_invoke_agent_span
|
||||
|
||||
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, MessageFactory
|
||||
|
||||
from ...base import ChatAgent, Response, TaskResult, Team
|
||||
from ...state import ChatAgentContainerState
|
||||
from ._events import (
|
||||
GroupChatAgentResponse,
|
||||
GroupChatError,
|
||||
GroupChatMessage,
|
||||
GroupChatPause,
|
||||
GroupChatRequestPublish,
|
||||
GroupChatReset,
|
||||
GroupChatResume,
|
||||
GroupChatStart,
|
||||
GroupChatTeamResponse,
|
||||
SerializableException,
|
||||
)
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
|
||||
class ChatAgentContainer(SequentialRoutedAgent):
|
||||
"""A core agent class that delegates message handling to an
|
||||
:class:`agentdhal_agentchat.base.ChatAgent` or :class:`agentdhal_agentchat.base.Team`
|
||||
so that it can be used in a group chat team.
|
||||
|
||||
Args:
|
||||
parent_topic_type (str): The topic type of the parent orchestrator.
|
||||
output_topic_type (str): The topic type for the output.
|
||||
agent (ChatAgent | Team): The agent or team to delegate message handling to.
|
||||
message_factory (MessageFactory): The message factory to use for
|
||||
creating messages from JSON data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent | Team, message_factory: MessageFactory
|
||||
) -> None:
|
||||
super().__init__(
|
||||
description=agent.description,
|
||||
sequential_message_types=[
|
||||
GroupChatStart,
|
||||
GroupChatRequestPublish,
|
||||
GroupChatReset,
|
||||
GroupChatAgentResponse,
|
||||
GroupChatTeamResponse,
|
||||
],
|
||||
)
|
||||
self._parent_topic_type = parent_topic_type
|
||||
self._output_topic_type = output_topic_type
|
||||
self._agent = agent
|
||||
self._message_buffer: List[BaseChatMessage] = []
|
||||
self._message_factory = message_factory
|
||||
|
||||
@event
|
||||
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
|
||||
"""Handle a start event by appending the content to the buffer."""
|
||||
if message.messages is not None:
|
||||
for msg in message.messages:
|
||||
self._buffer_message(msg)
|
||||
|
||||
@event
|
||||
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
|
||||
"""Handle an agent response event by appending the content to the buffer."""
|
||||
self._buffer_message(message.response.chat_message)
|
||||
|
||||
@event
|
||||
async def handle_team_response(self, message: GroupChatTeamResponse, ctx: MessageContext) -> None:
|
||||
"""Handle a team response event by appending the content to the buffer."""
|
||||
for msg in message.result.messages:
|
||||
if isinstance(msg, BaseChatMessage):
|
||||
self._buffer_message(msg)
|
||||
|
||||
@rpc
|
||||
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
|
||||
"""Handle a reset event by resetting the agent."""
|
||||
self._message_buffer.clear()
|
||||
if isinstance(self._agent, Team):
|
||||
# If the agent is a team, reset the team.
|
||||
await self._agent.reset()
|
||||
else:
|
||||
await self._agent.on_reset(ctx.cancellation_token)
|
||||
|
||||
@event
|
||||
async def handle_request(self, message: GroupChatRequestPublish, ctx: MessageContext) -> None:
|
||||
"""Handle a content request event by passing the messages in the buffer
|
||||
to the delegate agent and publish the response."""
|
||||
if isinstance(self._agent, Team):
|
||||
try:
|
||||
stream = self._agent.run_stream(
|
||||
task=self._message_buffer,
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
output_task_messages=False,
|
||||
)
|
||||
result: TaskResult | None = None
|
||||
async for team_event in stream:
|
||||
if isinstance(team_event, TaskResult):
|
||||
result = team_event
|
||||
else:
|
||||
await self._log_message(team_event)
|
||||
if result is None:
|
||||
raise RuntimeError(
|
||||
"The team did not produce a final TaskResult. Check the team's run_stream method."
|
||||
)
|
||||
self._message_buffer.clear()
|
||||
# Publish the team response to the group chat.
|
||||
await self.publish_message(
|
||||
GroupChatTeamResponse(result=result, name=self._agent.name),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
except Exception as e:
|
||||
# Publish the error to the group chat.
|
||||
error_message = SerializableException.from_exception(e)
|
||||
await self.publish_message(
|
||||
GroupChatError(error=error_message),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
# Raise the error to the runtime.
|
||||
raise
|
||||
else:
|
||||
# If the agent is not a team, handle it as a single agent.
|
||||
with trace_invoke_agent_span(
|
||||
agent_name=self._agent.name,
|
||||
agent_description=self._agent.description,
|
||||
agent_id=str(self.id),
|
||||
):
|
||||
try:
|
||||
# Pass the messages in the buffer to the delegate agent.
|
||||
response: Response | None = None
|
||||
async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
|
||||
if isinstance(msg, Response):
|
||||
await self._log_message(msg.chat_message)
|
||||
response = msg
|
||||
else:
|
||||
await self._log_message(msg)
|
||||
if response is None:
|
||||
raise RuntimeError(
|
||||
"The agent did not produce a final response. Check the agent's on_messages_stream method."
|
||||
)
|
||||
# Publish the response to the group chat.
|
||||
self._message_buffer.clear()
|
||||
await self.publish_message(
|
||||
GroupChatAgentResponse(response=response, name=self._agent.name),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
except Exception as e:
|
||||
# Publish the error to the group chat.
|
||||
error_message = SerializableException.from_exception(e)
|
||||
await self.publish_message(
|
||||
GroupChatError(error=error_message),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
# Raise the error to the runtime.
|
||||
raise
|
||||
|
||||
def _buffer_message(self, message: BaseChatMessage) -> None:
|
||||
if not self._message_factory.is_registered(message.__class__):
|
||||
raise ValueError(f"Message type {message.__class__} is not registered.")
|
||||
# Buffer the message.
|
||||
self._message_buffer.append(message)
|
||||
|
||||
async def _log_message(self, message: BaseAgentEvent | BaseChatMessage) -> None:
|
||||
if not self._message_factory.is_registered(message.__class__):
|
||||
raise ValueError(f"Message type {message.__class__} is not registered.")
|
||||
# Log the message.
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
|
||||
@rpc
|
||||
async def handle_pause(self, message: GroupChatPause, ctx: MessageContext) -> None:
|
||||
"""Handle a pause event by pausing the agent."""
|
||||
if isinstance(self._agent, Team):
|
||||
# If the agent is a team, pause the team.
|
||||
await self._agent.pause()
|
||||
else:
|
||||
await self._agent.on_pause(ctx.cancellation_token)
|
||||
|
||||
@rpc
|
||||
async def handle_resume(self, message: GroupChatResume, ctx: MessageContext) -> None:
|
||||
"""Handle a resume event by resuming the agent."""
|
||||
if isinstance(self._agent, Team):
|
||||
# If the agent is a team, resume the team.
|
||||
await self._agent.resume()
|
||||
else:
|
||||
await self._agent.on_resume(ctx.cancellation_token)
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
raise ValueError(f"Unhandled message in agent container: {type(message)}")
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
agent_state = await self._agent.save_state()
|
||||
state = ChatAgentContainerState(
|
||||
agent_state=agent_state, message_buffer=[message.dump() for message in self._message_buffer]
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
container_state = ChatAgentContainerState.model_validate(state)
|
||||
self._message_buffer = []
|
||||
for message_data in container_state.message_buffer:
|
||||
message = self._message_factory.create(message_data)
|
||||
if isinstance(message, BaseChatMessage):
|
||||
self._message_buffer.append(message)
|
||||
else:
|
||||
raise ValueError(f"Invalid message type in message buffer: {type(message)}")
|
||||
await self._agent.load_state(container_state.agent_state)
|
||||
113
agent_dhal/agentdhal_agentchat/teams/_group_chat/_events.py
Normal file
113
agent_dhal/agentdhal_agentchat/teams/_group_chat/_events.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, SerializeAsAny
|
||||
|
||||
from ...base import Response, TaskResult
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, StopMessage
|
||||
|
||||
|
||||
class SerializableException(BaseModel):
|
||||
"""A serializable exception."""
|
||||
|
||||
error_type: str
|
||||
"""The type of error that occurred."""
|
||||
|
||||
error_message: str
|
||||
"""The error message that describes the error."""
|
||||
|
||||
traceback: str | None = None
|
||||
"""The traceback of the error, if available."""
|
||||
|
||||
@classmethod
|
||||
def from_exception(cls, exc: Exception) -> "SerializableException":
|
||||
"""Create a GroupChatError from an exception."""
|
||||
return cls(
|
||||
error_type=type(exc).__name__,
|
||||
error_message=str(exc),
|
||||
traceback="\n".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the error, including the traceback if available."""
|
||||
if self.traceback:
|
||||
return f"{self.error_type}: {self.error_message}\nTraceback:\n{self.traceback}"
|
||||
return f"{self.error_type}: {self.error_message}"
|
||||
|
||||
|
||||
class GroupChatStart(BaseModel):
|
||||
"""A request to start a group chat."""
|
||||
|
||||
messages: List[SerializeAsAny[BaseChatMessage]] | None = None
|
||||
"""An optional list of messages to start the group chat."""
|
||||
|
||||
output_task_messages: bool = True
|
||||
"""Whether to include task messages in the output. Defaults to True for backward compatibility."""
|
||||
|
||||
|
||||
class GroupChatAgentResponse(BaseModel):
|
||||
"""A response published to a group chat."""
|
||||
|
||||
response: SerializeAsAny[Response]
|
||||
"""The response from an agent."""
|
||||
|
||||
name: str
|
||||
"""The name of the agent that produced the response."""
|
||||
|
||||
|
||||
class GroupChatTeamResponse(BaseModel):
|
||||
"""A response published to a group chat from a team."""
|
||||
|
||||
result: SerializeAsAny[TaskResult]
|
||||
"""The result from a team."""
|
||||
|
||||
name: str
|
||||
"""The name of the team that produced the response."""
|
||||
|
||||
|
||||
class GroupChatRequestPublish(BaseModel):
|
||||
"""A request to publish a message to a group chat."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class GroupChatMessage(BaseModel):
|
||||
"""A message from a group chat."""
|
||||
|
||||
message: SerializeAsAny[BaseAgentEvent | BaseChatMessage]
|
||||
"""The message that was published."""
|
||||
|
||||
|
||||
class GroupChatTermination(BaseModel):
|
||||
"""A message indicating that a group chat has terminated."""
|
||||
|
||||
message: StopMessage
|
||||
"""The stop message that indicates the reason of termination."""
|
||||
|
||||
error: SerializableException | None = None
|
||||
"""The error that occurred, if any."""
|
||||
|
||||
|
||||
class GroupChatReset(BaseModel):
|
||||
"""A request to reset the agents in the group chat."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class GroupChatPause(BaseModel):
|
||||
"""A request to pause the group chat."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class GroupChatResume(BaseModel):
|
||||
"""A request to resume the group chat."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class GroupChatError(BaseModel):
|
||||
"""A message indicating that an error occurred in the group chat."""
|
||||
|
||||
error: SerializableException
|
||||
"""The error that occurred."""
|
||||
@@ -0,0 +1,17 @@
|
||||
from ._digraph_group_chat import (
|
||||
DiGraph,
|
||||
DiGraphEdge,
|
||||
DiGraphNode,
|
||||
GraphFlow,
|
||||
GraphFlowManager,
|
||||
)
|
||||
from ._graph_builder import DiGraphBuilder
|
||||
|
||||
__all__ = [
|
||||
"GraphFlow",
|
||||
"DiGraph",
|
||||
"GraphFlowManager",
|
||||
"DiGraphNode",
|
||||
"DiGraphEdge",
|
||||
"DiGraphBuilder",
|
||||
]
|
||||
@@ -0,0 +1,877 @@
|
||||
import asyncio
|
||||
from collections import Counter, deque
|
||||
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set, Union
|
||||
|
||||
from agentdhal_core import AgentRuntime, Component, ComponentModel
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from agentdhal_agentchat.base import ChatAgent, TerminationCondition
|
||||
from agentdhal_agentchat.messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
MessageFactory,
|
||||
StopMessage,
|
||||
)
|
||||
from agentdhal_agentchat.state import BaseGroupChatManagerState
|
||||
from agentdhal_agentchat.teams import BaseGroupChat
|
||||
|
||||
from ..._group_chat._base_group_chat_manager import BaseGroupChatManager
|
||||
from ..._group_chat._events import GroupChatTermination
|
||||
|
||||
_DIGRAPH_STOP_MESSAGE = "Digraph execution is complete"
|
||||
|
||||
|
||||
class DiGraphEdge(BaseModel):
|
||||
"""Represents a directed edge in a :class:`DiGraph`, with an optional execution condition.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental feature, and the API will change in the future releases.
|
||||
|
||||
.. warning::
|
||||
|
||||
If the condition is a callable, it will not be serialized in the model.
|
||||
|
||||
"""
|
||||
|
||||
target: str # Target node name
|
||||
condition: Union[str, Callable[[BaseChatMessage], bool], None] = Field(default=None)
|
||||
"""(Experimental) Condition to execute this edge.
|
||||
If None, the edge is unconditional.
|
||||
If a string, the edge is conditional on the presence of that string in the last agent chat message.
|
||||
If a callable, the edge is conditional on the callable returning True when given the last message.
|
||||
"""
|
||||
|
||||
# Using Field to exclude the condition in serialization if it's a callable
|
||||
condition_function: Callable[[BaseChatMessage], bool] | None = Field(default=None, exclude=True)
|
||||
activation_group: str = Field(default="")
|
||||
"""Group identifier for forward dependencies.
|
||||
|
||||
When multiple edges point to the same target node, they are grouped by this field.
|
||||
This allows distinguishing between different cycles or dependency patterns.
|
||||
|
||||
Example: In a graph containing a cycle like A->B->C->B, the two edges pointing to B (A->B and C->B)
|
||||
can be in different activation groups to control how B is activated.
|
||||
Defaults to the target node name if not specified.
|
||||
"""
|
||||
activation_condition: Literal["all", "any"] = "all"
|
||||
"""Determines how forward dependencies within the same activation_group are evaluated.
|
||||
|
||||
- "all": All edges in this activation group must be satisfied before the target node can execute
|
||||
- "any": Any single edge in this activation group being satisfied allows the target node to execute
|
||||
|
||||
This is used to handle complex dependency patterns in cyclic graphs where multiple
|
||||
paths can lead to the same target node.
|
||||
"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_condition(self) -> "DiGraphEdge":
|
||||
# Store callable in a separate field and set condition to None for serialization
|
||||
if callable(self.condition):
|
||||
self.condition_function = self.condition
|
||||
# For serialization purposes, we'll set the condition to None
|
||||
# when storing as a pydantic model/dict
|
||||
object.__setattr__(self, "condition", None)
|
||||
|
||||
# Set activation_group to target if not already set
|
||||
if not self.activation_group:
|
||||
self.activation_group = self.target
|
||||
|
||||
return self
|
||||
|
||||
def check_condition(self, message: BaseChatMessage) -> bool:
|
||||
"""Check if the edge condition is satisfied for the given message.
|
||||
|
||||
Args:
|
||||
message: The message to check the condition against.
|
||||
|
||||
Returns:
|
||||
True if condition is satisfied (None condition always returns True),
|
||||
False otherwise.
|
||||
"""
|
||||
if self.condition_function is not None:
|
||||
return self.condition_function(message)
|
||||
elif isinstance(self.condition, str):
|
||||
# If it's a string, check if the string is in the message content
|
||||
return self.condition in message.to_model_text()
|
||||
return True # None condition is always satisfied
|
||||
|
||||
|
||||
class DiGraphNode(BaseModel):
|
||||
"""Represents a node (agent) in a :class:`DiGraph`, with its outgoing edges and activation type.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental feature, and the API will change in the future releases.
|
||||
|
||||
"""
|
||||
|
||||
name: str # Agent's name
|
||||
edges: List[DiGraphEdge] = [] # Outgoing edges
|
||||
activation: Literal["all", "any"] = "all"
|
||||
|
||||
|
||||
class DiGraph(BaseModel):
|
||||
"""Defines a directed graph structure with nodes and edges.
|
||||
:class:`GraphFlow` uses this to determine execution order and conditions.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental feature, and the API will change in the future releases.
|
||||
|
||||
"""
|
||||
|
||||
nodes: Dict[str, DiGraphNode] # Node name → DiGraphNode mapping
|
||||
default_start_node: str | None = None # Default start node name
|
||||
_has_cycles: bool | None = None # Cyclic graph flag
|
||||
|
||||
def get_parents(self) -> Dict[str, List[str]]:
|
||||
"""Compute a mapping of each node to its parent nodes."""
|
||||
parents: Dict[str, List[str]] = {node: [] for node in self.nodes}
|
||||
for node in self.nodes.values():
|
||||
for edge in node.edges:
|
||||
parents[edge.target].append(node.name)
|
||||
return parents
|
||||
|
||||
def get_start_nodes(self) -> Set[str]:
|
||||
"""Return the nodes that have no incoming edges (entry points)."""
|
||||
if self.default_start_node:
|
||||
return {self.default_start_node}
|
||||
|
||||
parents = self.get_parents()
|
||||
return set([node_name for node_name, parent_list in parents.items() if not parent_list])
|
||||
|
||||
def get_leaf_nodes(self) -> Set[str]:
|
||||
"""Return nodes that have no outgoing edges (final output nodes)."""
|
||||
return set([name for name, node in self.nodes.items() if not node.edges])
|
||||
|
||||
def has_cycles_with_exit(self) -> bool:
|
||||
"""
|
||||
Check if the graph has any cycles and validate that each cycle has at least one conditional edge.
|
||||
|
||||
Returns:
|
||||
bool: True if there is at least one cycle and all cycles have an exit condition.
|
||||
False if there are no cycles.
|
||||
|
||||
Raises:
|
||||
ValueError: If there is a cycle without any conditional edge.
|
||||
"""
|
||||
visited: Set[str] = set()
|
||||
rec_stack: Set[str] = set()
|
||||
path: List[str] = []
|
||||
|
||||
def dfs(node_name: str) -> bool:
|
||||
visited.add(node_name)
|
||||
rec_stack.add(node_name)
|
||||
path.append(node_name)
|
||||
|
||||
for edge in self.nodes[node_name].edges:
|
||||
target = edge.target
|
||||
if target not in visited:
|
||||
if dfs(target):
|
||||
return True
|
||||
elif target in rec_stack:
|
||||
# Found a cycle → extract the cycle
|
||||
cycle_start_index = path.index(target)
|
||||
cycle_nodes = path[cycle_start_index:]
|
||||
cycle_edges: List[DiGraphEdge] = []
|
||||
for n in cycle_nodes:
|
||||
cycle_edges.extend(self.nodes[n].edges)
|
||||
if all(edge.condition is None and edge.condition_function is None for edge in cycle_edges):
|
||||
raise ValueError(
|
||||
f"Cycle detected without exit condition: {' -> '.join(cycle_nodes + cycle_nodes[:1])}"
|
||||
)
|
||||
return True # Found cycle, but it has an exit condition
|
||||
|
||||
rec_stack.remove(node_name)
|
||||
path.pop()
|
||||
return False
|
||||
|
||||
has_cycle = False
|
||||
for node in self.nodes:
|
||||
if node not in visited:
|
||||
if dfs(node):
|
||||
has_cycle = True
|
||||
|
||||
return has_cycle
|
||||
|
||||
def get_has_cycles(self) -> bool:
|
||||
"""Indicates if the graph has at least one cycle (with valid exit conditions)."""
|
||||
if self._has_cycles is None:
|
||||
self._has_cycles = self.has_cycles_with_exit()
|
||||
|
||||
return self._has_cycles
|
||||
|
||||
def graph_validate(self) -> None:
|
||||
"""Validate graph structure and execution rules."""
|
||||
if not self.nodes:
|
||||
raise ValueError("Graph has no nodes.")
|
||||
|
||||
if not self.get_start_nodes():
|
||||
raise ValueError("Graph must have at least one start node")
|
||||
|
||||
if not self.get_leaf_nodes():
|
||||
raise ValueError("Graph must have at least one leaf node")
|
||||
|
||||
# Outgoing edge condition validation (per node)
|
||||
for node in self.nodes.values():
|
||||
# Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional
|
||||
has_condition = any(
|
||||
edge.condition is not None or edge.condition_function is not None for edge in node.edges
|
||||
)
|
||||
has_unconditioned = any(edge.condition is None and edge.condition_function is None for edge in node.edges)
|
||||
if has_condition and has_unconditioned:
|
||||
raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.")
|
||||
|
||||
# Validate activation conditions across all edges in the graph
|
||||
self._validate_activation_conditions()
|
||||
|
||||
self._has_cycles = self.has_cycles_with_exit()
|
||||
|
||||
def _validate_activation_conditions(self) -> None:
|
||||
"""Validate that all edges pointing to the same target node have consistent activation_condition values.
|
||||
|
||||
Raises:
|
||||
ValueError: If edges pointing to the same target have different activation_condition values
|
||||
"""
|
||||
target_activation_conditions: Dict[str, Dict[str, str]] = {} # target_node -> {activation_group -> condition}
|
||||
|
||||
for node in self.nodes.values():
|
||||
for edge in node.edges:
|
||||
target = edge.target # The target node this edge points to
|
||||
activation_group = edge.activation_group
|
||||
|
||||
if target not in target_activation_conditions:
|
||||
target_activation_conditions[target] = {}
|
||||
|
||||
if activation_group in target_activation_conditions[target]:
|
||||
if target_activation_conditions[target][activation_group] != edge.activation_condition:
|
||||
# Find the source node that has the conflicting condition
|
||||
conflicting_source = self._find_edge_source_by_target_and_group(
|
||||
target, activation_group, target_activation_conditions[target][activation_group]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Conflicting activation conditions for target '{target}' group '{activation_group}': "
|
||||
f"'{target_activation_conditions[target][activation_group]}' (from node '{conflicting_source}') "
|
||||
f"and '{edge.activation_condition}' (from node '{node.name}')"
|
||||
)
|
||||
else:
|
||||
target_activation_conditions[target][activation_group] = edge.activation_condition
|
||||
|
||||
def _find_edge_source_by_target_and_group(
|
||||
self, target: str, activation_group: str, activation_condition: str
|
||||
) -> str:
|
||||
"""Find the source node that has an edge pointing to the given target with the given activation_group and activation_condition."""
|
||||
for node_name, node in self.nodes.items():
|
||||
for edge in node.edges:
|
||||
if (
|
||||
edge.target == target
|
||||
and edge.activation_group == activation_group
|
||||
and edge.activation_condition == activation_condition
|
||||
):
|
||||
return node_name
|
||||
return "unknown"
|
||||
|
||||
def get_remaining_map(self) -> Dict[str, Dict[str, int]]:
|
||||
"""Get the remaining map that tracks how many edges point to each target node with each activation group.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping target nodes to their activation groups and remaining counts
|
||||
"""
|
||||
|
||||
remaining_map: Dict[str, Dict[str, int]] = {}
|
||||
|
||||
for node in self.nodes.values():
|
||||
for edge in node.edges:
|
||||
target = edge.target
|
||||
activation_group = edge.activation_group
|
||||
|
||||
if target not in remaining_map:
|
||||
remaining_map[target] = {}
|
||||
|
||||
if activation_group not in remaining_map[target]:
|
||||
remaining_map[target][activation_group] = 0
|
||||
|
||||
remaining_map[target][activation_group] += 1
|
||||
|
||||
return remaining_map
|
||||
|
||||
|
||||
class GraphFlowManagerState(BaseGroupChatManagerState):
|
||||
"""Tracks active execution state for DAG-based execution."""
|
||||
|
||||
active_nodes: List[str] = [] # Currently executing nodes
|
||||
type: str = "GraphManagerState"
|
||||
|
||||
|
||||
class GraphFlowManager(BaseGroupChatManager):
|
||||
"""Manages execution of agents using a Directed Graph execution model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
graph: DiGraph,
|
||||
) -> None:
|
||||
"""Initialize the graph-based execution manager."""
|
||||
super().__init__(
|
||||
name=name,
|
||||
group_topic_type=group_topic_type,
|
||||
output_topic_type=output_topic_type,
|
||||
participant_topic_types=participant_topic_types,
|
||||
participant_names=participant_names,
|
||||
participant_descriptions=participant_descriptions,
|
||||
output_message_queue=output_message_queue,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
message_factory=message_factory,
|
||||
)
|
||||
graph.graph_validate()
|
||||
if graph.get_has_cycles() and self._termination_condition is None and self._max_turns is None:
|
||||
raise ValueError("A termination condition is required for cyclic graphs without a maximum turn limit.")
|
||||
self._graph = graph
|
||||
# Lookup table for incoming edges for each node.
|
||||
self._parents = graph.get_parents()
|
||||
# Lookup table for outgoing edges for each node.
|
||||
self._edges: Dict[str, List[DiGraphEdge]] = {n: node.edges for n, node in graph.nodes.items()}
|
||||
|
||||
# Build activation and enqueued_any lookup tables by collecting all edges and grouping by target node
|
||||
self._build_lookup_tables(graph)
|
||||
|
||||
# Track which activation groups were triggered for each node
|
||||
self._triggered_activation_groups: Dict[str, Set[str]] = {}
|
||||
# === Mutable states for the graph execution ===
|
||||
# Count the number of remaining parents to activate each node.
|
||||
self._remaining: Dict[str, Counter[str]] = {
|
||||
target: Counter(groups) for target, groups in graph.get_remaining_map().items()
|
||||
}
|
||||
# cache for remaining
|
||||
self._origin_remaining: Dict[str, Dict[str, int]] = {
|
||||
target: Counter(groups) for target, groups in self._remaining.items()
|
||||
}
|
||||
|
||||
# Ready queue for nodes that are ready to execute, starting with the start nodes.
|
||||
self._ready: Deque[str] = deque([n for n in graph.get_start_nodes()])
|
||||
|
||||
def _build_lookup_tables(self, graph: DiGraph) -> None:
|
||||
"""Build activation and enqueued_any lookup tables by collecting all edges and grouping by target node.
|
||||
|
||||
Args:
|
||||
graph: The directed graph
|
||||
"""
|
||||
self._activation: Dict[str, Dict[str, Literal["any", "all"]]] = {}
|
||||
self._enqueued_any: Dict[str, Dict[str, bool]] = {}
|
||||
|
||||
for node in graph.nodes.values():
|
||||
for edge in node.edges:
|
||||
target = edge.target
|
||||
activation_group = edge.activation_group
|
||||
|
||||
# Build activation lookup
|
||||
if target not in self._activation:
|
||||
self._activation[target] = {}
|
||||
if activation_group not in self._activation[target]:
|
||||
self._activation[target][activation_group] = edge.activation_condition
|
||||
|
||||
# Build enqueued_any lookup
|
||||
if target not in self._enqueued_any:
|
||||
self._enqueued_any[target] = {}
|
||||
if activation_group not in self._enqueued_any[target]:
|
||||
self._enqueued_any[target][activation_group] = False
|
||||
|
||||
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
|
||||
await super().update_message_thread(messages)
|
||||
|
||||
# Find the node that ran in the current turn.
|
||||
message = messages[-1]
|
||||
if message.source not in self._graph.nodes:
|
||||
# Ignore messages from sources outside of the graph.
|
||||
return
|
||||
assert isinstance(message, BaseChatMessage)
|
||||
source = message.source
|
||||
|
||||
# Propagate the update to the children of the node.
|
||||
for edge in self._edges[source]:
|
||||
# Use the new check_condition method that handles both string and callable conditions
|
||||
if not edge.check_condition(message):
|
||||
continue
|
||||
|
||||
target = edge.target
|
||||
activation_group = edge.activation_group
|
||||
|
||||
if self._activation[target][activation_group] == "all":
|
||||
self._remaining[target][activation_group] -= 1
|
||||
if self._remaining[target][activation_group] == 0:
|
||||
# If all parents are done, add to the ready queue.
|
||||
self._ready.append(target)
|
||||
# Track which activation group was triggered
|
||||
self._save_triggered_activation_group(target, activation_group)
|
||||
else:
|
||||
# If activation is any, add to the ready queue if not already enqueued.
|
||||
if not self._enqueued_any[target][activation_group]:
|
||||
self._ready.append(target)
|
||||
self._enqueued_any[target][activation_group] = True
|
||||
# Track which activation group was triggered
|
||||
self._save_triggered_activation_group(target, activation_group)
|
||||
|
||||
def _save_triggered_activation_group(self, target: str, activation_group: str) -> None:
|
||||
"""Save which activation group was triggered for a target node.
|
||||
|
||||
Args:
|
||||
target: The target node that was triggered
|
||||
activation_group: The activation group that caused the trigger
|
||||
"""
|
||||
if target not in self._triggered_activation_groups:
|
||||
self._triggered_activation_groups[target] = set()
|
||||
self._triggered_activation_groups[target].add(activation_group)
|
||||
|
||||
def _reset_triggered_activation_groups(self, speaker: str) -> None:
|
||||
"""Reset the bookkeeping for the specific activation groups that were triggered for a speaker.
|
||||
|
||||
Args:
|
||||
speaker: The speaker node to reset activation groups for
|
||||
"""
|
||||
if speaker not in self._triggered_activation_groups:
|
||||
return
|
||||
|
||||
for activation_group in self._triggered_activation_groups[speaker]:
|
||||
if self._activation[speaker][activation_group] == "any":
|
||||
self._enqueued_any[speaker][activation_group] = False
|
||||
else:
|
||||
# Reset the remaining count for this activation group using the graph's original count
|
||||
if speaker in self._remaining and activation_group in self._remaining[speaker]:
|
||||
self._remaining[speaker][activation_group] = self._origin_remaining[speaker][activation_group]
|
||||
|
||||
# Clear the triggered activation groups for this speaker
|
||||
self._triggered_activation_groups[speaker].clear()
|
||||
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]:
|
||||
# Drain the ready queue for the next set of speakers.
|
||||
speakers: List[str] = []
|
||||
while self._ready:
|
||||
speaker = self._ready.popleft()
|
||||
speakers.append(speaker)
|
||||
|
||||
# Reset the bookkeeping for the specific activation groups that were triggered
|
||||
self._reset_triggered_activation_groups(speaker)
|
||||
|
||||
return speakers
|
||||
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
pass
|
||||
|
||||
async def _apply_termination_condition(
|
||||
self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False
|
||||
) -> bool:
|
||||
"""Apply termination condition including graph-specific completion logic.
|
||||
|
||||
First checks if graph execution is complete, then checks standard termination conditions.
|
||||
|
||||
Args:
|
||||
delta: The message delta to check termination conditions against
|
||||
increment_turn_count: Whether to increment the turn count
|
||||
|
||||
Returns:
|
||||
True if the conversation should be terminated, False otherwise
|
||||
"""
|
||||
# Check if the graph execution is complete (no ready speakers) - prioritize this check
|
||||
if not self._ready:
|
||||
stop_message = StopMessage(
|
||||
content=_DIGRAPH_STOP_MESSAGE,
|
||||
source=self._name,
|
||||
)
|
||||
# Reset the execution state when the graph has naturally completed
|
||||
self._reset_execution_state()
|
||||
# Reset the termination conditions and turn count.
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._current_turn = 0
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(stop_message)
|
||||
return True
|
||||
|
||||
# Apply the standard termination conditions from the base class
|
||||
return await super()._apply_termination_condition(delta, increment_turn_count)
|
||||
|
||||
def _reset_execution_state(self) -> None:
|
||||
"""Reset the graph execution state to the initial state."""
|
||||
self._remaining = {target: Counter(groups) for target, groups in self._graph.get_remaining_map().items()}
|
||||
self._enqueued_any = {n: {g: False for g in self._enqueued_any[n]} for n in self._enqueued_any}
|
||||
self._ready = deque([n for n in self._graph.get_start_nodes()])
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the execution state."""
|
||||
state = {
|
||||
"message_thread": [message.dump() for message in self._message_thread],
|
||||
"current_turn": self._current_turn,
|
||||
"remaining": {target: dict(counter) for target, counter in self._remaining.items()},
|
||||
"enqueued_any": dict(self._enqueued_any),
|
||||
"ready": list(self._ready),
|
||||
}
|
||||
return state
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Restore execution state from saved data."""
|
||||
self._message_thread = [self._message_factory.create(msg) for msg in state["message_thread"]]
|
||||
self._current_turn = state["current_turn"]
|
||||
self._remaining = {target: Counter(groups) for target, groups in state["remaining"].items()}
|
||||
self._enqueued_any = state["enqueued_any"]
|
||||
self._ready = deque(state["ready"])
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset execution state to the start of the graph."""
|
||||
self._current_turn = 0
|
||||
self._message_thread.clear()
|
||||
if self._termination_condition:
|
||||
await self._termination_condition.reset()
|
||||
self._reset_execution_state()
|
||||
|
||||
|
||||
class GraphFlowConfig(BaseModel):
|
||||
"""The declarative configuration for GraphFlow."""
|
||||
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
participants: List[ComponentModel]
|
||||
termination_condition: ComponentModel | None = None
|
||||
max_turns: int | None = None
|
||||
graph: DiGraph # The execution graph for agents
|
||||
|
||||
|
||||
class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
|
||||
"""A team that runs a group chat following a Directed Graph execution pattern.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental feature, and the API will change in the future releases.
|
||||
|
||||
This group chat executes agents based on a directed graph (:class:`DiGraph`) structure,
|
||||
allowing complex workflows such as sequential execution, parallel fan-out,
|
||||
conditional branching, join patterns, and loops with explicit exit conditions.
|
||||
|
||||
The execution order is determined by the edges defined in the `DiGraph`. Each node
|
||||
in the graph corresponds to an agent, and edges define the flow of messages between agents.
|
||||
Nodes can be configured to activate when:
|
||||
|
||||
- **All** parent nodes have completed (activation="all") → default
|
||||
- **Any** parent node completes (activation="any")
|
||||
|
||||
Conditional branching is supported using edge conditions, where the next agent(s) are selected
|
||||
based on content in the chat history. Loops are permitted as long as there is a condition
|
||||
that eventually exits the loop.
|
||||
|
||||
.. note::
|
||||
|
||||
Use the :class:`DiGraphBuilder` class to create a :class:`DiGraph` easily. It provides a fluent API
|
||||
for adding nodes and edges, setting entry points, and validating the graph structure.
|
||||
See the :class:`DiGraphBuilder` documentation for more details.
|
||||
The :class:`GraphFlow` class is designed to be used with the :class:`DiGraphBuilder` for creating complex workflows.
|
||||
|
||||
.. warning::
|
||||
|
||||
When using callable conditions in edges, they will not be serialized
|
||||
when calling :meth:`dump_component`. This will be addressed in future releases.
|
||||
|
||||
|
||||
Args:
|
||||
participants (List[ChatAgent]): The participants in the group chat.
|
||||
termination_condition (TerminationCondition, optional): Termination condition for the chat.
|
||||
max_turns (int, optional): Maximum number of turns before forcing termination.
|
||||
graph (DiGraph): Directed execution graph defining node flow and conditions.
|
||||
|
||||
Raises:
|
||||
ValueError: If participant names are not unique, or if graph validation fails (e.g., cycles without exit).
|
||||
|
||||
Examples:
|
||||
|
||||
**Sequential Flow: A → B → C**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import DiGraphBuilder, GraphFlow
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize agents with OpenAI model clients.
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
|
||||
agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.")
|
||||
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to English.")
|
||||
|
||||
# Create a directed graph with sequential flow A -> B -> C.
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
builder.add_edge(agent_a, agent_b).add_edge(agent_b, agent_c)
|
||||
graph = builder.build()
|
||||
|
||||
# Create a GraphFlow team with the directed graph.
|
||||
team = GraphFlow(
|
||||
participants=[agent_a, agent_b, agent_c],
|
||||
graph=graph,
|
||||
termination_condition=MaxMessageTermination(5),
|
||||
)
|
||||
|
||||
# Run the team and print the events.
|
||||
async for event in team.run_stream(task="Write a short story about a cat."):
|
||||
print(event)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
**Parallel Fan-out: A → (B, C)**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import DiGraphBuilder, GraphFlow
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize agents with OpenAI model clients.
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
|
||||
agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.")
|
||||
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Japanese.")
|
||||
|
||||
# Create a directed graph with fan-out flow A -> (B, C).
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
|
||||
graph = builder.build()
|
||||
|
||||
# Create a GraphFlow team with the directed graph.
|
||||
team = GraphFlow(
|
||||
participants=[agent_a, agent_b, agent_c],
|
||||
graph=graph,
|
||||
termination_condition=MaxMessageTermination(5),
|
||||
)
|
||||
|
||||
# Run the team and print the events.
|
||||
async for event in team.run_stream(task="Write a short story about a cat."):
|
||||
print(event)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
**Conditional Branching: A → B (if 'yes') or C (otherwise)**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import DiGraphBuilder, GraphFlow
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize agents with OpenAI model clients.
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
|
||||
agent_a = AssistantAgent(
|
||||
"A",
|
||||
model_client=model_client,
|
||||
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
|
||||
)
|
||||
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
|
||||
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
|
||||
|
||||
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
# Create conditions as callables that check the message content.
|
||||
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
|
||||
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
|
||||
graph = builder.build()
|
||||
|
||||
# Create a GraphFlow team with the directed graph.
|
||||
team = GraphFlow(
|
||||
participants=[agent_a, agent_b, agent_c],
|
||||
graph=graph,
|
||||
termination_condition=MaxMessageTermination(5),
|
||||
)
|
||||
|
||||
# Run the team and print the events.
|
||||
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
|
||||
print(event)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
**Loop with exit condition: A → B → C (if 'APPROVE') or A (otherwise)**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import DiGraphBuilder, GraphFlow
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize agents with OpenAI model clients.
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1")
|
||||
agent_a = AssistantAgent(
|
||||
"A",
|
||||
model_client=model_client,
|
||||
system_message="You are a helpful assistant.",
|
||||
)
|
||||
agent_b = AssistantAgent(
|
||||
"B",
|
||||
model_client=model_client,
|
||||
system_message="Provide feedback on the input, if your feedback has been addressed, "
|
||||
"say 'APPROVE', otherwise provide a reason for rejection.",
|
||||
)
|
||||
agent_c = AssistantAgent(
|
||||
"C", model_client=model_client, system_message="Translate the final product to Korean."
|
||||
)
|
||||
|
||||
# Create a loop graph with conditional exit: A -> B -> C ("APPROVE"), B -> A (otherwise).
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
builder.add_edge(agent_a, agent_b)
|
||||
|
||||
# Create conditional edges using strings
|
||||
builder.add_edge(agent_b, agent_c, condition=lambda msg: "APPROVE" in msg.to_model_text())
|
||||
builder.add_edge(agent_b, agent_a, condition=lambda msg: "APPROVE" not in msg.to_model_text())
|
||||
|
||||
builder.set_entry_point(agent_a)
|
||||
graph = builder.build()
|
||||
|
||||
# Create a GraphFlow team with the directed graph.
|
||||
team = GraphFlow(
|
||||
participants=[agent_a, agent_b, agent_c],
|
||||
graph=graph,
|
||||
termination_condition=MaxMessageTermination(20), # Max 20 messages to avoid infinite loop.
|
||||
)
|
||||
|
||||
# Run the team and print the events.
|
||||
async for event in team.run_stream(task="Write a short poem about AI Agents."):
|
||||
print(event)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_config_schema = GraphFlowConfig
|
||||
component_provider_override = "agentdhal_agentchat.teams.GraphFlow"
|
||||
|
||||
DEFAULT_NAME = "GraphFlow"
|
||||
DEFAULT_DESCRIPTION = "A team of agents"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
participants: List[ChatAgent],
|
||||
graph: DiGraph,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
) -> None:
|
||||
self._input_participants = participants
|
||||
self._input_termination_condition = termination_condition
|
||||
|
||||
for participant in participants:
|
||||
if not isinstance(participant, ChatAgent):
|
||||
raise TypeError(f"Participant {participant} must be a ChatAgent.")
|
||||
|
||||
# No longer add _StopAgent or StopMessageTermination
|
||||
# Termination is now handled directly in GraphFlowManager._apply_termination_condition
|
||||
super().__init__(
|
||||
name=name or self.DEFAULT_NAME,
|
||||
description=description or self.DEFAULT_DESCRIPTION,
|
||||
participants=list(participants),
|
||||
group_chat_manager_name="GraphManager",
|
||||
group_chat_manager_class=GraphFlowManager,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
)
|
||||
self._graph = graph
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], GraphFlowManager]:
|
||||
"""Creates the factory method for initializing the DiGraph-based chat manager."""
|
||||
|
||||
def _factory() -> GraphFlowManager:
|
||||
return GraphFlowManager(
|
||||
name=name,
|
||||
group_topic_type=group_topic_type,
|
||||
output_topic_type=output_topic_type,
|
||||
participant_topic_types=participant_topic_types,
|
||||
participant_names=participant_names,
|
||||
participant_descriptions=participant_descriptions,
|
||||
output_message_queue=output_message_queue,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
message_factory=message_factory,
|
||||
graph=self._graph,
|
||||
)
|
||||
|
||||
return _factory
|
||||
|
||||
def _to_config(self) -> GraphFlowConfig:
|
||||
"""Converts the instance into a configuration object."""
|
||||
participants = [participant.dump_component() for participant in self._input_participants]
|
||||
termination_condition = (
|
||||
self._input_termination_condition.dump_component() if self._input_termination_condition else None
|
||||
)
|
||||
return GraphFlowConfig(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
participants=participants,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=self._max_turns,
|
||||
graph=self._graph,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: GraphFlowConfig) -> Self:
|
||||
"""Reconstructs an instance from a configuration object."""
|
||||
participants = [ChatAgent.load_component(participant) for participant in config.participants]
|
||||
termination_condition = (
|
||||
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
|
||||
)
|
||||
return cls(
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
participants=participants,
|
||||
graph=config.graph,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=config.max_turns,
|
||||
)
|
||||
@@ -0,0 +1,209 @@
|
||||
import warnings
|
||||
from typing import Callable, Dict, Literal, Optional, Union
|
||||
|
||||
from agentdhal_agentchat.base import ChatAgent
|
||||
from agentdhal_agentchat.messages import BaseChatMessage
|
||||
|
||||
from ._digraph_group_chat import DiGraph, DiGraphEdge, DiGraphNode
|
||||
|
||||
|
||||
class DiGraphBuilder:
|
||||
"""
|
||||
A fluent builder for constructing :class:`DiGraph` execution graphs used in :class:`GraphFlow`.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental feature, and the API will change in the future releases.
|
||||
|
||||
This utility provides a convenient way to programmatically build a graph of agent interactions,
|
||||
including complex execution flows such as:
|
||||
|
||||
- Sequential chains
|
||||
- Parallel fan-outs
|
||||
- Conditional branching
|
||||
- Cyclic loops with safe exits
|
||||
|
||||
Each node in the graph represents an agent. Edges define execution paths between agents,
|
||||
and can optionally be conditioned on message content using callable functions.
|
||||
|
||||
The builder is compatible with the `Graph` runner and supports both standard and filtered agents.
|
||||
|
||||
Methods:
|
||||
- add_node(agent, activation): Add an agent node to the graph.
|
||||
- add_edge(source, target, condition): Connect two nodes optionally with a condition.
|
||||
- add_conditional_edges(source, condition_to_target): Add multiple conditional edges from a source.
|
||||
- set_entry_point(agent): Define the default start node (optional).
|
||||
- build(): Generate a validated `DiGraph`.
|
||||
- get_participants(): Return the list of added agents.
|
||||
|
||||
Example — Sequential Flow A → B → C:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> builder.add_edge(agent_a, agent_b).add_edge(agent_b, agent_c)
|
||||
>>> team = Graph(
|
||||
... participants=builder.get_participants(),
|
||||
... graph=builder.build(),
|
||||
... termination_condition=MaxMessageTermination(5),
|
||||
... )
|
||||
|
||||
Example — Parallel Fan-out A → (B, C):
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
|
||||
|
||||
Example — Conditional Branching A → B or A → C:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> # Add conditional edges using keyword check
|
||||
>>> builder.add_edge(agent_a, agent_b, condition="keyword1")
|
||||
>>> builder.add_edge(agent_a, agent_c, condition="keyword2")
|
||||
|
||||
|
||||
Example — Using Custom String Conditions:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> # Add condition strings to check in messages
|
||||
>>> builder.add_edge(agent_a, agent_b, condition="big")
|
||||
>>> builder.add_edge(agent_a, agent_c, condition="small")
|
||||
|
||||
Example — Loop: A → B → A or B → C:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> builder.add_edge(agent_a, agent_b)
|
||||
>> # Add a loop back to agent A
|
||||
>>> builder.add_edge(agent_b, agent_a, condition=lambda msg: "loop" in msg.to_model_text())
|
||||
>>> # Add exit condition to break the loop
|
||||
>>> builder.add_edge(agent_b, agent_c, condition=lambda msg: "loop" not in msg.to_model_text())
|
||||
|
||||
Example — Loop with multiple paths to the same node: A → B → C → B:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> builder.add_edge(agent_a, agent_b)
|
||||
>>> builder.add_edge(agent_b, agent_c)
|
||||
>>> builder.add_edge(agent_c, agent_b, activation_group="loop_back")
|
||||
|
||||
Example — Loop with multiple paths to the same node with any activation condition: A → B → (C1, C2) → B → E(exit):
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c1).add_node(agent_c2).add_node(agent_e)
|
||||
>>> builder.add_edge(agent_a, agent_b)
|
||||
>>> builder.add_edge(agent_b, agent_c1)
|
||||
>>> builder.add_edge(agent_b, agent_c2)
|
||||
>>> builder.add_edge(agent_b, agent_e, condition="exit")
|
||||
>>> builder.add_edge(agent_c1, agent_b, activation_group="loop_back_group", activation_condition="any")
|
||||
>>> builder.add_edge(agent_c2, agent_b, activation_group="loop_back_group", activation_condition="any")
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.nodes: Dict[str, DiGraphNode] = {}
|
||||
self.agents: Dict[str, ChatAgent] = {}
|
||||
self._default_start_node: Optional[str] = None
|
||||
|
||||
def _get_name(self, obj: Union[str, ChatAgent]) -> str:
|
||||
return obj if isinstance(obj, str) else obj.name
|
||||
|
||||
def add_node(self, agent: ChatAgent, activation: Literal["all", "any"] = "all") -> "DiGraphBuilder":
|
||||
"""Add a node to the graph and register its agent."""
|
||||
name = agent.name
|
||||
if name not in self.nodes:
|
||||
self.nodes[name] = DiGraphNode(name=name, edges=[], activation=activation)
|
||||
self.agents[name] = agent
|
||||
return self
|
||||
|
||||
def add_edge(
|
||||
self,
|
||||
source: Union[str, ChatAgent],
|
||||
target: Union[str, ChatAgent],
|
||||
condition: Optional[Union[str, Callable[[BaseChatMessage], bool]]] = None,
|
||||
activation_group: Optional[str] = None,
|
||||
activation_condition: Optional[Literal["all", "any"]] = None,
|
||||
) -> "DiGraphBuilder":
|
||||
"""Add a directed edge from source to target, optionally with a condition.
|
||||
|
||||
Args:
|
||||
source: Source node (agent name or agent object)
|
||||
target: Target node (agent name or agent object)
|
||||
condition: Optional condition for edge activation.
|
||||
If string, activates when substring is found in message.
|
||||
If callable, activates when function returns True for the message.
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
|
||||
Raises:
|
||||
ValueError: If source or target node doesn't exist in the builder
|
||||
"""
|
||||
source_name = self._get_name(source)
|
||||
target_name = self._get_name(target)
|
||||
|
||||
if source_name not in self.nodes:
|
||||
raise ValueError(f"Source node '{source_name}' must be added before adding an edge.")
|
||||
if target_name not in self.nodes:
|
||||
raise ValueError(f"Target node '{target_name}' must be added before adding an edge.")
|
||||
if activation_group is None:
|
||||
activation_group = target_name
|
||||
if activation_condition is None:
|
||||
activation_condition = "all"
|
||||
self.nodes[source_name].edges.append(
|
||||
DiGraphEdge(
|
||||
target=target_name,
|
||||
condition=condition,
|
||||
activation_group=activation_group,
|
||||
activation_condition=activation_condition,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def add_conditional_edges(
|
||||
self, source: Union[str, ChatAgent], condition_to_target: Dict[str, Union[str, ChatAgent]]
|
||||
) -> "DiGraphBuilder":
|
||||
"""Add multiple conditional edges from a source node based on keyword checks.
|
||||
|
||||
.. warning::
|
||||
|
||||
This method interface will be changed in the future to support callable conditions.
|
||||
Please use `add_edge` if you need to specify custom conditions.
|
||||
|
||||
Args:
|
||||
source: Source node (agent name or agent object)
|
||||
condition_to_target: Mapping from condition strings to target nodes
|
||||
Each key is a keyword that will be checked in the message content
|
||||
Each value is the target node to activate when condition is met
|
||||
|
||||
For each key (keyword), a lambda will be created that checks
|
||||
if the keyword is in the message text.
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"add_conditional_edges will be changed in the future to support callable conditions. "
|
||||
"For now, please use add_edge if you need to specify custom conditions.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
for condition_keyword, target in condition_to_target.items():
|
||||
self.add_edge(source, target, condition=condition_keyword)
|
||||
return self
|
||||
|
||||
def set_entry_point(self, name: Union[str, ChatAgent]) -> "DiGraphBuilder":
|
||||
"""Set the default start node of the graph."""
|
||||
node_name = self._get_name(name)
|
||||
if node_name not in self.nodes:
|
||||
raise ValueError(f"Start node '{node_name}' must be added before setting as entry point.")
|
||||
self._default_start_node = node_name
|
||||
return self
|
||||
|
||||
def build(self) -> DiGraph:
|
||||
"""Build and validate the DiGraph."""
|
||||
graph = DiGraph(
|
||||
nodes=self.nodes,
|
||||
default_start_node=self._default_start_node,
|
||||
)
|
||||
graph.graph_validate()
|
||||
return graph
|
||||
|
||||
def get_participants(self) -> list[ChatAgent]:
|
||||
"""Return the list of agents in the builder, in insertion order."""
|
||||
return list(self.agents.values())
|
||||
@@ -0,0 +1,5 @@
|
||||
from ._magentic_one_group_chat import MagenticOneGroupChat
|
||||
|
||||
__all__ = [
|
||||
"MagenticOneGroupChat",
|
||||
]
|
||||
@@ -0,0 +1,209 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Callable, List
|
||||
|
||||
from agentdhal_core import AgentRuntime, Component, ComponentModel
|
||||
from agentdhal_core.models import ChatCompletionClient
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
||||
from ....base import ChatAgent, TerminationCondition
|
||||
from ....messages import BaseAgentEvent, BaseChatMessage, MessageFactory
|
||||
from .._base_group_chat import BaseGroupChat
|
||||
from .._events import GroupChatTermination
|
||||
from ._magentic_one_orchestrator import MagenticOneOrchestrator
|
||||
from ._prompts import ORCHESTRATOR_FINAL_ANSWER_PROMPT
|
||||
|
||||
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class MagenticOneGroupChatConfig(BaseModel):
|
||||
"""The declarative configuration for a MagenticOneGroupChat."""
|
||||
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
participants: List[ComponentModel]
|
||||
model_client: ComponentModel
|
||||
termination_condition: ComponentModel | None = None
|
||||
max_turns: int | None = None
|
||||
max_stalls: int
|
||||
final_answer_prompt: str
|
||||
emit_team_events: bool = False
|
||||
|
||||
|
||||
class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig]):
|
||||
"""A team that runs a group chat with participants managed by the MagenticOneOrchestrator.
|
||||
|
||||
The orchestrator handles the conversation flow, ensuring that the task is completed
|
||||
efficiently by managing the participants' interactions.
|
||||
|
||||
The orchestrator is based on the Magentic-One architecture, which is a generalist multi-agent system for solving complex tasks (see references below).
|
||||
|
||||
Unlike :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` and :class:`~agentdhal_agentchat.teams.SelectorGroupChat`,
|
||||
the MagenticOneGroupChat does not support using team as participant.
|
||||
|
||||
Args:
|
||||
participants (List[ChatAgent]): The participants in the group chat.
|
||||
model_client (ChatCompletionClient): The model client used for generating responses.
|
||||
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
|
||||
Without a termination condition, the group chat will run based on the orchestrator logic or until the maximum number of turns is reached.
|
||||
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to 20.
|
||||
max_stalls (int, optional): The maximum number of stalls allowed before re-planning. Defaults to 3.
|
||||
final_answer_prompt (str, optional): The LLM prompt used to generate the final answer or response from the team's transcript. A default (sensible for GPT-4o class models) is provided.
|
||||
custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): A list of custom message types that will be used in the group chat.
|
||||
If you are using custom message types or your agents produces custom message types, you need to specify them here.
|
||||
Make sure your custom message types are subclasses of :class:`~agentdhal_agentchat.messages.BaseAgentEvent` or :class:`~agentdhal_agentchat.messages.BaseChatMessage`.
|
||||
emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
|
||||
|
||||
Raises:
|
||||
ValueError: In orchestration logic if progress ledger does not have required keys or if next speaker is not valid.
|
||||
|
||||
Examples:
|
||||
|
||||
MagenticOneGroupChat with one assistant agent:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import MagenticOneGroupChat
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
assistant = AssistantAgent(
|
||||
"Assistant",
|
||||
model_client=model_client,
|
||||
)
|
||||
team = MagenticOneGroupChat([assistant], model_client=model_client)
|
||||
await Console(team.run_stream(task="Provide a different proof to Fermat last theorem"))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
References:
|
||||
|
||||
If you use the MagenticOneGroupChat in your work, please cite the following paper:
|
||||
|
||||
.. code-block:: bibtex
|
||||
|
||||
@article{fourney2024magentic,
|
||||
title={Magentic-one: A generalist multi-agent system for solving complex tasks},
|
||||
author={Fourney, Adam and Bansal, Gagan and Mozannar, Hussein and Tan, Cheng and Salinas, Eduardo and Niedtner, Friederike and Proebsting, Grace and Bassman, Griffin and Gerrits, Jack and Alber, Jacob and others},
|
||||
journal={arXiv preprint arXiv:2411.04468},
|
||||
year={2024}
|
||||
}
|
||||
"""
|
||||
|
||||
component_config_schema = MagenticOneGroupChatConfig
|
||||
component_provider_override = "agentdhal_agentchat.teams.MagenticOneGroupChat"
|
||||
|
||||
DEFAULT_NAME = "MagenticOneGroupChat"
|
||||
DEFAULT_DESCRIPTION = "A team of agents."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
participants: List[ChatAgent],
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = 20,
|
||||
runtime: AgentRuntime | None = None,
|
||||
max_stalls: int = 3,
|
||||
final_answer_prompt: str = ORCHESTRATOR_FINAL_ANSWER_PROMPT,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
):
|
||||
for participant in participants:
|
||||
if not isinstance(participant, ChatAgent):
|
||||
raise TypeError(f"Participant {participant} must be a ChatAgent.")
|
||||
super().__init__(
|
||||
name=name or self.DEFAULT_NAME,
|
||||
description=description or self.DEFAULT_DESCRIPTION,
|
||||
participants=list(participants),
|
||||
group_chat_manager_name="MagenticOneOrchestrator",
|
||||
group_chat_manager_class=MagenticOneOrchestrator,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
|
||||
# Validate the participants.
|
||||
if len(participants) == 0:
|
||||
raise ValueError("At least one participant is required for MagenticOneGroupChat.")
|
||||
self._model_client = model_client
|
||||
self._max_stalls = max_stalls
|
||||
self._final_answer_prompt = final_answer_prompt
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], MagenticOneOrchestrator]:
|
||||
return lambda: MagenticOneOrchestrator(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
max_turns,
|
||||
message_factory,
|
||||
self._model_client,
|
||||
self._max_stalls,
|
||||
self._final_answer_prompt,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
self._emit_team_events,
|
||||
)
|
||||
|
||||
def _to_config(self) -> MagenticOneGroupChatConfig:
|
||||
participants = [participant.dump_component() for participant in self._participants]
|
||||
termination_condition = self._termination_condition.dump_component() if self._termination_condition else None
|
||||
return MagenticOneGroupChatConfig(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
participants=participants,
|
||||
model_client=self._model_client.dump_component(),
|
||||
termination_condition=termination_condition,
|
||||
max_turns=self._max_turns,
|
||||
max_stalls=self._max_stalls,
|
||||
final_answer_prompt=self._final_answer_prompt,
|
||||
emit_team_events=self._emit_team_events,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: MagenticOneGroupChatConfig) -> Self:
|
||||
participants = [ChatAgent.load_component(participant) for participant in config.participants]
|
||||
model_client = ChatCompletionClient.load_component(config.model_client)
|
||||
termination_condition = (
|
||||
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
|
||||
)
|
||||
return cls(
|
||||
participants=participants,
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
model_client=model_client,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=config.max_turns,
|
||||
max_stalls=config.max_stalls,
|
||||
final_answer_prompt=config.final_answer_prompt,
|
||||
emit_team_events=config.emit_team_events,
|
||||
)
|
||||
@@ -0,0 +1,536 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Mapping, Sequence
|
||||
|
||||
from agentdhal_core import AgentId, CancellationToken, DefaultTopicId, MessageContext, event, rpc
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
LLMMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agentdhal_core.utils import extract_json_from_str
|
||||
|
||||
from .... import TRACE_LOGGER_NAME
|
||||
from ....base import Response, TerminationCondition
|
||||
from ....messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
HandoffMessage,
|
||||
MessageFactory,
|
||||
MultiModalMessage,
|
||||
SelectSpeakerEvent,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
ToolCallSummaryMessage,
|
||||
)
|
||||
from ....state import MagenticOneOrchestratorState
|
||||
from ....utils import remove_images
|
||||
from .._base_group_chat_manager import BaseGroupChatManager
|
||||
from .._events import (
|
||||
GroupChatAgentResponse,
|
||||
GroupChatMessage,
|
||||
GroupChatRequestPublish,
|
||||
GroupChatReset,
|
||||
GroupChatStart,
|
||||
GroupChatTeamResponse,
|
||||
GroupChatTermination,
|
||||
SerializableException,
|
||||
)
|
||||
from ._prompts import (
|
||||
ORCHESTRATOR_FINAL_ANSWER_PROMPT,
|
||||
ORCHESTRATOR_PROGRESS_LEDGER_PROMPT,
|
||||
ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT,
|
||||
ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT,
|
||||
ORCHESTRATOR_TASK_LEDGER_FULL_PROMPT,
|
||||
ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT,
|
||||
ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT,
|
||||
LedgerEntry,
|
||||
)
|
||||
|
||||
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
||||
|
||||
|
||||
class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
"""The MagenticOneOrchestrator manages a group chat with ledger based orchestration."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
model_client: ChatCompletionClient,
|
||||
max_stalls: int,
|
||||
final_answer_prompt: str,
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
emit_team_events: bool,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
self._model_client = model_client
|
||||
self._max_stalls = max_stalls
|
||||
self._final_answer_prompt = final_answer_prompt
|
||||
self._max_json_retries = 10
|
||||
self._task = ""
|
||||
self._facts = ""
|
||||
self._plan = ""
|
||||
self._n_rounds = 0
|
||||
self._n_stalls = 0
|
||||
|
||||
# Produce a team description. Each agent sould appear on a single line.
|
||||
self._team_description = ""
|
||||
for topic_type, description in zip(self._participant_names, self._participant_descriptions, strict=True):
|
||||
self._team_description += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n"
|
||||
self._team_description = self._team_description.strip()
|
||||
|
||||
def _get_task_ledger_facts_prompt(self, task: str) -> str:
|
||||
return ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT.format(task=task)
|
||||
|
||||
def _get_task_ledger_plan_prompt(self, team: str) -> str:
|
||||
return ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT.format(team=team)
|
||||
|
||||
def _get_task_ledger_full_prompt(self, task: str, team: str, facts: str, plan: str) -> str:
|
||||
return ORCHESTRATOR_TASK_LEDGER_FULL_PROMPT.format(task=task, team=team, facts=facts, plan=plan)
|
||||
|
||||
def _get_progress_ledger_prompt(self, task: str, team: str, names: List[str]) -> str:
|
||||
return ORCHESTRATOR_PROGRESS_LEDGER_PROMPT.format(task=task, team=team, names=", ".join(names))
|
||||
|
||||
def _get_task_ledger_facts_update_prompt(self, task: str, facts: str) -> str:
|
||||
return ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT.format(task=task, facts=facts)
|
||||
|
||||
def _get_task_ledger_plan_update_prompt(self, team: str) -> str:
|
||||
return ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT.format(team=team)
|
||||
|
||||
def _get_final_answer_prompt(self, task: str) -> str:
|
||||
if self._final_answer_prompt == ORCHESTRATOR_FINAL_ANSWER_PROMPT:
|
||||
return ORCHESTRATOR_FINAL_ANSWER_PROMPT.format(task=task)
|
||||
else:
|
||||
return self._final_answer_prompt
|
||||
|
||||
async def _log_message(self, log_message: str) -> None:
|
||||
trace_logger.debug(log_message)
|
||||
|
||||
@rpc
|
||||
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None: # type: ignore
|
||||
"""Handle the start of a task."""
|
||||
|
||||
# Check if the conversation has already terminated.
|
||||
if self._termination_condition is not None and self._termination_condition.terminated:
|
||||
early_stop_message = StopMessage(content="The group chat has already terminated.", source=self._name)
|
||||
# Signal termination.
|
||||
await self._signal_termination(early_stop_message)
|
||||
# Stop the group chat.
|
||||
return
|
||||
assert message is not None and message.messages is not None
|
||||
|
||||
# Validate the group state given all the messages.
|
||||
await self.validate_group_state(message.messages)
|
||||
|
||||
# Log the message to the output topic.
|
||||
await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||
# Log the message to the output queue.
|
||||
for msg in message.messages:
|
||||
await self._output_message_queue.put(msg)
|
||||
|
||||
# Outer Loop for first time
|
||||
# Create the initial task ledger
|
||||
#################################
|
||||
# Combine all message contents for task
|
||||
self._task = " ".join([msg.to_model_text() for msg in message.messages])
|
||||
planning_conversation: List[LLMMessage] = []
|
||||
|
||||
# 1. GATHER FACTS
|
||||
# create a closed book task and generate a response and update the chat history
|
||||
planning_conversation.append(
|
||||
UserMessage(content=self._get_task_ledger_facts_prompt(self._task), source=self._name)
|
||||
)
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(planning_conversation), cancellation_token=ctx.cancellation_token
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
self._facts = response.content
|
||||
planning_conversation.append(AssistantMessage(content=self._facts, source=self._name))
|
||||
|
||||
# 2. CREATE A PLAN
|
||||
## plan based on available information
|
||||
planning_conversation.append(
|
||||
UserMessage(content=self._get_task_ledger_plan_prompt(self._team_description), source=self._name)
|
||||
)
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(planning_conversation), cancellation_token=ctx.cancellation_token
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
self._plan = response.content
|
||||
|
||||
# Kick things off
|
||||
self._n_stalls = 0
|
||||
await self._reenter_outer_loop(ctx.cancellation_token)
|
||||
|
||||
@event
|
||||
async def handle_agent_response( # type: ignore
|
||||
self, message: GroupChatAgentResponse | GroupChatTeamResponse, ctx: MessageContext
|
||||
) -> None: # type: ignore
|
||||
try:
|
||||
if not isinstance(message, GroupChatAgentResponse):
|
||||
raise RuntimeError("MagenticOneOrchestrator does not support GroupChatTeamResponse messages.")
|
||||
delta: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if message.response.inner_messages is not None:
|
||||
for inner_message in message.response.inner_messages:
|
||||
delta.append(inner_message)
|
||||
await self.update_message_thread([message.response.chat_message])
|
||||
delta.append(message.response.chat_message)
|
||||
|
||||
if self._termination_condition is not None:
|
||||
stop_message = await self._termination_condition(delta)
|
||||
if stop_message is not None:
|
||||
# Reset the termination conditions.
|
||||
await self._termination_condition.reset()
|
||||
# Signal termination.
|
||||
await self._signal_termination(stop_message)
|
||||
return
|
||||
|
||||
await self._orchestrate_step(ctx.cancellation_token)
|
||||
except Exception as e:
|
||||
error = SerializableException.from_exception(e)
|
||||
await self._signal_termination_with_error(error)
|
||||
# Raise the error to the runtime.
|
||||
raise
|
||||
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
pass
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = MagenticOneOrchestratorState(
|
||||
message_thread=[msg.dump() for msg in self._message_thread],
|
||||
current_turn=self._current_turn,
|
||||
task=self._task,
|
||||
facts=self._facts,
|
||||
plan=self._plan,
|
||||
n_rounds=self._n_rounds,
|
||||
n_stalls=self._n_stalls,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
orchestrator_state = MagenticOneOrchestratorState.model_validate(state)
|
||||
self._message_thread = [self._message_factory.create(message) for message in orchestrator_state.message_thread]
|
||||
self._current_turn = orchestrator_state.current_turn
|
||||
self._task = orchestrator_state.task
|
||||
self._facts = orchestrator_state.facts
|
||||
self._plan = orchestrator_state.plan
|
||||
self._n_rounds = orchestrator_state.n_rounds
|
||||
self._n_stalls = orchestrator_state.n_stalls
|
||||
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Not used in this orchestrator, we select next speaker in _orchestrate_step."""
|
||||
return [""]
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset the group chat manager."""
|
||||
self._message_thread.clear()
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._n_rounds = 0
|
||||
self._n_stalls = 0
|
||||
self._task = ""
|
||||
self._facts = ""
|
||||
self._plan = ""
|
||||
|
||||
async def _reenter_outer_loop(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Re-enter Outer loop of the orchestrator after creating task ledger."""
|
||||
# Reset the agents
|
||||
for participant_topic_type in self._participant_name_to_topic_type.values():
|
||||
await self._runtime.send_message(
|
||||
GroupChatReset(),
|
||||
recipient=AgentId(type=participant_topic_type, key=self.id.key),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
# Reset partially the group chat manager
|
||||
self._message_thread.clear()
|
||||
|
||||
# Prepare the ledger
|
||||
ledger_message = TextMessage(
|
||||
content=self._get_task_ledger_full_prompt(self._task, self._team_description, self._facts, self._plan),
|
||||
source=self._name,
|
||||
)
|
||||
|
||||
# Save my copy
|
||||
await self.update_message_thread([ledger_message])
|
||||
|
||||
# Log it to the output topic.
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=ledger_message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Log it to the output queue.
|
||||
await self._output_message_queue.put(ledger_message)
|
||||
|
||||
# Broadcast
|
||||
await self.publish_message(
|
||||
GroupChatAgentResponse(response=Response(chat_message=ledger_message), name=self._name),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
)
|
||||
|
||||
# Restart the inner loop
|
||||
await self._orchestrate_step(cancellation_token=cancellation_token)
|
||||
|
||||
async def _orchestrate_step(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Implements the inner loop of the orchestrator and selects next speaker."""
|
||||
# Check if we reached the maximum number of rounds
|
||||
if self._max_turns is not None and self._n_rounds > self._max_turns:
|
||||
await self._prepare_final_answer("Max rounds reached.", cancellation_token)
|
||||
return
|
||||
self._n_rounds += 1
|
||||
|
||||
# Update the progress ledger
|
||||
context = self._thread_to_context()
|
||||
|
||||
progress_ledger_prompt = self._get_progress_ledger_prompt(
|
||||
self._task, self._team_description, self._participant_names
|
||||
)
|
||||
context.append(UserMessage(content=progress_ledger_prompt, source=self._name))
|
||||
progress_ledger: Dict[str, Any] = {}
|
||||
assert self._max_json_retries > 0
|
||||
key_error: bool = False
|
||||
for _ in range(self._max_json_retries):
|
||||
if self._model_client.model_info.get("structured_output", False):
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), json_output=LedgerEntry
|
||||
)
|
||||
elif self._model_client.model_info.get("json_output", False):
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), cancellation_token=cancellation_token, json_output=True
|
||||
)
|
||||
else:
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), cancellation_token=cancellation_token
|
||||
)
|
||||
ledger_str = response.content
|
||||
try:
|
||||
assert isinstance(ledger_str, str)
|
||||
output_json = extract_json_from_str(ledger_str)
|
||||
if len(output_json) != 1:
|
||||
raise ValueError(
|
||||
f"Progress ledger should contain a single JSON object, but found: {len(progress_ledger)}"
|
||||
)
|
||||
progress_ledger = output_json[0]
|
||||
|
||||
# If the team consists of a single agent, deterministically set the next speaker
|
||||
if len(self._participant_names) == 1:
|
||||
progress_ledger["next_speaker"] = {
|
||||
"reason": "The team consists of only one agent.",
|
||||
"answer": self._participant_names[0],
|
||||
}
|
||||
|
||||
# Validate the structure
|
||||
required_keys = [
|
||||
"is_request_satisfied",
|
||||
"is_progress_being_made",
|
||||
"is_in_loop",
|
||||
"instruction_or_question",
|
||||
"next_speaker",
|
||||
]
|
||||
|
||||
key_error = False
|
||||
for key in required_keys:
|
||||
if (
|
||||
key not in progress_ledger
|
||||
or not isinstance(progress_ledger[key], dict)
|
||||
or "answer" not in progress_ledger[key]
|
||||
or "reason" not in progress_ledger[key]
|
||||
):
|
||||
key_error = True
|
||||
break
|
||||
|
||||
# Validate the next speaker if the task is not yet complete
|
||||
if (
|
||||
not progress_ledger["is_request_satisfied"]["answer"]
|
||||
and progress_ledger["next_speaker"]["answer"] not in self._participant_names
|
||||
):
|
||||
key_error = True
|
||||
break
|
||||
|
||||
if not key_error:
|
||||
break
|
||||
await self._log_message(f"Failed to parse ledger information, retrying: {ledger_str}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
key_error = True
|
||||
await self._log_message("Invalid ledger format encountered, retrying...")
|
||||
continue
|
||||
if key_error:
|
||||
raise ValueError("Failed to parse ledger information after multiple retries.")
|
||||
await self._log_message(f"Progress Ledger: {progress_ledger}")
|
||||
|
||||
# Check for task completion
|
||||
if progress_ledger["is_request_satisfied"]["answer"]:
|
||||
await self._log_message("Task completed, preparing final answer...")
|
||||
await self._prepare_final_answer(progress_ledger["is_request_satisfied"]["reason"], cancellation_token)
|
||||
return
|
||||
|
||||
# Check for stalling
|
||||
if not progress_ledger["is_progress_being_made"]["answer"]:
|
||||
self._n_stalls += 1
|
||||
elif progress_ledger["is_in_loop"]["answer"]:
|
||||
self._n_stalls += 1
|
||||
else:
|
||||
self._n_stalls = max(0, self._n_stalls - 1)
|
||||
|
||||
# Too much stalling
|
||||
if self._n_stalls >= self._max_stalls:
|
||||
await self._log_message("Stall count exceeded, re-planning with the outer loop...")
|
||||
await self._update_task_ledger(cancellation_token)
|
||||
await self._reenter_outer_loop(cancellation_token)
|
||||
return
|
||||
|
||||
# Broadcast the next step
|
||||
message = TextMessage(content=progress_ledger["instruction_or_question"]["answer"], source=self._name)
|
||||
await self.update_message_thread([message]) # My copy
|
||||
|
||||
await self._log_message(f"Next Speaker: {progress_ledger['next_speaker']['answer']}")
|
||||
# Log it to the output topic.
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Log it to the output queue.
|
||||
await self._output_message_queue.put(message)
|
||||
|
||||
# Broadcast it
|
||||
await self.publish_message( # Broadcast
|
||||
GroupChatAgentResponse(response=Response(chat_message=message), name=self._name),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
# Request that the step be completed
|
||||
next_speaker = progress_ledger["next_speaker"]["answer"]
|
||||
# Check if the next speaker is valid
|
||||
if next_speaker not in self._participant_name_to_topic_type:
|
||||
raise ValueError(
|
||||
f"Invalid next speaker: {next_speaker} from the ledger, participants are: {self._participant_names}"
|
||||
)
|
||||
participant_topic_type = self._participant_name_to_topic_type[next_speaker]
|
||||
await self.publish_message(
|
||||
GroupChatRequestPublish(),
|
||||
topic_id=DefaultTopicId(type=participant_topic_type),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
# Send the message to the next speaker
|
||||
if self._emit_team_events:
|
||||
select_msg = SelectSpeakerEvent(content=[next_speaker], source=self._name)
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=select_msg),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
await self._output_message_queue.put(select_msg)
|
||||
|
||||
async def _update_task_ledger(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Update the task ledger (outer loop) with the latest facts and plan."""
|
||||
context = self._thread_to_context()
|
||||
|
||||
# Update the facts
|
||||
update_facts_prompt = self._get_task_ledger_facts_update_prompt(self._task, self._facts)
|
||||
context.append(UserMessage(content=update_facts_prompt, source=self._name))
|
||||
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), cancellation_token=cancellation_token
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
self._facts = response.content
|
||||
context.append(AssistantMessage(content=self._facts, source=self._name))
|
||||
|
||||
# Update the plan
|
||||
update_plan_prompt = self._get_task_ledger_plan_update_prompt(self._team_description)
|
||||
context.append(UserMessage(content=update_plan_prompt, source=self._name))
|
||||
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), cancellation_token=cancellation_token
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
self._plan = response.content
|
||||
|
||||
async def _prepare_final_answer(self, reason: str, cancellation_token: CancellationToken) -> None:
|
||||
"""Prepare the final answer for the task."""
|
||||
context = self._thread_to_context()
|
||||
|
||||
# Get the final answer
|
||||
final_answer_prompt = self._get_final_answer_prompt(self._task)
|
||||
context.append(UserMessage(content=final_answer_prompt, source=self._name))
|
||||
|
||||
response = await self._model_client.create(
|
||||
self._get_compatible_context(context), cancellation_token=cancellation_token
|
||||
)
|
||||
assert isinstance(response.content, str)
|
||||
message = TextMessage(content=response.content, source=self._name)
|
||||
|
||||
await self.update_message_thread([message]) # My copy
|
||||
|
||||
# Log it to the output topic.
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Log it to the output queue.
|
||||
await self._output_message_queue.put(message)
|
||||
|
||||
# Broadcast
|
||||
await self.publish_message(
|
||||
GroupChatAgentResponse(response=Response(chat_message=message), name=self._name),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
# Signal termination
|
||||
await self._signal_termination(StopMessage(content=reason, source=self._name))
|
||||
|
||||
def _thread_to_context(self) -> List[LLMMessage]:
|
||||
"""Convert the message thread to a context for the model."""
|
||||
context: List[LLMMessage] = []
|
||||
for m in self._message_thread:
|
||||
if isinstance(m, ToolCallRequestEvent | ToolCallExecutionEvent):
|
||||
# Ignore tool call messages.
|
||||
continue
|
||||
elif isinstance(m, StopMessage | HandoffMessage):
|
||||
context.append(UserMessage(content=m.content, source=m.source))
|
||||
elif m.source == self._name:
|
||||
assert isinstance(m, TextMessage | ToolCallSummaryMessage)
|
||||
context.append(AssistantMessage(content=m.content, source=m.source))
|
||||
else:
|
||||
assert isinstance(m, (TextMessage, MultiModalMessage, ToolCallSummaryMessage))
|
||||
context.append(UserMessage(content=m.content, source=m.source))
|
||||
return context
|
||||
|
||||
def _get_compatible_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
|
||||
if self._model_client.model_info["vision"]:
|
||||
return messages
|
||||
else:
|
||||
return remove_images(messages)
|
||||
@@ -0,0 +1,149 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
ORCHESTRATOR_SYSTEM_MESSAGE = ""
|
||||
|
||||
|
||||
ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT = """Below I will present you a request. Before we begin addressing the request, please answer the following pre-survey to the best of your ability. Keep in mind that you are Ken Jennings-level with trivia, and Mensa-level with puzzles, so there should be a deep well to draw from.
|
||||
|
||||
Here is the request:
|
||||
|
||||
{task}
|
||||
|
||||
Here is the pre-survey:
|
||||
|
||||
1. Please list any specific facts or figures that are GIVEN in the request itself. It is possible that there are none.
|
||||
2. Please list any facts that may need to be looked up, and WHERE SPECIFICALLY they might be found. In some cases, authoritative sources are mentioned in the request itself.
|
||||
3. Please list any facts that may need to be derived (e.g., via logical deduction, simulation, or computation)
|
||||
4. Please list any facts that are recalled from memory, hunches, well-reasoned guesses, etc.
|
||||
|
||||
When answering this survey, keep in mind that "facts" will typically be specific names, dates, statistics, etc. Your answer should use headings:
|
||||
|
||||
1. GIVEN OR VERIFIED FACTS
|
||||
2. FACTS TO LOOK UP
|
||||
3. FACTS TO DERIVE
|
||||
4. EDUCATED GUESSES
|
||||
|
||||
DO NOT include any other headings or sections in your response. DO NOT list next steps or plans until asked to do so.
|
||||
"""
|
||||
|
||||
|
||||
ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT = """Fantastic. To address this request we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
Based on the team composition, and known and unknown facts, please devise a short bullet-point plan for addressing the original request. Remember, there is no requirement to involve all team members -- a team member's particular expertise may not be needed for this task."""
|
||||
|
||||
|
||||
ORCHESTRATOR_TASK_LEDGER_FULL_PROMPT = """
|
||||
We are working to address the following user request:
|
||||
|
||||
{task}
|
||||
|
||||
|
||||
To answer this request we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
|
||||
Here is an initial fact sheet to consider:
|
||||
|
||||
{facts}
|
||||
|
||||
|
||||
Here is the plan to follow as best as possible:
|
||||
|
||||
{plan}
|
||||
"""
|
||||
|
||||
|
||||
ORCHESTRATOR_PROGRESS_LEDGER_PROMPT = """
|
||||
Recall we are working on the following request:
|
||||
|
||||
{task}
|
||||
|
||||
And we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
To make progress on the request, please answer the following questions, including necessary reasoning:
|
||||
|
||||
- Is the request fully satisfied? (True if complete, or False if the original request has yet to be SUCCESSFULLY and FULLY addressed)
|
||||
- Are we in a loop where we are repeating the same requests and / or getting the same responses as before? Loops can span multiple turns, and can include repeated actions like scrolling up or down more than a handful of times.
|
||||
- Are we making forward progress? (True if just starting, or recent messages are adding value. False if recent messages show evidence of being stuck in a loop or if there is evidence of significant barriers to success such as the inability to read from a required file)
|
||||
- Who should speak next? (select from: {names})
|
||||
- What instruction or question would you give this team member? (Phrase as if speaking directly to them, and include any specific information they may need)
|
||||
|
||||
Please output an answer in pure JSON format according to the following schema. The JSON object must be parsable as-is. DO NOT OUTPUT ANYTHING OTHER THAN JSON, AND DO NOT DEVIATE FROM THIS SCHEMA:
|
||||
|
||||
{{
|
||||
"is_request_satisfied": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}},
|
||||
"is_in_loop": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}},
|
||||
"is_progress_being_made": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}},
|
||||
"next_speaker": {{
|
||||
"reason": string,
|
||||
"answer": string (select from: {names})
|
||||
}},
|
||||
"instruction_or_question": {{
|
||||
"reason": string,
|
||||
"answer": string
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
class LedgerEntryBooleanAnswer(BaseModel):
|
||||
reason: str
|
||||
answer: bool
|
||||
|
||||
|
||||
class LedgerEntryStringAnswer(BaseModel):
|
||||
reason: str
|
||||
answer: str
|
||||
|
||||
|
||||
class LedgerEntry(BaseModel):
|
||||
is_request_satisfied: LedgerEntryBooleanAnswer
|
||||
is_in_loop: LedgerEntryBooleanAnswer
|
||||
is_progress_being_made: LedgerEntryBooleanAnswer
|
||||
next_speaker: LedgerEntryStringAnswer
|
||||
instruction_or_question: LedgerEntryStringAnswer
|
||||
|
||||
|
||||
ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT = """As a reminder, we are working to solve the following task:
|
||||
|
||||
{task}
|
||||
|
||||
It's clear we aren't making as much progress as we would like, but we may have learned something new. Please rewrite the following fact sheet, updating it to include anything new we have learned that may be helpful. Example edits can include (but are not limited to) adding new guesses, moving educated guesses to verified facts if appropriate, etc. Updates may be made to any section of the fact sheet, and more than one section of the fact sheet can be edited. This is an especially good time to update educated guesses, so please at least add or update one educated guess or hunch, and explain your reasoning.
|
||||
|
||||
Here is the old fact sheet:
|
||||
|
||||
{facts}
|
||||
"""
|
||||
|
||||
|
||||
ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT = """Please briefly explain what went wrong on this last run (the root cause of the failure), and then come up with a new plan that takes steps and/or includes hints to overcome prior challenges and especially avoids repeating the same mistakes. As before, the new plan should be concise, be expressed in bullet-point form, and consider the following team composition (do not involve any other outside people since we cannot contact anyone else):
|
||||
|
||||
{team}
|
||||
"""
|
||||
|
||||
|
||||
ORCHESTRATOR_FINAL_ANSWER_PROMPT = """
|
||||
We are working on the following task:
|
||||
{task}
|
||||
|
||||
We have completed the task.
|
||||
|
||||
The above messages contain the conversation that took place to complete the task.
|
||||
|
||||
Based on the information gathered, provide the final answer to the original request.
|
||||
The answer should be phrased as if you were speaking to the user.
|
||||
"""
|
||||
@@ -0,0 +1,328 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, List, Mapping, Sequence
|
||||
|
||||
from agentdhal_core import AgentRuntime, Component, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ...base import ChatAgent, Team, TerminationCondition
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory
|
||||
from ...state import RoundRobinManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
from ._events import GroupChatTermination
|
||||
|
||||
|
||||
class RoundRobinGroupChatManager(BaseGroupChatManager):
|
||||
"""A group chat manager that selects the next speaker in a round-robin fashion."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
emit_team_events: bool,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
emit_team_events,
|
||||
)
|
||||
self._next_speaker_index = 0
|
||||
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
pass
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._current_turn = 0
|
||||
self._message_thread.clear()
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._next_speaker_index = 0
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = RoundRobinManagerState(
|
||||
message_thread=[message.dump() for message in self._message_thread],
|
||||
current_turn=self._current_turn,
|
||||
next_speaker_index=self._next_speaker_index,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
round_robin_state = RoundRobinManagerState.model_validate(state)
|
||||
self._message_thread = [self._message_factory.create(message) for message in round_robin_state.message_thread]
|
||||
self._current_turn = round_robin_state.current_turn
|
||||
self._next_speaker_index = round_robin_state.next_speaker_index
|
||||
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Select a speaker from the participants in a round-robin fashion.
|
||||
|
||||
.. note::
|
||||
|
||||
This method always returns a single speaker.
|
||||
"""
|
||||
current_speaker_index = self._next_speaker_index
|
||||
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_names)
|
||||
current_speaker = self._participant_names[current_speaker_index]
|
||||
return current_speaker
|
||||
|
||||
|
||||
class RoundRobinGroupChatConfig(BaseModel):
|
||||
"""The declarative configuration RoundRobinGroupChat."""
|
||||
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
participants: List[ComponentModel]
|
||||
termination_condition: ComponentModel | None = None
|
||||
max_turns: int | None = None
|
||||
emit_team_events: bool = False
|
||||
|
||||
|
||||
class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
||||
"""A team that runs a group chat with participants taking turns in a round-robin fashion
|
||||
to publish a message to all.
|
||||
|
||||
If an :class:`~agentdhal_agentchat.base.ChatAgent` is a participant,
|
||||
the :class:`~agentdhal_agentchat.messages.BaseChatMessage` from the agent response's
|
||||
:attr:`~agentdhal_agentchat.base.Response.chat_message` will be published
|
||||
to other participants in the group chat.
|
||||
|
||||
If a :class:`~agentdhal_agentchat.base.Team` is a participant,
|
||||
the :class:`~agentdhal_agentchat.messages.BaseChatMessage`
|
||||
from the team result' :attr:`~agentdhal_agentchat.base.TaskResult.messages` will be published
|
||||
to other participants in the group chat.
|
||||
|
||||
If a single participant is in the team, the participant will be the only speaker.
|
||||
|
||||
Args:
|
||||
participants (List[ChatAgent | Team]): The participants in the group chat.
|
||||
name (str | None, optional): The name of the group chat, using :attr:`~agentdhal_agentchat.teams.RoundRobinGroupChat.DEFAULT_NAME` if not provided.
|
||||
The name is used by a parent team to identify this group chat so it must be unique within the parent team.
|
||||
description (str | None, optional): The description of the group chat, using :attr:`~agentdhal_agentchat.teams.RoundRobinGroupChat.DEFAULT_DESCRIPTION` if not provided.
|
||||
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
|
||||
Without a termination condition, the group chat will run indefinitely.
|
||||
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit.
|
||||
custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): A list of custom message types that will be used in the group chat.
|
||||
If you are using custom message types or your agents produces custom message types, you need to specify them here.
|
||||
Make sure your custom message types are subclasses of :class:`~agentdhal_agentchat.messages.BaseAgentEvent` or :class:`~agentdhal_agentchat.messages.BaseChatMessage`.
|
||||
emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
|
||||
|
||||
Raises:
|
||||
ValueError: If no participants are provided or if participant names are not unique.
|
||||
|
||||
Examples:
|
||||
|
||||
A team with one participant with tools:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
async def get_weather(location: str) -> str:
|
||||
return f"The weather in {location} is sunny."
|
||||
|
||||
assistant = AssistantAgent(
|
||||
"Assistant",
|
||||
model_client=model_client,
|
||||
tools=[get_weather],
|
||||
)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team = RoundRobinGroupChat([assistant], termination_condition=termination)
|
||||
await Console(team.run_stream(task="What's the weather in New York?"))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
A team with multiple participants:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
|
||||
await Console(team.run_stream(task="Tell me some jokes."))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
A team of user proxy and a nested team of writer and reviewer agents:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import UserProxyAgent, AssistantAgent
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination, MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
|
||||
|
||||
writer = AssistantAgent(
|
||||
"writer", model_client=model_client, system_message="You are a writer.", model_client_stream=True
|
||||
)
|
||||
|
||||
reviewer = AssistantAgent(
|
||||
"reviewer",
|
||||
model_client=model_client,
|
||||
system_message="Provide feedback to the input and suggest improvements.",
|
||||
model_client_stream=True,
|
||||
)
|
||||
|
||||
# NOTE: you can skip input by pressing Enter.
|
||||
user_proxy = UserProxyAgent("user_proxy")
|
||||
|
||||
# Maximum 1 round of review and revision.
|
||||
inner_termination = MaxMessageTermination(max_messages=4)
|
||||
|
||||
# The outter-loop termination condition that will terminate the team when the user types "exit".
|
||||
outter_termination = TextMentionTermination("exit", sources=["user_proxy"])
|
||||
|
||||
team = RoundRobinGroupChat(
|
||||
[
|
||||
# For each turn, the writer writes a summary and the reviewer reviews it.
|
||||
RoundRobinGroupChat([writer, reviewer], termination_condition=inner_termination),
|
||||
# The user proxy gets user input once the writer and reviewer have finished their actions.
|
||||
user_proxy,
|
||||
],
|
||||
termination_condition=outter_termination,
|
||||
)
|
||||
# Start the team and wait for it to terminate.
|
||||
await Console(team.run_stream(task="Write a short essay about the impact of AI on society."))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_config_schema = RoundRobinGroupChatConfig
|
||||
component_provider_override = "agentdhal_agentchat.teams.RoundRobinGroupChat"
|
||||
|
||||
DEFAULT_NAME = "RoundRobinGroupChat"
|
||||
DEFAULT_DESCRIPTION = "A team of agents."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
participants: List[ChatAgent | Team],
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name=name or self.DEFAULT_NAME,
|
||||
description=description or self.DEFAULT_DESCRIPTION,
|
||||
participants=participants,
|
||||
group_chat_manager_name="RoundRobinGroupChatManager",
|
||||
group_chat_manager_class=RoundRobinGroupChatManager,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], RoundRobinGroupChatManager]:
|
||||
def _factory() -> RoundRobinGroupChatManager:
|
||||
return RoundRobinGroupChatManager(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
self._emit_team_events,
|
||||
)
|
||||
|
||||
return _factory
|
||||
|
||||
def _to_config(self) -> RoundRobinGroupChatConfig:
|
||||
participants = [participant.dump_component() for participant in self._participants]
|
||||
termination_condition = self._termination_condition.dump_component() if self._termination_condition else None
|
||||
return RoundRobinGroupChatConfig(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
participants=participants,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=self._max_turns,
|
||||
emit_team_events=self._emit_team_events,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: RoundRobinGroupChatConfig) -> Self:
|
||||
participants: List[ChatAgent | Team] = []
|
||||
for participant in config.participants:
|
||||
if participant.component_type == Team.component_type:
|
||||
participants.append(Team.load_component(participant))
|
||||
else:
|
||||
participants.append(ChatAgent.load_component(participant))
|
||||
|
||||
termination_condition = (
|
||||
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
|
||||
)
|
||||
return cls(
|
||||
participants,
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=config.max_turns,
|
||||
emit_team_events=config.emit_team_events,
|
||||
)
|
||||
@@ -0,0 +1,730 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast
|
||||
|
||||
from agentdhal_core import AgentRuntime, CancellationToken, Component, ComponentModel
|
||||
from agentdhal_core.model_context import (
|
||||
ChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
LLMMessage,
|
||||
ModelFamily,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ... import TRACE_LOGGER_NAME
|
||||
from ...base import ChatAgent, Team, TerminationCondition
|
||||
from ...messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
HandoffMessage,
|
||||
MessageFactory,
|
||||
ModelClientStreamingChunkEvent,
|
||||
SelectorEvent,
|
||||
)
|
||||
from ...state import SelectorManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
from ._events import GroupChatTermination
|
||||
|
||||
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
||||
|
||||
SyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None]
|
||||
AsyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]]
|
||||
SelectorFuncType = Union[SyncSelectorFunc | AsyncSelectorFunc]
|
||||
|
||||
SyncCandidateFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]]
|
||||
AsyncCandidateFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[List[str]]]
|
||||
CandidateFuncType = Union[SyncCandidateFunc | AsyncCandidateFunc]
|
||||
|
||||
|
||||
class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
"""A group chat manager that selects the next speaker using a ChatCompletion
|
||||
model and a custom selector function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
model_client: ChatCompletionClient,
|
||||
selector_prompt: str,
|
||||
allow_repeated_speaker: bool,
|
||||
selector_func: Optional[SelectorFuncType],
|
||||
max_selector_attempts: int,
|
||||
candidate_func: Optional[CandidateFuncType],
|
||||
emit_team_events: bool,
|
||||
model_context: ChatCompletionContext | None,
|
||||
model_client_streaming: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
emit_team_events,
|
||||
)
|
||||
self._model_client = model_client
|
||||
self._selector_prompt = selector_prompt
|
||||
self._previous_speaker: str | None = None
|
||||
self._allow_repeated_speaker = allow_repeated_speaker
|
||||
self._selector_func = selector_func
|
||||
self._is_selector_func_async = iscoroutinefunction(self._selector_func)
|
||||
self._max_selector_attempts = max_selector_attempts
|
||||
self._candidate_func = candidate_func
|
||||
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
|
||||
self._model_client_streaming = model_client_streaming
|
||||
if model_context is not None:
|
||||
self._model_context = model_context
|
||||
else:
|
||||
self._model_context = UnboundedChatCompletionContext()
|
||||
self._cancellation_token = CancellationToken()
|
||||
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
pass
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._current_turn = 0
|
||||
self._message_thread.clear()
|
||||
await self._model_context.clear()
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._previous_speaker = None
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = SelectorManagerState(
|
||||
message_thread=[msg.dump() for msg in self._message_thread],
|
||||
current_turn=self._current_turn,
|
||||
previous_speaker=self._previous_speaker,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
selector_state = SelectorManagerState.model_validate(state)
|
||||
self._message_thread = [self._message_factory.create(msg) for msg in selector_state.message_thread]
|
||||
await self._add_messages_to_context(
|
||||
self._model_context, [msg for msg in self._message_thread if isinstance(msg, BaseChatMessage)]
|
||||
)
|
||||
self._current_turn = selector_state.current_turn
|
||||
self._previous_speaker = selector_state.previous_speaker
|
||||
|
||||
@staticmethod
|
||||
async def _add_messages_to_context(
|
||||
model_context: ChatCompletionContext,
|
||||
messages: Sequence[BaseChatMessage],
|
||||
) -> None:
|
||||
"""
|
||||
Add incoming messages to the model context.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HandoffMessage):
|
||||
for llm_msg in msg.context:
|
||||
await model_context.add_message(llm_msg)
|
||||
await model_context.add_message(msg.to_model_message())
|
||||
|
||||
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
|
||||
self._message_thread.extend(messages)
|
||||
base_chat_messages = [m for m in messages if isinstance(m, BaseChatMessage)]
|
||||
await self._add_messages_to_context(self._model_context, base_chat_messages)
|
||||
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Selects the next speaker in a group chat using a ChatCompletion client,
|
||||
with the selector function as override if it returns a speaker name.
|
||||
|
||||
.. note::
|
||||
|
||||
This method always returns a single speaker name.
|
||||
|
||||
A key assumption is that the agent type is the same as the topic type, which we use as the agent name.
|
||||
"""
|
||||
# Use the selector function if provided.
|
||||
if self._selector_func is not None:
|
||||
if self._is_selector_func_async:
|
||||
async_selector_func = cast(AsyncSelectorFunc, self._selector_func)
|
||||
speaker = await async_selector_func(thread)
|
||||
else:
|
||||
sync_selector_func = cast(SyncSelectorFunc, self._selector_func)
|
||||
speaker = sync_selector_func(thread)
|
||||
if speaker is not None:
|
||||
if speaker not in self._participant_names:
|
||||
raise ValueError(
|
||||
f"Selector function returned an invalid speaker name: {speaker}. "
|
||||
f"Expected one of: {self._participant_names}."
|
||||
)
|
||||
# Skip the model based selection.
|
||||
return [speaker]
|
||||
|
||||
# Use the candidate function to filter participants if provided
|
||||
if self._candidate_func is not None:
|
||||
if self._is_candidate_func_async:
|
||||
async_candidate_func = cast(AsyncCandidateFunc, self._candidate_func)
|
||||
participants = await async_candidate_func(thread)
|
||||
else:
|
||||
sync_candidate_func = cast(SyncCandidateFunc, self._candidate_func)
|
||||
participants = sync_candidate_func(thread)
|
||||
if not participants:
|
||||
raise ValueError("Candidate function must return a non-empty list of participant names.")
|
||||
if not all(p in self._participant_names for p in participants):
|
||||
raise ValueError(
|
||||
f"Candidate function returned invalid participant names: {participants}. "
|
||||
f"Expected one of: {self._participant_names}."
|
||||
)
|
||||
else:
|
||||
# Construct the candidate agent list to be selected from, skip the previous speaker if not allowed.
|
||||
if self._previous_speaker is not None and not self._allow_repeated_speaker:
|
||||
participants = [p for p in self._participant_names if p != self._previous_speaker]
|
||||
else:
|
||||
participants = list(self._participant_names)
|
||||
|
||||
assert len(participants) > 0
|
||||
|
||||
# Construct agent roles.
|
||||
# Each agent sould appear on a single line.
|
||||
roles = ""
|
||||
for topic_type, description in zip(self._participant_names, self._participant_descriptions, strict=True):
|
||||
roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n"
|
||||
roles = roles.strip()
|
||||
|
||||
# Select the next speaker.
|
||||
if len(participants) > 1:
|
||||
agent_name = await self._select_speaker(roles, participants, self._max_selector_attempts)
|
||||
else:
|
||||
agent_name = participants[0]
|
||||
self._previous_speaker = agent_name
|
||||
trace_logger.debug(f"Selected speaker: {agent_name}")
|
||||
return [agent_name]
|
||||
|
||||
def construct_message_history(self, message_history: List[LLMMessage]) -> str:
|
||||
# Construct the history of the conversation.
|
||||
history_messages: List[str] = []
|
||||
for msg in message_history:
|
||||
if isinstance(msg, UserMessage) or isinstance(msg, AssistantMessage):
|
||||
message = f"{msg.source}: {msg.content}"
|
||||
history_messages.append(
|
||||
message.rstrip() + "\n\n"
|
||||
) # Create some consistency for how messages are separated in the transcript
|
||||
|
||||
history: str = "\n".join(history_messages)
|
||||
return history
|
||||
|
||||
async def _select_speaker(self, roles: str, participants: List[str], max_attempts: int) -> str:
|
||||
model_context_messages = await self._model_context.get_messages()
|
||||
model_context_history = self.construct_message_history(model_context_messages)
|
||||
|
||||
select_speaker_prompt = self._selector_prompt.format(
|
||||
roles=roles, participants=str(participants), history=model_context_history
|
||||
)
|
||||
|
||||
select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage]
|
||||
if ModelFamily.is_openai(self._model_client.model_info["family"]):
|
||||
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
|
||||
else:
|
||||
# Many other models need a UserMessage to respond to
|
||||
select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="user")]
|
||||
|
||||
num_attempts = 0
|
||||
while num_attempts < max_attempts:
|
||||
num_attempts += 1
|
||||
if self._model_client_streaming:
|
||||
chunk: CreateResult | str = ""
|
||||
async for _chunk in self._model_client.create_stream(messages=select_speaker_messages):
|
||||
chunk = _chunk
|
||||
if self._emit_team_events:
|
||||
if isinstance(chunk, str):
|
||||
await self._output_message_queue.put(
|
||||
ModelClientStreamingChunkEvent(content=cast(str, _chunk), source=self._name)
|
||||
)
|
||||
else:
|
||||
assert isinstance(chunk, CreateResult)
|
||||
assert isinstance(chunk.content, str)
|
||||
await self._output_message_queue.put(
|
||||
SelectorEvent(content=chunk.content, source=self._name)
|
||||
)
|
||||
# The last chunk must be CreateResult.
|
||||
assert isinstance(chunk, CreateResult)
|
||||
response = chunk
|
||||
else:
|
||||
response = await self._model_client.create(messages=select_speaker_messages)
|
||||
assert isinstance(response.content, str)
|
||||
select_speaker_messages.append(AssistantMessage(content=response.content, source="selector"))
|
||||
# NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed.
|
||||
# This is because the model may still select the previous speaker, and we want to catch that.
|
||||
mentions = self._mentioned_agents(response.content, self._participant_names)
|
||||
if len(mentions) == 0:
|
||||
trace_logger.debug(f"Model failed to select a valid name: {response.content} (attempt {num_attempts})")
|
||||
feedback = f"No valid name was mentioned. Please select from: {str(participants)}."
|
||||
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
|
||||
elif len(mentions) > 1:
|
||||
trace_logger.debug(f"Model selected multiple names: {str(mentions)} (attempt {num_attempts})")
|
||||
feedback = (
|
||||
f"Expected exactly one name to be mentioned. Please select only one from: {str(participants)}."
|
||||
)
|
||||
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
|
||||
else:
|
||||
agent_name = list(mentions.keys())[0]
|
||||
if (
|
||||
not self._allow_repeated_speaker
|
||||
and self._previous_speaker is not None
|
||||
and agent_name == self._previous_speaker
|
||||
):
|
||||
trace_logger.debug(f"Model selected the previous speaker: {agent_name} (attempt {num_attempts})")
|
||||
feedback = (
|
||||
f"Repeated speaker is not allowed, please select a different name from: {str(participants)}."
|
||||
)
|
||||
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
|
||||
else:
|
||||
# Valid selection
|
||||
trace_logger.debug(f"Model selected a valid name: {agent_name} (attempt {num_attempts})")
|
||||
return agent_name
|
||||
|
||||
if self._previous_speaker is not None:
|
||||
trace_logger.warning(f"Model failed to select a speaker after {max_attempts}, using the previous speaker.")
|
||||
return self._previous_speaker
|
||||
trace_logger.warning(
|
||||
f"Model failed to select a speaker after {max_attempts} and there was no previous speaker, using the first participant."
|
||||
)
|
||||
return participants[0]
|
||||
|
||||
def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]:
|
||||
"""Counts the number of times each agent is mentioned in the provided message content.
|
||||
Agent names will match under any of the following conditions (all case-sensitive):
|
||||
- Exact name match
|
||||
- If the agent name has underscores it will match with spaces instead (e.g. 'Story_writer' == 'Story writer')
|
||||
- If the agent name has underscores it will match with '\\_' instead of '_' (e.g. 'Story_writer' == 'Story\\_writer')
|
||||
|
||||
Args:
|
||||
message_content (Union[str, List]): The content of the message, either as a single string or a list of strings.
|
||||
agents (List[Agent]): A list of Agent objects, each having a 'name' attribute to be searched in the message content.
|
||||
|
||||
Returns:
|
||||
Dict: a counter for mentioned agents.
|
||||
"""
|
||||
mentions: Dict[str, int] = dict()
|
||||
for name in agent_names:
|
||||
# Finds agent mentions, taking word boundaries into account,
|
||||
# accommodates escaping underscores and underscores as spaces
|
||||
regex = (
|
||||
r"(?<=\W)("
|
||||
+ re.escape(name)
|
||||
+ r"|"
|
||||
+ re.escape(name.replace("_", " "))
|
||||
+ r"|"
|
||||
+ re.escape(name.replace("_", r"\_"))
|
||||
+ r")(?=\W)"
|
||||
)
|
||||
# Pad the message to help with matching
|
||||
count = len(re.findall(regex, f" {message_content} "))
|
||||
if count > 0:
|
||||
mentions[name] = count
|
||||
return mentions
|
||||
|
||||
|
||||
class SelectorGroupChatConfig(BaseModel):
|
||||
"""The declarative configuration for SelectorGroupChat."""
|
||||
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
participants: List[ComponentModel]
|
||||
model_client: ComponentModel
|
||||
termination_condition: ComponentModel | None = None
|
||||
max_turns: int | None = None
|
||||
selector_prompt: str
|
||||
allow_repeated_speaker: bool
|
||||
# selector_func: ComponentModel | None
|
||||
max_selector_attempts: int = 3
|
||||
emit_team_events: bool = False
|
||||
model_client_streaming: bool = False
|
||||
model_context: ComponentModel | None = None
|
||||
|
||||
|
||||
class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
||||
"""A group chat team that have participants takes turn to publish a message
|
||||
to all, using a ChatCompletion model to select the next speaker after each message.
|
||||
|
||||
If an :class:`~agentdhal_agentchat.base.ChatAgent` is a participant,
|
||||
the :class:`~agentdhal_agentchat.messages.BaseChatMessage` from the agent response's
|
||||
:attr:`~agentdhal_agentchat.base.Response.chat_message` will be published
|
||||
to other participants in the group chat.
|
||||
|
||||
If a :class:`~agentdhal_agentchat.base.Team` is a participant,
|
||||
the :class:`~agentdhal_agentchat.messages.BaseChatMessage`
|
||||
from the team result' :attr:`~agentdhal_agentchat.base.TaskResult.messages` will be published
|
||||
to other participants in the group chat.
|
||||
|
||||
Args:
|
||||
participants (List[ChatAgent | Team]): The participants in the group chat,
|
||||
must have unique names and at least two participants.
|
||||
model_client (ChatCompletionClient): The ChatCompletion model client used
|
||||
to select the next speaker.
|
||||
name (str | None, optional): The name of the group chat, using
|
||||
:attr:`~agentdhal_agentchat.teams.SelectorGroupChat.DEFAULT_NAME` if not provided.
|
||||
The name is used by a parent team to identify this group chat so it must
|
||||
be unique within the parent team.
|
||||
description (str | None, optional): The description of the group chat, using
|
||||
:attr:`~agentdhal_agentchat.teams.SelectorGroupChat.DEFAULT_DESCRIPTION` if not provided.
|
||||
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
|
||||
Without a termination condition, the group chat will run indefinitely.
|
||||
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit.
|
||||
selector_prompt (str, optional): The prompt template to use for selecting the next speaker.
|
||||
Available fields: '{roles}', '{participants}', and '{history}'.
|
||||
`{participants}` is the names of candidates for selection. The format is `["<name1>", "<name2>", ...]`.
|
||||
`{roles}` is a newline-separated list of names and descriptions of the candidate agents. The format for each line is: `"<name> : <description>"`.
|
||||
`{history}` is the conversation history formatted as a double newline separated of names and message content. The format for each message is: `"<name> : <message content>"`.
|
||||
allow_repeated_speaker (bool, optional): Whether to include the previous speaker in the list of candidates to be selected for the next turn.
|
||||
Defaults to False. The model may still select the previous speaker -- a warning will be logged if this happens.
|
||||
max_selector_attempts (int, optional): The maximum number of attempts to select a speaker using the model. Defaults to 3.
|
||||
If the model fails to select a speaker after the maximum number of attempts, the previous speaker will be used if available,
|
||||
otherwise the first participant will be used.
|
||||
selector_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]], optional): A custom selector
|
||||
function that takes the conversation history and returns the name of the next speaker.
|
||||
If provided, this function will be used to override the model to select the next speaker.
|
||||
If the function returns None, the model will be used to select the next speaker.
|
||||
NOTE: `selector_func` is not serializable and will be ignored during serialization and deserialization process.
|
||||
candidate_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[List[str]]], optional):
|
||||
A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker
|
||||
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
|
||||
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
|
||||
custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): A list of custom message types that will be used in the group chat.
|
||||
If you are using custom message types or your agents produces custom message types, you need to specify them here.
|
||||
Make sure your custom message types are subclasses of :class:`~agentdhal_agentchat.messages.BaseAgentEvent` or :class:`~agentdhal_agentchat.messages.BaseChatMessage`.
|
||||
emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
|
||||
model_client_streaming (bool, optional): Whether to use streaming for the model client. (This is useful for reasoning models like QwQ). Defaults to False.
|
||||
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving
|
||||
:class:`~agentdhal_core.models.LLMMessage`. It can be preloaded with initial messages. Messages stored in model context will be used for speaker selection. The initial messages will be cleared when the team is reset.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
|
||||
|
||||
Examples:
|
||||
|
||||
A team with multiple participants:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import SelectorGroupChat
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
async def lookup_hotel(location: str) -> str:
|
||||
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
|
||||
|
||||
async def lookup_flight(origin: str, destination: str) -> str:
|
||||
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
|
||||
|
||||
async def book_trip() -> str:
|
||||
return "Your trip is booked!"
|
||||
|
||||
travel_advisor = AssistantAgent(
|
||||
"Travel_Advisor",
|
||||
model_client,
|
||||
tools=[book_trip],
|
||||
description="Helps with travel planning.",
|
||||
)
|
||||
hotel_agent = AssistantAgent(
|
||||
"Hotel_Agent",
|
||||
model_client,
|
||||
tools=[lookup_hotel],
|
||||
description="Helps with hotel booking.",
|
||||
)
|
||||
flight_agent = AssistantAgent(
|
||||
"Flight_Agent",
|
||||
model_client,
|
||||
tools=[lookup_flight],
|
||||
description="Helps with flight booking.",
|
||||
)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team = SelectorGroupChat(
|
||||
[travel_advisor, hotel_agent, flight_agent],
|
||||
model_client=model_client,
|
||||
termination_condition=termination,
|
||||
)
|
||||
await Console(team.run_stream(task="Book a 3-day trip to new york."))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
A team with a custom selector function:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from typing import Sequence
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import SelectorGroupChat
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
def check_calculation(x: int, y: int, answer: int) -> str:
|
||||
if x + y == answer:
|
||||
return "Correct!"
|
||||
else:
|
||||
return "Incorrect!"
|
||||
|
||||
agent1 = AssistantAgent(
|
||||
"Agent1",
|
||||
model_client,
|
||||
description="For calculation",
|
||||
system_message="Calculate the sum of two numbers",
|
||||
)
|
||||
agent2 = AssistantAgent(
|
||||
"Agent2",
|
||||
model_client,
|
||||
tools=[check_calculation],
|
||||
description="For checking calculation",
|
||||
system_message="Check the answer and respond with 'Correct!' or 'Incorrect!'",
|
||||
)
|
||||
|
||||
def selector_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:
|
||||
if len(messages) == 1 or messages[-1].to_text() == "Incorrect!":
|
||||
return "Agent1"
|
||||
if messages[-1].source == "Agent1":
|
||||
return "Agent2"
|
||||
return None
|
||||
|
||||
termination = TextMentionTermination("Correct!")
|
||||
team = SelectorGroupChat(
|
||||
[agent1, agent2],
|
||||
model_client=model_client,
|
||||
selector_func=selector_func,
|
||||
termination_condition=termination,
|
||||
)
|
||||
|
||||
await Console(team.run_stream(task="What is 1 + 1?"))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
A team with custom model context:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_core.model_context import BufferedChatCompletionContext
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.teams import SelectorGroupChat
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
model_context = BufferedChatCompletionContext(buffer_size=5)
|
||||
|
||||
async def lookup_hotel(location: str) -> str:
|
||||
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
|
||||
|
||||
async def lookup_flight(origin: str, destination: str) -> str:
|
||||
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
|
||||
|
||||
async def book_trip() -> str:
|
||||
return "Your trip is booked!"
|
||||
|
||||
travel_advisor = AssistantAgent(
|
||||
"Travel_Advisor",
|
||||
model_client,
|
||||
tools=[book_trip],
|
||||
description="Helps with travel planning.",
|
||||
)
|
||||
hotel_agent = AssistantAgent(
|
||||
"Hotel_Agent",
|
||||
model_client,
|
||||
tools=[lookup_hotel],
|
||||
description="Helps with hotel booking.",
|
||||
)
|
||||
flight_agent = AssistantAgent(
|
||||
"Flight_Agent",
|
||||
model_client,
|
||||
tools=[lookup_flight],
|
||||
description="Helps with flight booking.",
|
||||
)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team = SelectorGroupChat(
|
||||
[travel_advisor, hotel_agent, flight_agent],
|
||||
model_client=model_client,
|
||||
termination_condition=termination,
|
||||
model_context=model_context,
|
||||
)
|
||||
await Console(team.run_stream(task="Book a 3-day trip to new york."))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_config_schema = SelectorGroupChatConfig
|
||||
component_provider_override = "agentdhal_agentchat.teams.SelectorGroupChat"
|
||||
|
||||
DEFAULT_NAME = "SelectorGroupChat"
|
||||
DEFAULT_DESCRIPTION = "A team of agents."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
participants: List[ChatAgent | Team],
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
selector_prompt: str = """You are in a role play game. The following roles are available:
|
||||
{roles}.
|
||||
Read the following conversation. Then select the next role from {participants} to play. Only return the role.
|
||||
|
||||
{history}
|
||||
|
||||
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
|
||||
""",
|
||||
allow_repeated_speaker: bool = False,
|
||||
max_selector_attempts: int = 3,
|
||||
selector_func: Optional[SelectorFuncType] = None,
|
||||
candidate_func: Optional[CandidateFuncType] = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
model_client_streaming: bool = False,
|
||||
model_context: ChatCompletionContext | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name or self.DEFAULT_NAME,
|
||||
description=description or self.DEFAULT_DESCRIPTION,
|
||||
participants=participants,
|
||||
group_chat_manager_name="SelectorGroupChatManager",
|
||||
group_chat_manager_class=SelectorGroupChatManager,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
# Validate the participants.
|
||||
if len(participants) < 2:
|
||||
raise ValueError("At least two participants are required for SelectorGroupChat.")
|
||||
self._selector_prompt = selector_prompt
|
||||
self._model_client = model_client
|
||||
self._allow_repeated_speaker = allow_repeated_speaker
|
||||
self._selector_func = selector_func
|
||||
self._max_selector_attempts = max_selector_attempts
|
||||
self._candidate_func = candidate_func
|
||||
self._model_client_streaming = model_client_streaming
|
||||
self._model_context = model_context
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], BaseGroupChatManager]:
|
||||
return lambda: SelectorGroupChatManager(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
self._model_client,
|
||||
self._selector_prompt,
|
||||
self._allow_repeated_speaker,
|
||||
self._selector_func,
|
||||
self._max_selector_attempts,
|
||||
self._candidate_func,
|
||||
self._emit_team_events,
|
||||
self._model_context,
|
||||
self._model_client_streaming,
|
||||
)
|
||||
|
||||
def _to_config(self) -> SelectorGroupChatConfig:
|
||||
return SelectorGroupChatConfig(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
participants=[participant.dump_component() for participant in self._participants],
|
||||
model_client=self._model_client.dump_component(),
|
||||
termination_condition=self._termination_condition.dump_component() if self._termination_condition else None,
|
||||
max_turns=self._max_turns,
|
||||
selector_prompt=self._selector_prompt,
|
||||
allow_repeated_speaker=self._allow_repeated_speaker,
|
||||
max_selector_attempts=self._max_selector_attempts,
|
||||
# selector_func=self._selector_func.dump_component() if self._selector_func else None,
|
||||
emit_team_events=self._emit_team_events,
|
||||
model_client_streaming=self._model_client_streaming,
|
||||
model_context=self._model_context.dump_component() if self._model_context else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: SelectorGroupChatConfig) -> Self:
|
||||
participants: List[ChatAgent | Team] = []
|
||||
for participant in config.participants:
|
||||
if participant.component_type == ChatAgent.component_type:
|
||||
participants.append(ChatAgent.load_component(participant))
|
||||
elif participant.component_type == Team.component_type:
|
||||
participants.append(Team.load_component(participant))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid participant component type: {participant.component_type}. " "Expected ChatAgent or Team."
|
||||
)
|
||||
return cls(
|
||||
participants=participants,
|
||||
model_client=ChatCompletionClient.load_component(config.model_client),
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
termination_condition=TerminationCondition.load_component(config.termination_condition)
|
||||
if config.termination_condition
|
||||
else None,
|
||||
max_turns=config.max_turns,
|
||||
selector_prompt=config.selector_prompt,
|
||||
allow_repeated_speaker=config.allow_repeated_speaker,
|
||||
max_selector_attempts=config.max_selector_attempts,
|
||||
# selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None])
|
||||
# if config.selector_func
|
||||
# else None,
|
||||
emit_team_events=config.emit_team_events,
|
||||
model_client_streaming=config.model_client_streaming,
|
||||
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
|
||||
)
|
||||
@@ -0,0 +1,72 @@
|
||||
import asyncio
|
||||
from typing import Any, Sequence
|
||||
|
||||
from agentdhal_core import MessageContext, RoutedAgent
|
||||
|
||||
|
||||
class FIFOLock:
|
||||
"""A lock that ensures coroutines acquire the lock in the order they request it."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._queue = asyncio.Queue[asyncio.Event]()
|
||||
self._locked = False
|
||||
|
||||
async def acquire(self) -> None:
|
||||
# If the lock is not held by any coroutine, set the lock to be held
|
||||
# by the current coroutine.
|
||||
if not self._locked:
|
||||
self._locked = True
|
||||
return
|
||||
|
||||
# If the lock is held by another coroutine, create an event and put it
|
||||
# in the queue. Wait for the event to be set.
|
||||
event = asyncio.Event()
|
||||
await self._queue.put(event)
|
||||
await event.wait()
|
||||
|
||||
def release(self) -> None:
|
||||
if not self._queue.empty():
|
||||
# If there are events in the queue, get the next event and set it.
|
||||
next_event = self._queue.get_nowait()
|
||||
next_event.set()
|
||||
else:
|
||||
# If there are no events in the queue, release the lock.
|
||||
self._locked = False
|
||||
|
||||
|
||||
class SequentialRoutedAgent(RoutedAgent):
|
||||
"""A subclass of :class:`agentdhal_core.RoutedAgent` that ensures
|
||||
that messages of certain types are processed sequentially
|
||||
using a FIFO lock.
|
||||
|
||||
This is useful for agents that need to maintain a strict order of
|
||||
processing messages, such as in a group chat scenario.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
description (str): The description of the agent.
|
||||
sequential_message_types (Sequence[Type[Any]]): A sequence of message types that should be
|
||||
processed sequentially. If a message of one of these types is received,
|
||||
the agent will acquire a FIFO lock to ensure that it is processed
|
||||
before any later messages that are also one of these types.
|
||||
"""
|
||||
|
||||
def __init__(self, description: str, sequential_message_types: Sequence[type[Any]]) -> None:
|
||||
super().__init__(description=description)
|
||||
self._fifo_lock = FIFOLock()
|
||||
self._sequential_message_types = sequential_message_types
|
||||
|
||||
async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None:
|
||||
if any(isinstance(message, sequential_type) for sequential_type in self._sequential_message_types):
|
||||
# Acquire the FIFO lock to ensure that this message is processed
|
||||
# in the order it was received.
|
||||
await self._fifo_lock.acquire()
|
||||
try:
|
||||
return await super().on_message_impl(message, ctx)
|
||||
finally:
|
||||
# Release the FIFO lock to allow the next message to be processed.
|
||||
self._fifo_lock.release()
|
||||
# If the message is not of a sequential type, process it normally.
|
||||
return await super().on_message_impl(message, ctx)
|
||||
@@ -0,0 +1,321 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, List, Mapping, Sequence
|
||||
|
||||
from agentdhal_core import AgentRuntime, Component, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, HandoffMessage, MessageFactory
|
||||
from ...state import SwarmManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
from ._events import GroupChatTermination
|
||||
|
||||
|
||||
class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
"""A group chat manager that selects the next speaker based on handoff message only."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
emit_team_events: bool,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
emit_team_events,
|
||||
)
|
||||
self._current_speaker = self._participant_names[0]
|
||||
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
"""Validate the start messages for the group chat."""
|
||||
# Check if any of the start messages is a handoff message.
|
||||
if messages:
|
||||
for message in messages:
|
||||
if isinstance(message, HandoffMessage):
|
||||
if message.target not in self._participant_names:
|
||||
raise ValueError(
|
||||
f"The target {message.target} is not one of the participants {self._participant_names}. "
|
||||
"If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target."
|
||||
)
|
||||
return
|
||||
|
||||
# Check if there is a handoff message in the thread that is not targeting a valid participant.
|
||||
for existing_message in reversed(self._message_thread):
|
||||
if isinstance(existing_message, HandoffMessage):
|
||||
if existing_message.target not in self._participant_names:
|
||||
raise ValueError(
|
||||
f"The existing handoff target {existing_message.target} is not one of the participants {self._participant_names}. "
|
||||
"If you are resuming Swarm with a new task make sure to include in your task "
|
||||
"a HandoffMessage with a valid participant as the target. For example, if you are "
|
||||
"resuming from a HandoffTermination, make sure the new task is a HandoffMessage "
|
||||
"with a valid participant as the target."
|
||||
)
|
||||
# The latest handoff message should always target a valid participant.
|
||||
# Do not look past the latest handoff message.
|
||||
return
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._current_turn = 0
|
||||
self._message_thread.clear()
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._current_speaker = self._participant_names[0]
|
||||
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Select a speaker from the participants based on handoff message.
|
||||
Looks for the last handoff message in the thread to determine the next speaker.
|
||||
|
||||
.. note::
|
||||
|
||||
This method always returns a single speaker.
|
||||
"""
|
||||
if len(thread) == 0:
|
||||
return [self._current_speaker]
|
||||
for message in reversed(thread):
|
||||
if isinstance(message, HandoffMessage):
|
||||
self._current_speaker = message.target
|
||||
# The latest handoff message should always target a valid participant.
|
||||
assert self._current_speaker in self._participant_names
|
||||
return [self._current_speaker]
|
||||
return self._current_speaker
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = SwarmManagerState(
|
||||
message_thread=[msg.dump() for msg in self._message_thread],
|
||||
current_turn=self._current_turn,
|
||||
current_speaker=self._current_speaker,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
swarm_state = SwarmManagerState.model_validate(state)
|
||||
self._message_thread = [self._message_factory.create(message) for message in swarm_state.message_thread]
|
||||
self._current_turn = swarm_state.current_turn
|
||||
self._current_speaker = swarm_state.current_speaker
|
||||
|
||||
|
||||
class SwarmConfig(BaseModel):
|
||||
"""The declarative configuration for Swarm."""
|
||||
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
participants: List[ComponentModel]
|
||||
termination_condition: ComponentModel | None = None
|
||||
max_turns: int | None = None
|
||||
emit_team_events: bool = False
|
||||
|
||||
|
||||
class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
||||
"""A group chat team that selects the next speaker based on handoff message only.
|
||||
|
||||
The first participant in the list of participants is the initial speaker.
|
||||
The next speaker is selected based on the :class:`~agentdhal_agentchat.messages.HandoffMessage` message
|
||||
sent by the current speaker. If no handoff message is sent, the current speaker
|
||||
continues to be the speaker.
|
||||
|
||||
.. note::
|
||||
|
||||
Unlike :class:`~agentdhal_agentchat.teams.RoundRobinGroupChat` and
|
||||
:class:`~agentdhal_agentchat.teams.SelectorGroupChat`, this group chat
|
||||
team does not support inner teams as participants.
|
||||
|
||||
Args:
|
||||
participants (List[ChatAgent]): The agents participating in the group chat. The first agent in the list is the initial speaker.
|
||||
name (str | None, optional): The name of the group chat, using :attr:`~agentdhal_agentchat.teams.Swarm.DEFAULT_NAME` if not provided.
|
||||
The name is used by a parent team to identify this group chat so it must be unique within the parent team.
|
||||
description (str | None, optional): The description of the group chat, using :attr:`~agentdhal_agentchat.teams.Swarm.DEFAULT_DESCRIPTION` if not provided.
|
||||
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
|
||||
Without a termination condition, the group chat will run indefinitely.
|
||||
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit.
|
||||
custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): A list of custom message types that will be used in the group chat.
|
||||
If you are using custom message types or your agents produces custom message types, you need to specify them here.
|
||||
Make sure your custom message types are subclasses of :class:`~agentdhal_agentchat.messages.BaseAgentEvent` or :class:`~agentdhal_agentchat.messages.BaseChatMessage`.
|
||||
emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
|
||||
|
||||
Basic example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import Swarm
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent(
|
||||
"Alice",
|
||||
model_client=model_client,
|
||||
handoffs=["Bob"],
|
||||
system_message="You are Alice and you only answer questions about yourself.",
|
||||
)
|
||||
agent2 = AssistantAgent(
|
||||
"Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January."
|
||||
)
|
||||
|
||||
termination = MaxMessageTermination(3)
|
||||
team = Swarm([agent1, agent2], termination_condition=termination)
|
||||
|
||||
stream = team.run_stream(task="What is bob's birthday?")
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Using the :class:`~agentdhal_agentchat.conditions.HandoffTermination` for human-in-the-loop handoff:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import Swarm
|
||||
from agentdhal_agentchat.conditions import HandoffTermination, MaxMessageTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.messages import HandoffMessage
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent = AssistantAgent(
|
||||
"Alice",
|
||||
model_client=model_client,
|
||||
handoffs=["user"],
|
||||
system_message="You are Alice and you only answer questions about yourself, ask the user for help if needed.",
|
||||
)
|
||||
termination = HandoffTermination(target="user") | MaxMessageTermination(3)
|
||||
team = Swarm([agent], termination_condition=termination)
|
||||
|
||||
# Start the conversation.
|
||||
await Console(team.run_stream(task="What is bob's birthday?"))
|
||||
|
||||
# Resume with user feedback.
|
||||
await Console(
|
||||
team.run_stream(
|
||||
task=HandoffMessage(source="user", target="Alice", content="Bob's birthday is on 1st January.")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_config_schema = SwarmConfig
|
||||
component_provider_override = "agentdhal_agentchat.teams.Swarm"
|
||||
|
||||
DEFAULT_NAME = "Swarm"
|
||||
DEFAULT_DESCRIPTION = "A team of agents."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
participants: List[ChatAgent],
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
|
||||
emit_team_events: bool = False,
|
||||
) -> None:
|
||||
for participant in participants:
|
||||
if not isinstance(participant, ChatAgent):
|
||||
raise TypeError(f"Participant {participant} must be a ChatAgent.")
|
||||
super().__init__(
|
||||
name=name or self.DEFAULT_NAME,
|
||||
description=description or self.DEFAULT_DESCRIPTION,
|
||||
participants=[participant for participant in participants],
|
||||
group_chat_manager_name="SwarmGroupChatManager",
|
||||
group_chat_manager_class=SwarmGroupChatManager,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
custom_message_types=custom_message_types,
|
||||
emit_team_events=emit_team_events,
|
||||
)
|
||||
# The first participant must be able to produce handoff messages.
|
||||
first_participant = self._participants[0]
|
||||
assert isinstance(first_participant, ChatAgent)
|
||||
if HandoffMessage not in first_participant.produced_message_types:
|
||||
raise ValueError("The first participant must be able to produce a handoff messages.")
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
message_factory: MessageFactory,
|
||||
) -> Callable[[], SwarmGroupChatManager]:
|
||||
def _factory() -> SwarmGroupChatManager:
|
||||
return SwarmGroupChatManager(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
message_factory,
|
||||
self._emit_team_events,
|
||||
)
|
||||
|
||||
return _factory
|
||||
|
||||
def _to_config(self) -> SwarmConfig:
|
||||
participants = [participant.dump_component() for participant in self._participants]
|
||||
termination_condition = self._termination_condition.dump_component() if self._termination_condition else None
|
||||
return SwarmConfig(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
participants=participants,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=self._max_turns,
|
||||
emit_team_events=self._emit_team_events,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: SwarmConfig) -> "Swarm":
|
||||
participants = [ChatAgent.load_component(participant) for participant in config.participants]
|
||||
termination_condition = (
|
||||
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
|
||||
)
|
||||
return cls(
|
||||
participants,
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=config.max_turns,
|
||||
emit_team_events=config.emit_team_events,
|
||||
)
|
||||
4
agent_dhal/agentdhal_agentchat/tools/__init__.py
Normal file
4
agent_dhal/agentdhal_agentchat/tools/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from ._agent import AgentTool
|
||||
from ._team import TeamTool
|
||||
|
||||
__all__ = ["AgentTool", "TeamTool"]
|
||||
93
agent_dhal/agentdhal_agentchat/tools/_agent.py
Normal file
93
agent_dhal/agentdhal_agentchat/tools/_agent.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from agentdhal_core import Component, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from agentdhal_agentchat.agents import BaseChatAgent
|
||||
|
||||
from ._task_runner_tool import TaskRunnerTool
|
||||
|
||||
|
||||
class AgentToolConfig(BaseModel):
|
||||
"""Configuration for the AgentTool."""
|
||||
|
||||
agent: ComponentModel
|
||||
"""The agent to be used for running the task."""
|
||||
|
||||
return_value_as_last_message: bool = False
|
||||
"""Whether to return the value as the last message of the task result."""
|
||||
|
||||
|
||||
class AgentTool(TaskRunnerTool, Component[AgentToolConfig]):
|
||||
"""Tool that can be used to run a task using an agent.
|
||||
|
||||
The tool returns the result of the task execution as a :class:`~agentdhal_agentchat.base.TaskResult` object.
|
||||
|
||||
.. important::
|
||||
When using AgentTool, you **must** disable parallel tool calls in the model client configuration
|
||||
to avoid concurrency issues. Agents cannot run concurrently as they maintain internal state
|
||||
that would conflict with parallel execution. For example, set ``parallel_tool_calls=False``
|
||||
for :class:`~agentdhal_extensions.models.openai.OpenAIChatCompletionClient` and
|
||||
:class:`~agentdhal_extensions.models.openai.AzureOpenAIChatCompletionClient`.
|
||||
|
||||
Args:
|
||||
agent (BaseChatAgent): The agent to be used for running the task.
|
||||
return_value_as_last_message (bool): Whether to use the last message content of the task result
|
||||
as the return value of the tool in :meth:`~agentdhal_agentchat.tools.TaskRunnerTool.return_value_as_string`.
|
||||
If set to True, the last message content will be returned as a string.
|
||||
If set to False, the tool will return all messages in the task result as a string concatenated together,
|
||||
with each message prefixed by its source (e.g., "writer: ...", "assistant: ...").
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.tools import AgentTool
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1")
|
||||
writer = AssistantAgent(
|
||||
name="writer",
|
||||
description="A writer agent for generating text.",
|
||||
model_client=model_client,
|
||||
system_message="Write well.",
|
||||
)
|
||||
writer_tool = AgentTool(agent=writer)
|
||||
|
||||
# Create model client with parallel tool calls disabled for the main agent
|
||||
main_model_client = OpenAIChatCompletionClient(model="gpt-4.1", parallel_tool_calls=False)
|
||||
assistant = AssistantAgent(
|
||||
name="assistant",
|
||||
model_client=main_model_client,
|
||||
tools=[writer_tool],
|
||||
system_message="You are a helpful assistant.",
|
||||
)
|
||||
await Console(assistant.run_stream(task="Write a poem about the sea."))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_config_schema = AgentToolConfig
|
||||
component_provider_override = "agentdhal_agentchat.tools.AgentTool"
|
||||
|
||||
def __init__(self, agent: BaseChatAgent, return_value_as_last_message: bool = False) -> None:
|
||||
self._agent = agent
|
||||
super().__init__(
|
||||
agent, agent.name, agent.description, return_value_as_last_message=return_value_as_last_message
|
||||
)
|
||||
|
||||
def _to_config(self) -> AgentToolConfig:
|
||||
return AgentToolConfig(
|
||||
agent=self._agent.dump_component(),
|
||||
return_value_as_last_message=self._return_value_as_last_message,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: AgentToolConfig) -> Self:
|
||||
return cls(BaseChatAgent.load_component(config.agent), config.return_value_as_last_message)
|
||||
72
agent_dhal/agentdhal_agentchat/tools/_task_runner_tool.py
Normal file
72
agent_dhal/agentdhal_agentchat/tools/_task_runner_tool.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from abc import ABC
|
||||
from typing import Annotated, Any, AsyncGenerator, List, Mapping
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import BaseStreamTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..agents import BaseChatAgent
|
||||
from ..base import TaskResult
|
||||
from ..messages import BaseAgentEvent, BaseChatMessage
|
||||
from ..teams import BaseGroupChat
|
||||
|
||||
|
||||
class TaskRunnerToolArgs(BaseModel):
|
||||
"""Input for the TaskRunnerTool."""
|
||||
|
||||
task: Annotated[str, "The task to be executed."]
|
||||
|
||||
|
||||
class TaskRunnerTool(BaseStreamTool[TaskRunnerToolArgs, BaseAgentEvent | BaseChatMessage, TaskResult], ABC):
|
||||
"""An base class for tool that can be used to run a task using a team or an agent."""
|
||||
|
||||
component_type = "tool"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_runner: BaseGroupChat | BaseChatAgent,
|
||||
name: str,
|
||||
description: str,
|
||||
return_value_as_last_message: bool,
|
||||
) -> None:
|
||||
self._task_runner = task_runner
|
||||
self._return_value_as_last_message = return_value_as_last_message
|
||||
super().__init__(
|
||||
args_type=TaskRunnerToolArgs,
|
||||
return_type=TaskResult,
|
||||
name=name,
|
||||
description=description,
|
||||
strict=True,
|
||||
)
|
||||
|
||||
async def run(self, args: TaskRunnerToolArgs, cancellation_token: CancellationToken) -> TaskResult:
|
||||
"""Run the task and return the result."""
|
||||
return await self._task_runner.run(task=args.task, cancellation_token=cancellation_token)
|
||||
|
||||
async def run_stream(
|
||||
self, args: TaskRunnerToolArgs, cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
|
||||
"""Run the task and yield events or messages as they are produced, the final :class:`TaskResult`
|
||||
will be yielded at the end."""
|
||||
async for event in self._task_runner.run_stream(task=args.task, cancellation_token=cancellation_token):
|
||||
yield event
|
||||
|
||||
def return_value_as_string(self, value: TaskResult) -> str:
|
||||
"""Convert the task result to a string."""
|
||||
if self._return_value_as_last_message:
|
||||
if value.messages and isinstance(value.messages[-1], BaseChatMessage):
|
||||
return value.messages[-1].to_model_text()
|
||||
raise ValueError("The last message is not a BaseChatMessage.")
|
||||
parts: List[str] = []
|
||||
for message in value.messages:
|
||||
if isinstance(message, BaseChatMessage):
|
||||
if message.source == "user":
|
||||
continue
|
||||
parts.append(f"{message.source}: {message.to_model_text()}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
async def save_state_json(self) -> Mapping[str, Any]:
|
||||
return await self._task_runner.save_state()
|
||||
|
||||
async def load_state_json(self, state: Mapping[str, Any]) -> None:
|
||||
await self._task_runner.load_state(state)
|
||||
133
agent_dhal/agentdhal_agentchat/tools/_team.py
Normal file
133
agent_dhal/agentdhal_agentchat/tools/_team.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from agentdhal_core import Component, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from agentdhal_agentchat.teams import BaseGroupChat
|
||||
|
||||
from ._task_runner_tool import TaskRunnerTool
|
||||
|
||||
|
||||
class TeamToolConfig(BaseModel):
|
||||
"""Configuration for the TeamTool."""
|
||||
|
||||
name: str
|
||||
"""The name of the tool."""
|
||||
description: str
|
||||
"""The name and description of the tool."""
|
||||
team: ComponentModel
|
||||
"""The team to be used for running the task."""
|
||||
return_value_as_last_message: bool = False
|
||||
"""Whether to return the value as the last message of the task result."""
|
||||
|
||||
|
||||
class TeamTool(TaskRunnerTool, Component[TeamToolConfig]):
|
||||
"""Tool that can be used to run a task.
|
||||
|
||||
The tool returns the result of the task execution as a :class:`~agentdhal_agentchat.base.TaskResult` object.
|
||||
|
||||
.. important::
|
||||
When using TeamTool, you **must** disable parallel tool calls in the model client configuration
|
||||
to avoid concurrency issues. Teams cannot run concurrently as they maintain internal state
|
||||
that would conflict with parallel execution. For example, set ``parallel_tool_calls=False``
|
||||
for :class:`~agentdhal_extensions.models.openai.OpenAIChatCompletionClient` and
|
||||
:class:`~agentdhal_extensions.models.openai.AzureOpenAIChatCompletionClient`.
|
||||
|
||||
Args:
|
||||
team (BaseGroupChat): The team to be used for running the task.
|
||||
name (str): The name of the tool.
|
||||
description (str): The description of the tool.
|
||||
return_value_as_last_message (bool): Whether to use the last message content of the task result
|
||||
as the return value of the tool in :meth:`~agentdhal_agentchat.tools.TaskRunnerTool.return_value_as_string`.
|
||||
If set to True, the last message content will be returned as a string.
|
||||
If set to False, the tool will return all messages in the task result as a string concatenated together,
|
||||
with each message prefixed by its source (e.g., "writer: ...", "assistant: ...").
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import SourceMatchTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.tools import TeamTool
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Disable parallel tool calls when using TeamTool
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1")
|
||||
|
||||
writer = AssistantAgent(name="writer", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
reviewer = AssistantAgent(
|
||||
name="reviewer", model_client=model_client, system_message="You are a critical reviewer."
|
||||
)
|
||||
summarizer = AssistantAgent(
|
||||
name="summarizer",
|
||||
model_client=model_client,
|
||||
system_message="You combine the review and produce a revised response.",
|
||||
)
|
||||
team = RoundRobinGroupChat(
|
||||
[writer, reviewer, summarizer], termination_condition=SourceMatchTermination(sources=["summarizer"])
|
||||
)
|
||||
|
||||
# Create a TeamTool that uses the team to run tasks, returning the last message as the result.
|
||||
tool = TeamTool(
|
||||
team=team,
|
||||
name="writing_team",
|
||||
description="A tool for writing tasks.",
|
||||
return_value_as_last_message=True,
|
||||
)
|
||||
|
||||
# Create model client with parallel tool calls disabled for the main agent
|
||||
main_model_client = OpenAIChatCompletionClient(model="gpt-4.1", parallel_tool_calls=False)
|
||||
main_agent = AssistantAgent(
|
||||
name="main_agent",
|
||||
model_client=main_model_client,
|
||||
system_message="You are a helpful assistant that can use the writing tool.",
|
||||
tools=[tool],
|
||||
)
|
||||
# For handling each events manually.
|
||||
# async for message in main_agent.run_stream(
|
||||
# task="Write a short story about a robot learning to love.",
|
||||
# ):
|
||||
# print(message)
|
||||
# Use Console to display the messages in a more readable format.
|
||||
await Console(
|
||||
main_agent.run_stream(
|
||||
task="Write a short story about a robot learning to love.",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_config_schema = TeamToolConfig
|
||||
component_provider_override = "agentdhal_agentchat.tools.TeamTool"
|
||||
|
||||
def __init__(
|
||||
self, team: BaseGroupChat, name: str, description: str, return_value_as_last_message: bool = False
|
||||
) -> None:
|
||||
self._team = team
|
||||
super().__init__(team, name, description, return_value_as_last_message=return_value_as_last_message)
|
||||
|
||||
def _to_config(self) -> TeamToolConfig:
|
||||
return TeamToolConfig(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
team=self._team.dump_component(),
|
||||
return_value_as_last_message=self._return_value_as_last_message,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: TeamToolConfig) -> Self:
|
||||
return cls(
|
||||
BaseGroupChat.load_component(config.team),
|
||||
config.name,
|
||||
config.description,
|
||||
config.return_value_as_last_message,
|
||||
)
|
||||
7
agent_dhal/agentdhal_agentchat/ui/__init__.py
Normal file
7
agent_dhal/agentdhal_agentchat/ui/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
This module implements utility classes for formatting/printing agent messages.
|
||||
"""
|
||||
|
||||
from ._console import Console, UserInputManager
|
||||
|
||||
__all__ = ["Console", "UserInputManager"]
|
||||
204
agent_dhal/agentdhal_agentchat/ui/_console.py
Normal file
204
agent_dhal/agentdhal_agentchat/ui/_console.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, cast
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.models import RequestUsage
|
||||
|
||||
from agentdhal_agentchat.agents import UserProxyAgent
|
||||
from agentdhal_agentchat.base import Response, TaskResult
|
||||
from agentdhal_agentchat.messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
MultiModalMessage,
|
||||
UserInputRequestedEvent,
|
||||
)
|
||||
|
||||
|
||||
def _is_running_in_iterm() -> bool:
|
||||
return os.getenv("TERM_PROGRAM") == "iTerm.app"
|
||||
|
||||
|
||||
def _is_output_a_tty() -> bool:
|
||||
return sys.stdout.isatty()
|
||||
|
||||
|
||||
SyncInputFunc = Callable[[str], str]
|
||||
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
|
||||
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
|
||||
|
||||
T = TypeVar("T", bound=TaskResult | Response)
|
||||
|
||||
|
||||
class UserInputManager:
|
||||
def __init__(self, callback: InputFuncType):
|
||||
self.input_events: Dict[str, asyncio.Event] = {}
|
||||
self.callback = callback
|
||||
|
||||
def get_wrapped_callback(self) -> AsyncInputFunc:
|
||||
async def user_input_func_wrapper(prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
|
||||
# Lookup the event for the prompt, if it exists wait for it.
|
||||
# If it doesn't exist, create it and store it.
|
||||
# Get request ID:
|
||||
request_id = UserProxyAgent.InputRequestContext.request_id()
|
||||
if request_id in self.input_events:
|
||||
event = self.input_events[request_id]
|
||||
else:
|
||||
event = asyncio.Event()
|
||||
self.input_events[request_id] = event
|
||||
|
||||
await event.wait()
|
||||
|
||||
del self.input_events[request_id]
|
||||
|
||||
if iscoroutinefunction(self.callback):
|
||||
# Cast to AsyncInputFunc for proper typing
|
||||
async_func = cast(AsyncInputFunc, self.callback)
|
||||
return await async_func(prompt, cancellation_token)
|
||||
else:
|
||||
# Cast to SyncInputFunc for proper typing
|
||||
sync_func = cast(SyncInputFunc, self.callback)
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, sync_func, prompt)
|
||||
|
||||
return user_input_func_wrapper
|
||||
|
||||
def notify_event_received(self, request_id: str) -> None:
|
||||
if request_id in self.input_events:
|
||||
self.input_events[request_id].set()
|
||||
else:
|
||||
event = asyncio.Event()
|
||||
self.input_events[request_id] = event
|
||||
|
||||
|
||||
def aprint(output: str, end: str = "\n", flush: bool = False) -> Awaitable[None]:
|
||||
return asyncio.to_thread(print, output, end=end, flush=flush)
|
||||
|
||||
|
||||
async def Console(
|
||||
stream: AsyncGenerator[BaseAgentEvent | BaseChatMessage | T, None],
|
||||
*,
|
||||
no_inline_images: bool = False,
|
||||
output_stats: bool = False,
|
||||
user_input_manager: UserInputManager | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Consumes the message stream from :meth:`~agentdhal_agentchat.base.TaskRunner.run_stream`
|
||||
or :meth:`~agentdhal_agentchat.base.ChatAgent.on_messages_stream` and renders the messages to the console.
|
||||
Returns the last processed TaskResult or Response.
|
||||
|
||||
.. note::
|
||||
|
||||
`output_stats` is experimental and the stats may not be accurate.
|
||||
It will be improved in future releases.
|
||||
|
||||
Args:
|
||||
stream (AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None] | AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]): Message stream to render.
|
||||
This can be from :meth:`~agentdhal_agentchat.base.TaskRunner.run_stream` or :meth:`~agentdhal_agentchat.base.ChatAgent.on_messages_stream`.
|
||||
no_inline_images (bool, optional): If terminal is iTerm2 will render images inline. Use this to disable this behavior. Defaults to False.
|
||||
output_stats (bool, optional): (Experimental) If True, will output a summary of the messages and inline token usage info. Defaults to False.
|
||||
|
||||
Returns:
|
||||
last_processed: A :class:`~agentdhal_agentchat.base.TaskResult` if the stream is from :meth:`~agentdhal_agentchat.base.TaskRunner.run_stream`
|
||||
or a :class:`~agentdhal_agentchat.base.Response` if the stream is from :meth:`~agentdhal_agentchat.base.ChatAgent.on_messages_stream`.
|
||||
"""
|
||||
render_image_iterm = _is_running_in_iterm() and _is_output_a_tty() and not no_inline_images
|
||||
start_time = time.time()
|
||||
total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
last_processed: Optional[T] = None
|
||||
|
||||
streaming_chunks: List[str] = []
|
||||
|
||||
async for message in stream:
|
||||
if isinstance(message, TaskResult):
|
||||
duration = time.time() - start_time
|
||||
if output_stats:
|
||||
output = (
|
||||
f"{'-' * 10} Summary {'-' * 10}\n"
|
||||
f"Number of messages: {len(message.messages)}\n"
|
||||
f"Finish reason: {message.stop_reason}\n"
|
||||
f"Total prompt tokens: {total_usage.prompt_tokens}\n"
|
||||
f"Total completion tokens: {total_usage.completion_tokens}\n"
|
||||
f"Duration: {duration:.2f} seconds\n"
|
||||
)
|
||||
await aprint(output, end="", flush=True)
|
||||
|
||||
# mypy ignore
|
||||
last_processed = message # type: ignore
|
||||
|
||||
elif isinstance(message, Response):
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Print final response.
|
||||
if isinstance(message.chat_message, MultiModalMessage):
|
||||
final_content = message.chat_message.to_text(iterm=render_image_iterm)
|
||||
else:
|
||||
final_content = message.chat_message.to_text()
|
||||
output = f"{'-' * 10} {message.chat_message.source} {'-' * 10}\n{final_content}\n"
|
||||
if message.chat_message.models_usage:
|
||||
if output_stats:
|
||||
output += f"[Prompt tokens: {message.chat_message.models_usage.prompt_tokens}, Completion tokens: {message.chat_message.models_usage.completion_tokens}]\n"
|
||||
total_usage.completion_tokens += message.chat_message.models_usage.completion_tokens
|
||||
total_usage.prompt_tokens += message.chat_message.models_usage.prompt_tokens
|
||||
await aprint(output, end="", flush=True)
|
||||
|
||||
# Print summary.
|
||||
if output_stats:
|
||||
if message.inner_messages is not None:
|
||||
num_inner_messages = len(message.inner_messages)
|
||||
else:
|
||||
num_inner_messages = 0
|
||||
output = (
|
||||
f"{'-' * 10} Summary {'-' * 10}\n"
|
||||
f"Number of inner messages: {num_inner_messages}\n"
|
||||
f"Total prompt tokens: {total_usage.prompt_tokens}\n"
|
||||
f"Total completion tokens: {total_usage.completion_tokens}\n"
|
||||
f"Duration: {duration:.2f} seconds\n"
|
||||
)
|
||||
await aprint(output, end="", flush=True)
|
||||
|
||||
# mypy ignore
|
||||
last_processed = message # type: ignore
|
||||
# We don't want to print UserInputRequestedEvent messages, we just use them to signal the user input event.
|
||||
elif isinstance(message, UserInputRequestedEvent):
|
||||
if user_input_manager is not None:
|
||||
user_input_manager.notify_event_received(message.request_id)
|
||||
else:
|
||||
# Cast required for mypy to be happy
|
||||
message = cast(BaseAgentEvent | BaseChatMessage, message) # type: ignore
|
||||
if not streaming_chunks:
|
||||
# Print message sender.
|
||||
await aprint(
|
||||
f"{'-' * 10} {message.__class__.__name__} ({message.source}) {'-' * 10}", end="\n", flush=True
|
||||
)
|
||||
if isinstance(message, ModelClientStreamingChunkEvent):
|
||||
await aprint(message.to_text(), end="", flush=True)
|
||||
streaming_chunks.append(message.content)
|
||||
else:
|
||||
if streaming_chunks:
|
||||
streaming_chunks.clear()
|
||||
# Chunked messages are already printed, so we just print a newline.
|
||||
await aprint("", end="\n", flush=True)
|
||||
elif isinstance(message, MultiModalMessage):
|
||||
await aprint(message.to_text(iterm=render_image_iterm), end="\n", flush=True)
|
||||
else:
|
||||
await aprint(message.to_text(), end="\n", flush=True)
|
||||
if message.models_usage:
|
||||
if output_stats:
|
||||
await aprint(
|
||||
f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]",
|
||||
end="\n",
|
||||
flush=True,
|
||||
)
|
||||
total_usage.completion_tokens += message.models_usage.completion_tokens
|
||||
total_usage.prompt_tokens += message.models_usage.prompt_tokens
|
||||
|
||||
if last_processed is None:
|
||||
raise ValueError("No TaskResult or Response was processed.")
|
||||
|
||||
return last_processed
|
||||
7
agent_dhal/agentdhal_agentchat/utils/__init__.py
Normal file
7
agent_dhal/agentdhal_agentchat/utils/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
This module implements various utilities common to AgentChat agents and teams.
|
||||
"""
|
||||
|
||||
from ._utils import content_to_str, remove_images
|
||||
|
||||
__all__ = ["content_to_str", "remove_images"]
|
||||
44
agent_dhal/agentdhal_agentchat/utils/_utils.py
Normal file
44
agent_dhal/agentdhal_agentchat/utils/_utils.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import List, Union
|
||||
|
||||
from agentdhal_core import FunctionCall, Image
|
||||
from agentdhal_core.models import FunctionExecutionResult, LLMMessage, UserMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Type aliases for convenience
|
||||
_StructuredContent = BaseModel
|
||||
_UserContent = Union[str, List[Union[str, Image]]]
|
||||
_AssistantContent = Union[str, List[FunctionCall]]
|
||||
_FunctionExecutionContent = List[FunctionExecutionResult]
|
||||
_SystemContent = str
|
||||
|
||||
|
||||
def content_to_str(
|
||||
content: _UserContent | _AssistantContent | _FunctionExecutionContent | _SystemContent | _StructuredContent,
|
||||
) -> str:
|
||||
"""Convert the content of an LLMMessage to a string."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, BaseModel):
|
||||
return content.model_dump_json()
|
||||
else:
|
||||
result: List[str] = []
|
||||
for c in content:
|
||||
if isinstance(c, str):
|
||||
result.append(c)
|
||||
elif isinstance(c, Image):
|
||||
result.append("<image>")
|
||||
else:
|
||||
result.append(str(c))
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def remove_images(messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
"""Remove images from a list of LLMMessages"""
|
||||
str_messages: List[LLMMessage] = []
|
||||
for message in messages:
|
||||
if isinstance(message, UserMessage) and isinstance(message.content, list):
|
||||
str_messages.append(UserMessage(content=content_to_str(message.content), source=message.source))
|
||||
else:
|
||||
str_messages.append(message)
|
||||
return str_messages
|
||||
142
agent_dhal/agentdhal_core/__init__.py
Normal file
142
agent_dhal/agentdhal_core/__init__.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# AgentDhal Core Module - Self-contained version
|
||||
__version__ = "1.0.0"
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_instantiation import AgentInstantiationContext
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_proxy import AgentProxy
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._agent_type import AgentType
|
||||
from ._base_agent import BaseAgent
|
||||
from ._cache_store import CacheStore, InMemoryStore
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._closure_agent import ClosureAgent, ClosureContext
|
||||
from ._component_config import (
|
||||
Component,
|
||||
ComponentBase,
|
||||
ComponentFromConfig,
|
||||
ComponentLoader,
|
||||
ComponentModel,
|
||||
ComponentSchemaType,
|
||||
ComponentToConfig,
|
||||
ComponentType,
|
||||
is_component_class,
|
||||
is_component_instance,
|
||||
)
|
||||
from ._constants import (
|
||||
EVENT_LOGGER_NAME as EVENT_LOGGER_NAME_ALIAS,
|
||||
)
|
||||
from ._constants import (
|
||||
ROOT_LOGGER_NAME as ROOT_LOGGER_NAME_ALIAS,
|
||||
)
|
||||
from ._constants import (
|
||||
TRACE_LOGGER_NAME as TRACE_LOGGER_NAME_ALIAS,
|
||||
)
|
||||
from ._default_subscription import DefaultSubscription, default_subscription, type_subscription
|
||||
from ._default_topic import DefaultTopicId
|
||||
from ._image import Image
|
||||
from ._intervention import (
|
||||
DefaultInterventionHandler,
|
||||
DropMessage,
|
||||
InterventionHandler,
|
||||
)
|
||||
from ._message_context import MessageContext
|
||||
from ._message_handler_context import MessageHandlerContext
|
||||
from ._routed_agent import RoutedAgent, event, message_handler, rpc
|
||||
from ._serialization import (
|
||||
JSON_DATA_CONTENT_TYPE as JSON_DATA_CONTENT_TYPE_ALIAS,
|
||||
)
|
||||
from ._serialization import (
|
||||
PROTOBUF_DATA_CONTENT_TYPE as PROTOBUF_DATA_CONTENT_TYPE_ALIAS,
|
||||
)
|
||||
from ._serialization import (
|
||||
MessageSerializer,
|
||||
UnknownPayload,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
||||
from ._subscription import Subscription
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._telemetry import (
|
||||
trace_create_agent_span,
|
||||
trace_invoke_agent_span,
|
||||
trace_tool_span,
|
||||
)
|
||||
from ._topic import TopicId
|
||||
from ._type_prefix_subscription import TypePrefixSubscription
|
||||
from ._type_subscription import TypeSubscription
|
||||
from ._types import FunctionCall
|
||||
|
||||
EVENT_LOGGER_NAME = EVENT_LOGGER_NAME_ALIAS
|
||||
"""The name of the logger used for structured events."""
|
||||
|
||||
ROOT_LOGGER_NAME = ROOT_LOGGER_NAME_ALIAS
|
||||
"""The name of the root logger."""
|
||||
|
||||
TRACE_LOGGER_NAME = TRACE_LOGGER_NAME_ALIAS
|
||||
"""Logger name used for developer intended trace logging. The content and format of this log should not be depended upon."""
|
||||
|
||||
JSON_DATA_CONTENT_TYPE = JSON_DATA_CONTENT_TYPE_ALIAS
|
||||
"""The content type for JSON data."""
|
||||
|
||||
PROTOBUF_DATA_CONTENT_TYPE = PROTOBUF_DATA_CONTENT_TYPE_ALIAS
|
||||
"""The content type for Protobuf data."""
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"AgentId",
|
||||
"AgentProxy",
|
||||
"AgentMetadata",
|
||||
"AgentRuntime",
|
||||
"BaseAgent",
|
||||
"CacheStore",
|
||||
"InMemoryStore",
|
||||
"CancellationToken",
|
||||
"AgentInstantiationContext",
|
||||
"TopicId",
|
||||
"Subscription",
|
||||
"MessageContext",
|
||||
"AgentType",
|
||||
"SubscriptionInstantiationContext",
|
||||
"MessageHandlerContext",
|
||||
"MessageSerializer",
|
||||
"try_get_known_serializers_for_type",
|
||||
"UnknownPayload",
|
||||
"Image",
|
||||
"RoutedAgent",
|
||||
"ClosureAgent",
|
||||
"ClosureContext",
|
||||
"message_handler",
|
||||
"event",
|
||||
"rpc",
|
||||
"FunctionCall",
|
||||
"TypeSubscription",
|
||||
"DefaultSubscription",
|
||||
"DefaultTopicId",
|
||||
"default_subscription",
|
||||
"type_subscription",
|
||||
"TypePrefixSubscription",
|
||||
"JSON_DATA_CONTENT_TYPE",
|
||||
"PROTOBUF_DATA_CONTENT_TYPE",
|
||||
"SingleThreadedAgentRuntime",
|
||||
"ROOT_LOGGER_NAME",
|
||||
"EVENT_LOGGER_NAME",
|
||||
"TRACE_LOGGER_NAME",
|
||||
"Component",
|
||||
"ComponentBase",
|
||||
"ComponentFromConfig",
|
||||
"ComponentLoader",
|
||||
"ComponentModel",
|
||||
"ComponentSchemaType",
|
||||
"ComponentToConfig",
|
||||
"ComponentType",
|
||||
"is_component_class",
|
||||
"is_component_instance",
|
||||
"DropMessage",
|
||||
"InterventionHandler",
|
||||
"DefaultInterventionHandler",
|
||||
"trace_create_agent_span",
|
||||
"trace_invoke_agent_span",
|
||||
"trace_tool_span",
|
||||
]
|
||||
64
agent_dhal/agentdhal_core/_agent.py
Normal file
64
agent_dhal/agentdhal_core/_agent.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Protocol, runtime_checkable
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._message_context import MessageContext
|
||||
|
||||
# Forward declaration for type checking only
|
||||
if TYPE_CHECKING:
|
||||
from ._agent_runtime import AgentRuntime
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agent(Protocol):
|
||||
@property
|
||||
def metadata(self) -> AgentMetadata:
|
||||
"""Metadata of the agent."""
|
||||
...
|
||||
|
||||
@property
|
||||
def id(self) -> AgentId:
|
||||
"""ID of the agent."""
|
||||
...
|
||||
|
||||
async def bind_id_and_runtime(self, id: AgentId, runtime: "AgentRuntime") -> None:
|
||||
"""Function used to bind an Agent instance to an `AgentRuntime`.
|
||||
|
||||
Args:
|
||||
agent_id (AgentId): ID of the agent.
|
||||
runtime (AgentRuntime): AgentRuntime instance to bind the agent to.
|
||||
"""
|
||||
...
|
||||
|
||||
async def on_message(self, message: Any, ctx: MessageContext) -> Any:
|
||||
"""Message handler for the agent. This should only be called by the runtime, not by other agents.
|
||||
|
||||
Args:
|
||||
message (Any): Received message. Type is one of the types in `subscriptions`.
|
||||
ctx (MessageContext): Context of the message.
|
||||
|
||||
Returns:
|
||||
Any: Response to the message. Can be None.
|
||||
|
||||
Raises:
|
||||
asyncio.CancelledError: If the message was cancelled.
|
||||
CantHandleException: If the agent cannot handle the message.
|
||||
"""
|
||||
...
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the agent. The result must be JSON serializable."""
|
||||
...
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load in the state of the agent obtained from `save_state`.
|
||||
|
||||
Args:
|
||||
state (Mapping[str, Any]): State of the agent. Must be JSON serializable.
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Called when the runtime is closed"""
|
||||
...
|
||||
68
agent_dhal/agentdhal_core/_agent_id.py
Normal file
68
agent_dhal/agentdhal_core/_agent_id.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import re
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._agent_type import AgentType
|
||||
|
||||
|
||||
def is_valid_agent_type(value: str) -> bool:
|
||||
return bool(re.match(r"^[\w\-\.]+\Z", value))
|
||||
|
||||
|
||||
class AgentId:
|
||||
"""
|
||||
Agent ID uniquely identifies an agent instance within an agent runtime - including distributed runtime. It is the 'address' of the agent instance for receiving messages.
|
||||
|
||||
See here for more information: :ref:`agentid_and_lifecycle`
|
||||
"""
|
||||
|
||||
def __init__(self, type: str | AgentType, key: str) -> None:
|
||||
if isinstance(type, AgentType):
|
||||
type = type.type
|
||||
|
||||
if not is_valid_agent_type(type):
|
||||
raise ValueError(rf"Invalid agent type: {type}. Allowed values MUST match the regex: `^[\w\-\.]+\Z`")
|
||||
|
||||
self._type = type
|
||||
self._key = key
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._type, self._key))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self._type}/{self._key}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'AgentId(type="{self._type}", key="{self._key}")'
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
if not isinstance(value, AgentId):
|
||||
return False
|
||||
return self._type == value.type and self._key == value.key
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, agent_id: str) -> Self:
|
||||
"""Convert a string of the format ``type/key`` into an AgentId"""
|
||||
items = agent_id.split("/", maxsplit=1)
|
||||
if len(items) != 2:
|
||||
raise ValueError(f"Invalid agent id: {agent_id}")
|
||||
type, key = items[0], items[1]
|
||||
return cls(type, key)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""
|
||||
An identifier that associates an agent with a specific factory function.
|
||||
|
||||
Strings may only be composed of alphanumeric letters (a-z) and (0-9), or underscores (_).
|
||||
"""
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""
|
||||
Agent instance identifier.
|
||||
|
||||
Strings may only be composed of alphanumeric letters (a-z) and (0-9), or underscores (_).
|
||||
"""
|
||||
return self._key
|
||||
126
agent_dhal/agentdhal_core/_agent_instantiation.py
Normal file
126
agent_dhal/agentdhal_core/_agent_instantiation.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, ClassVar, Generator
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_runtime import AgentRuntime
|
||||
|
||||
|
||||
class AgentInstantiationContext:
|
||||
"""A static class that provides context for agent instantiation.
|
||||
|
||||
This static class can be used to access the current runtime and agent ID
|
||||
during agent instantiation -- inside the factory function or the agent's
|
||||
class constructor.
|
||||
|
||||
Example:
|
||||
|
||||
Get the current runtime and agent ID inside the factory function and
|
||||
the agent's constructor:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agentdhal_core import (
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
MessageContext,
|
||||
RoutedAgent,
|
||||
SingleThreadedAgentRuntime,
|
||||
message_handler,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestMessage:
|
||||
content: str
|
||||
|
||||
|
||||
class TestAgent(RoutedAgent):
|
||||
def __init__(self, description: str):
|
||||
super().__init__(description)
|
||||
# Get the current runtime -- we don't use it here, but it's available.
|
||||
_ = AgentInstantiationContext.current_runtime()
|
||||
# Get the current agent ID.
|
||||
agent_id = AgentInstantiationContext.current_agent_id()
|
||||
print(f"Current AgentID from constructor: {agent_id}")
|
||||
|
||||
@message_handler
|
||||
async def handle_test_message(self, message: TestMessage, ctx: MessageContext) -> None:
|
||||
print(f"Received message: {message.content}")
|
||||
|
||||
|
||||
def test_agent_factory() -> TestAgent:
|
||||
# Get the current runtime -- we don't use it here, but it's available.
|
||||
_ = AgentInstantiationContext.current_runtime()
|
||||
# Get the current agent ID.
|
||||
agent_id = AgentInstantiationContext.current_agent_id()
|
||||
print(f"Current AgentID from factory: {agent_id}")
|
||||
return TestAgent(description="Test agent")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create a SingleThreadedAgentRuntime instance.
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Start the runtime.
|
||||
runtime.start()
|
||||
|
||||
# Register the agent type with a factory function.
|
||||
await runtime.register_factory("test_agent", test_agent_factory)
|
||||
|
||||
# Send a message to the agent. The runtime will instantiate the agent and call the message handler.
|
||||
await runtime.send_message(TestMessage(content="Hello, world!"), AgentId("test_agent", "default"))
|
||||
|
||||
# Stop the runtime.
|
||||
await runtime.stop()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError(
|
||||
"AgentInstantiationContext cannot be instantiated. It is a static class that provides context management for agent instantiation."
|
||||
)
|
||||
|
||||
_AGENT_INSTANTIATION_CONTEXT_VAR: ClassVar[ContextVar[tuple[AgentRuntime, AgentId]]] = ContextVar(
|
||||
"_AGENT_INSTANTIATION_CONTEXT_VAR"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def populate_context(cls, ctx: tuple[AgentRuntime, AgentId]) -> Generator[None, Any, None]:
|
||||
""":meta private:"""
|
||||
token = AgentInstantiationContext._AGENT_INSTANTIATION_CONTEXT_VAR.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
AgentInstantiationContext._AGENT_INSTANTIATION_CONTEXT_VAR.reset(token)
|
||||
|
||||
@classmethod
|
||||
def current_runtime(cls) -> AgentRuntime:
|
||||
try:
|
||||
return cls._AGENT_INSTANTIATION_CONTEXT_VAR.get()[0]
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"AgentInstantiationContext.runtime() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def current_agent_id(cls) -> AgentId:
|
||||
try:
|
||||
return cls._AGENT_INSTANTIATION_CONTEXT_VAR.get()[1]
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"AgentInstantiationContext.agent_id() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def is_in_factory_call(cls) -> bool:
|
||||
if cls._AGENT_INSTANTIATION_CONTEXT_VAR.get(None) is None:
|
||||
return False
|
||||
return True
|
||||
7
agent_dhal/agentdhal_core/_agent_metadata.py
Normal file
7
agent_dhal/agentdhal_core/_agent_metadata.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class AgentMetadata(TypedDict):
|
||||
type: str
|
||||
key: str
|
||||
description: str
|
||||
56
agent_dhal/agentdhal_core/_agent_proxy.py
Normal file
56
agent_dhal/agentdhal_core/_agent_proxy.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Mapping
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._agent_runtime import AgentRuntime
|
||||
|
||||
|
||||
class AgentProxy:
|
||||
"""A helper class that allows you to use an :class:`~agentdhal_core.AgentId` in place of its associated :class:`~agentdhal_core.Agent`"""
|
||||
|
||||
def __init__(self, agent: AgentId, runtime: AgentRuntime):
|
||||
self._agent = agent
|
||||
self._runtime = runtime
|
||||
|
||||
@property
|
||||
def id(self) -> AgentId:
|
||||
"""Target agent for this proxy"""
|
||||
return self._agent
|
||||
|
||||
@property
|
||||
def metadata(self) -> Awaitable[AgentMetadata]:
|
||||
"""Metadata of the agent."""
|
||||
return self._runtime.agent_metadata(self._agent)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
sender: AgentId,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
return await self._runtime.send_message(
|
||||
message,
|
||||
recipient=self._agent,
|
||||
sender=sender,
|
||||
cancellation_token=cancellation_token,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the agent. The result must be JSON serializable."""
|
||||
return await self._runtime.agent_save_state(self._agent)
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load in the state of the agent obtained from `save_state`.
|
||||
|
||||
Args:
|
||||
state (Mapping[str, Any]): State of the agent. Must be JSON serializable.
|
||||
"""
|
||||
await self._runtime.agent_load_state(self._agent, state)
|
||||
295
agent_dhal/agentdhal_core/_agent_runtime.py
Normal file
295
agent_dhal/agentdhal_core/_agent_runtime.py
Normal file
@@ -0,0 +1,295 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_type import AgentType
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._serialization import MessageSerializer
|
||||
from ._subscription import Subscription
|
||||
from ._topic import TopicId
|
||||
|
||||
# Undeliverable - error
|
||||
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentRuntime(Protocol):
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
recipient: AgentId,
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
"""Send a message to an agent and get a response.
|
||||
|
||||
Args:
|
||||
message (Any): The message to send.
|
||||
recipient (AgentId): The agent to send the message to.
|
||||
sender (AgentId | None, optional): Agent which sent the message. Should **only** be None if this was sent from no agent, such as directly to the runtime externally. Defaults to None.
|
||||
cancellation_token (CancellationToken | None, optional): Token used to cancel an in progress . Defaults to None.
|
||||
|
||||
Raises:
|
||||
CantHandleException: If the recipient cannot handle the message.
|
||||
UndeliverableException: If the message cannot be delivered.
|
||||
Other: Any other exception raised by the recipient.
|
||||
|
||||
Returns:
|
||||
Any: The response from the agent.
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
topic_id: TopicId,
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
"""Publish a message to all agents in the given namespace, or if no namespace is provided, the namespace of the sender.
|
||||
|
||||
No responses are expected from publishing.
|
||||
|
||||
Args:
|
||||
message (Any): The message to publish.
|
||||
topic_id (TopicId): The topic to publish the message to.
|
||||
sender (AgentId | None, optional): The agent which sent the message. Defaults to None.
|
||||
cancellation_token (CancellationToken | None, optional): Token used to cancel an in progress. Defaults to None.
|
||||
message_id (str | None, optional): The message id. If None, a new message id will be generated. Defaults to None. This message id must be unique. and is recommended to be a UUID.
|
||||
|
||||
Raises:
|
||||
UndeliverableException: If the message cannot be delivered.
|
||||
"""
|
||||
...
|
||||
|
||||
async def register_factory(
|
||||
self,
|
||||
type: str | AgentType,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
*,
|
||||
expected_class: type[T] | None = None,
|
||||
) -> AgentType:
|
||||
"""Register an agent factory with the runtime associated with a specific type. The type must be unique. This API does not add any subscriptions.
|
||||
|
||||
.. note::
|
||||
|
||||
This is a low level API and usually the agent class's `register` method should be used instead, as this also handles subscriptions automatically.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agentdhal_core import AgentRuntime, MessageContext, RoutedAgent, event
|
||||
from agentdhal_core.models import UserMessage
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyMessage:
|
||||
content: str
|
||||
|
||||
|
||||
class MyAgent(RoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("My core agent")
|
||||
|
||||
@event
|
||||
async def handler(self, message: UserMessage, context: MessageContext) -> None:
|
||||
print("Event received: ", message.content)
|
||||
|
||||
|
||||
async def my_agent_factory():
|
||||
return MyAgent()
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime: AgentRuntime = ... # type: ignore
|
||||
await runtime.register_factory("my_agent", lambda: MyAgent())
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Args:
|
||||
type (str): The type of agent this factory creates. It is not the same as agent class name. The `type` parameter is used to differentiate between different factory functions rather than agent classes.
|
||||
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agentdhal_core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
|
||||
expected_class (type[T] | None, optional): The expected class of the agent, used for runtime validation of the factory. Defaults to None. If None, no validation is performed.
|
||||
"""
|
||||
...
|
||||
|
||||
async def register_agent_instance(
|
||||
self,
|
||||
agent_instance: Agent,
|
||||
agent_id: AgentId,
|
||||
) -> AgentId:
|
||||
"""Register an agent instance with the runtime. The type may be reused, but each agent_id must be unique. All agent instances within a type must be of the same object type. This API does not add any subscriptions.
|
||||
|
||||
.. note::
|
||||
|
||||
This is a low level API and usually the agent class's `register_instance` method should be used instead, as this also handles subscriptions automatically.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agentdhal_core import AgentId, AgentRuntime, MessageContext, RoutedAgent, event
|
||||
from agentdhal_core.models import UserMessage
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyMessage:
|
||||
content: str
|
||||
|
||||
|
||||
class MyAgent(RoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("My core agent")
|
||||
|
||||
@event
|
||||
async def handler(self, message: UserMessage, context: MessageContext) -> None:
|
||||
print("Event received: ", message.content)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime: AgentRuntime = ... # type: ignore
|
||||
agent = MyAgent()
|
||||
await runtime.register_agent_instance(
|
||||
agent_instance=agent, agent_id=AgentId(type="my_agent", key="default")
|
||||
)
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Args:
|
||||
agent_instance (Agent): A concrete instance of the agent.
|
||||
agent_id (AgentId): The agent's identifier. The agent's type is `agent_id.type`.
|
||||
"""
|
||||
...
|
||||
|
||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||
"""Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases.
|
||||
|
||||
If the underlying agent is not accessible, this will raise an exception.
|
||||
|
||||
Args:
|
||||
id (AgentId): The agent id.
|
||||
type (Type[T], optional): The expected type of the agent. Defaults to Agent.
|
||||
|
||||
Returns:
|
||||
T: The concrete agent instance.
|
||||
|
||||
Raises:
|
||||
LookupError: If the agent is not found.
|
||||
NotAccessibleError: If the agent is not accessible, for example if it is located remotely.
|
||||
TypeError: If the agent is not of the expected type.
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
async def get(self, id: AgentId, /, *, lazy: bool = ...) -> AgentId: ...
|
||||
|
||||
@overload
|
||||
async def get(self, type: AgentType | str, /, key: str = ..., *, lazy: bool = ...) -> AgentId: ...
|
||||
|
||||
async def get(
|
||||
self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True
|
||||
) -> AgentId: ...
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the entire runtime, including all hosted agents. The only way to restore the state is to pass it to :meth:`load_state`.
|
||||
|
||||
The structure of the state is implementation defined and can be any JSON serializable object.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: The saved state.
|
||||
"""
|
||||
...
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of the entire runtime, including all hosted agents. The state should be the same as the one returned by :meth:`save_state`.
|
||||
|
||||
Args:
|
||||
state (Mapping[str, Any]): The saved state.
|
||||
"""
|
||||
...
|
||||
|
||||
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
"""Get the metadata for an agent.
|
||||
|
||||
Args:
|
||||
agent (AgentId): The agent id.
|
||||
|
||||
Returns:
|
||||
AgentMetadata: The agent metadata.
|
||||
"""
|
||||
...
|
||||
|
||||
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
|
||||
"""Save the state of a single agent.
|
||||
|
||||
The structure of the state is implementation defined and can be any JSON serializable object.
|
||||
|
||||
Args:
|
||||
agent (AgentId): The agent id.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: The saved state.
|
||||
"""
|
||||
...
|
||||
|
||||
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of a single agent.
|
||||
|
||||
Args:
|
||||
agent (AgentId): The agent id.
|
||||
state (Mapping[str, Any]): The saved state.
|
||||
"""
|
||||
...
|
||||
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
"""Add a new subscription that the runtime should fulfill when processing published messages
|
||||
|
||||
Args:
|
||||
subscription (Subscription): The subscription to add
|
||||
"""
|
||||
...
|
||||
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
"""Remove a subscription from the runtime
|
||||
|
||||
Args:
|
||||
id (str): id of the subscription to remove
|
||||
|
||||
Raises:
|
||||
LookupError: If the subscription does not exist
|
||||
"""
|
||||
...
|
||||
|
||||
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
|
||||
"""Add a new message serialization serializer to the runtime
|
||||
|
||||
Note: This will deduplicate serializers based on the type_name and data_content_type properties
|
||||
|
||||
Args:
|
||||
serializer (MessageSerializer[Any] | Sequence[MessageSerializer[Any]]): The serializer/s to add
|
||||
"""
|
||||
...
|
||||
7
agent_dhal/agentdhal_core/_agent_type.py
Normal file
7
agent_dhal/agentdhal_core/_agent_type.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class AgentType:
|
||||
type: str
|
||||
"""String representation of this agent type."""
|
||||
254
agent_dhal/agentdhal_core/_base_agent.py
Normal file
254
agent_dhal/agentdhal_core/_base_agent.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar, final
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_instantiation import AgentInstantiationContext
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._agent_type import AgentType
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._message_context import MessageContext
|
||||
from ._serialization import MessageSerializer, try_get_known_serializers_for_type
|
||||
from ._subscription import Subscription, UnboundSubscription
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._topic import TopicId
|
||||
from ._type_prefix_subscription import TypePrefixSubscription
|
||||
from ._type_subscription import TypeSubscription
|
||||
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")
|
||||
|
||||
|
||||
# Decorator for adding an unbound subscription to an agent
|
||||
def subscription_factory(subscription: UnboundSubscription) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
|
||||
""":meta private:"""
|
||||
|
||||
def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]:
|
||||
cls.internal_unbound_subscriptions_list.append(subscription)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def handles(
|
||||
type: Type[Any], serializer: MessageSerializer[Any] | List[MessageSerializer[Any]] | None = None
|
||||
) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
|
||||
def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]:
|
||||
if serializer is None:
|
||||
serializer_list = try_get_known_serializers_for_type(type)
|
||||
else:
|
||||
serializer_list = [serializer] if not isinstance(serializer, Sequence) else serializer
|
||||
|
||||
if len(serializer_list) == 0:
|
||||
raise ValueError(f"No serializers found for type {type}. Please provide an explicit serializer.")
|
||||
|
||||
cls.internal_extra_handles_types.append((type, serializer_list))
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class BaseAgent(ABC, Agent):
|
||||
internal_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = []
|
||||
""":meta private:"""
|
||||
internal_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = []
|
||||
""":meta private:"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Automatically set class_variable in each subclass so that they are not shared between subclasses
|
||||
cls.internal_extra_handles_types = []
|
||||
cls.internal_unbound_subscriptions_list = []
|
||||
|
||||
@classmethod
|
||||
def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]:
|
||||
return cls.internal_extra_handles_types
|
||||
|
||||
@classmethod
|
||||
def _unbound_subscriptions(cls) -> List[UnboundSubscription]:
|
||||
return cls.internal_unbound_subscriptions_list
|
||||
|
||||
@property
|
||||
def metadata(self) -> AgentMetadata:
|
||||
assert self._id is not None
|
||||
return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description)
|
||||
|
||||
def __init__(self, description: str) -> None:
|
||||
if AgentInstantiationContext.is_in_factory_call():
|
||||
self._runtime: AgentRuntime = AgentInstantiationContext.current_runtime()
|
||||
self._id = AgentInstantiationContext.current_agent_id()
|
||||
if not isinstance(description, str):
|
||||
raise ValueError("Agent description must be a string")
|
||||
self._description = description
|
||||
|
||||
async def bind_id_and_runtime(self, id: AgentId, runtime: AgentRuntime) -> None:
|
||||
if hasattr(self, "_id"):
|
||||
if self._id != id:
|
||||
raise RuntimeError("Agent is already bound to a different ID")
|
||||
|
||||
if hasattr(self, "_runtime"):
|
||||
if self._runtime != runtime:
|
||||
raise RuntimeError("Agent is already bound to a different runtime")
|
||||
|
||||
self._id = id
|
||||
self._runtime = runtime
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.id.type
|
||||
|
||||
@property
|
||||
def id(self) -> AgentId:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def runtime(self) -> AgentRuntime:
|
||||
return self._runtime
|
||||
|
||||
@final
|
||||
async def on_message(self, message: Any, ctx: MessageContext) -> Any:
|
||||
return await self.on_message_impl(message, ctx)
|
||||
|
||||
@abstractmethod
|
||||
async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any: ...
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
recipient: AgentId,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
"""See :py:meth:`agentdhal_core.AgentRuntime.send_message` for more information."""
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
return await self._runtime.send_message(
|
||||
message,
|
||||
sender=self.id,
|
||||
recipient=recipient,
|
||||
cancellation_token=cancellation_token,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
topic_id: TopicId,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> None:
|
||||
await self._runtime.publish_message(message, topic_id, sender=self.id, cancellation_token=cancellation_token)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
warnings.warn("save_state not implemented", stacklevel=2)
|
||||
return {}
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
warnings.warn("load_state not implemented", stacklevel=2)
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_instance(
|
||||
self,
|
||||
runtime: AgentRuntime,
|
||||
agent_id: AgentId,
|
||||
*,
|
||||
skip_class_subscriptions: bool = True,
|
||||
skip_direct_message_subscription: bool = False,
|
||||
) -> AgentId:
|
||||
"""
|
||||
This function is similar to `register` but is used for registering an instance of an agent. A subscription based on the agent ID is created and added to the runtime.
|
||||
"""
|
||||
agent_id = await runtime.register_agent_instance(agent_instance=self, agent_id=agent_id)
|
||||
|
||||
id_subscription = TypeSubscription(topic_type=agent_id.key, agent_type=agent_id.type)
|
||||
await runtime.add_subscription(id_subscription)
|
||||
|
||||
if not skip_class_subscriptions:
|
||||
with SubscriptionInstantiationContext.populate_context(AgentType(agent_id.type)):
|
||||
subscriptions: List[Subscription] = []
|
||||
for unbound_subscription in self._unbound_subscriptions():
|
||||
subscriptions_list_result = unbound_subscription()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
subscriptions_list = await subscriptions_list_result
|
||||
else:
|
||||
subscriptions_list = subscriptions_list_result
|
||||
|
||||
subscriptions.extend(subscriptions_list)
|
||||
for subscription in subscriptions:
|
||||
await runtime.add_subscription(subscription)
|
||||
|
||||
if not skip_direct_message_subscription:
|
||||
# Additionally adds a special prefix subscription for this agent to receive direct messages
|
||||
try:
|
||||
await runtime.add_subscription(
|
||||
TypePrefixSubscription(
|
||||
# The prefix MUST include ":" to avoid collisions with other agents
|
||||
topic_type_prefix=agent_id.type + ":",
|
||||
agent_type=agent_id.type,
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
# We don't care if the subscription already exists
|
||||
pass
|
||||
|
||||
# TODO: deduplication
|
||||
for _message_type, serializer in self._handles_types():
|
||||
runtime.add_message_serializer(serializer)
|
||||
|
||||
return agent_id
|
||||
|
||||
@classmethod
|
||||
async def register(
|
||||
cls,
|
||||
runtime: AgentRuntime,
|
||||
type: str,
|
||||
factory: Callable[[], Self | Awaitable[Self]],
|
||||
*,
|
||||
skip_class_subscriptions: bool = False,
|
||||
skip_direct_message_subscription: bool = False,
|
||||
) -> AgentType:
|
||||
agent_type = AgentType(type)
|
||||
agent_type = await runtime.register_factory(type=agent_type, agent_factory=factory, expected_class=cls)
|
||||
if not skip_class_subscriptions:
|
||||
with SubscriptionInstantiationContext.populate_context(agent_type):
|
||||
subscriptions: List[Subscription] = []
|
||||
for unbound_subscription in cls._unbound_subscriptions():
|
||||
subscriptions_list_result = unbound_subscription()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
subscriptions_list = await subscriptions_list_result
|
||||
else:
|
||||
subscriptions_list = subscriptions_list_result
|
||||
|
||||
subscriptions.extend(subscriptions_list)
|
||||
for subscription in subscriptions:
|
||||
await runtime.add_subscription(subscription)
|
||||
|
||||
if not skip_direct_message_subscription:
|
||||
# Additionally adds a special prefix subscription for this agent to receive direct messages
|
||||
await runtime.add_subscription(
|
||||
TypePrefixSubscription(
|
||||
# The prefix MUST include ":" to avoid collisions with other agents
|
||||
topic_type_prefix=agent_type.type + ":",
|
||||
agent_type=agent_type.type,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: deduplication
|
||||
for _message_type, serializer in cls._handles_types():
|
||||
runtime.add_message_serializer(serializer)
|
||||
|
||||
return agent_type
|
||||
70
agent_dhal/agentdhal_core/_cache_store.py
Normal file
70
agent_dhal/agentdhal_core/_cache_store.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Generic, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._component_config import Component, ComponentBase
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class CacheStore(ABC, Generic[T], ComponentBase[BaseModel]):
|
||||
"""
|
||||
This protocol defines the basic interface for store/cache operations.
|
||||
|
||||
Sub-classes should handle the lifecycle of underlying storage.
|
||||
"""
|
||||
|
||||
component_type = "cache_store"
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
|
||||
"""
|
||||
Retrieve an item from the store.
|
||||
|
||||
Args:
|
||||
key: The key identifying the item in the store.
|
||||
default (optional): The default value to return if the key is not found.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
The value associated with the key if found, else the default value.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, value: T) -> None:
|
||||
"""
|
||||
Set an item in the store.
|
||||
|
||||
Args:
|
||||
key: The key under which the item is to be stored.
|
||||
value: The value to be stored in the store.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class InMemoryStoreConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryStore(CacheStore[T], Component[InMemoryStoreConfig]):
|
||||
component_provider_override = "agentdhal_core.InMemoryStore"
|
||||
component_config_schema = InMemoryStoreConfig
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.store: Dict[str, T] = {}
|
||||
|
||||
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
|
||||
return self.store.get(key, default)
|
||||
|
||||
def set(self, key: str, value: T) -> None:
|
||||
self.store[key] = value
|
||||
|
||||
def _to_config(self) -> InMemoryStoreConfig:
|
||||
return InMemoryStoreConfig()
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: InMemoryStoreConfig) -> Self:
|
||||
return cls()
|
||||
46
agent_dhal/agentdhal_core/_cancellation_token.py
Normal file
46
agent_dhal/agentdhal_core/_cancellation_token.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import threading
|
||||
from asyncio import Future
|
||||
from typing import Any, Callable, List
|
||||
|
||||
|
||||
class CancellationToken:
|
||||
"""A token used to cancel pending async calls"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._cancelled: bool = False
|
||||
self._lock: threading.Lock = threading.Lock()
|
||||
self._callbacks: List[Callable[[], None]] = []
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel pending async calls linked to this cancellation token."""
|
||||
with self._lock:
|
||||
if not self._cancelled:
|
||||
self._cancelled = True
|
||||
for callback in self._callbacks:
|
||||
callback()
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
"""Check if the CancellationToken has been used"""
|
||||
with self._lock:
|
||||
return self._cancelled
|
||||
|
||||
def add_callback(self, callback: Callable[[], None]) -> None:
|
||||
"""Attach a callback that will be called when cancel is invoked"""
|
||||
with self._lock:
|
||||
if self._cancelled:
|
||||
callback()
|
||||
else:
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def link_future(self, future: Future[Any]) -> Future[Any]:
|
||||
"""Link a pending async call to a token to allow its cancellation"""
|
||||
with self._lock:
|
||||
if self._cancelled:
|
||||
future.cancel()
|
||||
else:
|
||||
|
||||
def _cancel() -> None:
|
||||
future.cancel()
|
||||
|
||||
self._callbacks.append(_cancel)
|
||||
return future
|
||||
241
agent_dhal/agentdhal_core/_closure_agent.py
Normal file
241
agent_dhal/agentdhal_core/_closure_agent.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Awaitable, Callable, List, Literal, Mapping, Protocol, Sequence, TypeVar, get_type_hints
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_instantiation import AgentInstantiationContext
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._agent_type import AgentType
|
||||
from ._base_agent import BaseAgent
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._message_context import MessageContext
|
||||
from ._serialization import try_get_known_serializers_for_type
|
||||
from ._subscription import Subscription
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._topic import TopicId
|
||||
from ._type_helpers import get_types
|
||||
from .exceptions import CantHandleException
|
||||
|
||||
T = TypeVar("T")
|
||||
ClosureAgentType = TypeVar("ClosureAgentType", bound="ClosureAgent")
|
||||
|
||||
|
||||
def get_handled_types_from_closure(
|
||||
closure: Callable[[ClosureAgent, T, MessageContext], Awaitable[Any]],
|
||||
) -> Sequence[type]:
|
||||
args = inspect.getfullargspec(closure)[0]
|
||||
if len(args) != 3:
|
||||
raise AssertionError("Closure must have 4 arguments")
|
||||
|
||||
message_arg_name = args[1]
|
||||
|
||||
type_hints = get_type_hints(closure)
|
||||
|
||||
if "return" not in type_hints:
|
||||
raise AssertionError("return not found in function signature")
|
||||
|
||||
# Get the type of the message parameter
|
||||
target_types = get_types(type_hints[message_arg_name])
|
||||
if target_types is None:
|
||||
raise AssertionError("Message type not found")
|
||||
|
||||
# print(type_hints)
|
||||
return_types = get_types(type_hints["return"])
|
||||
|
||||
if return_types is None:
|
||||
raise AssertionError("Return type not found")
|
||||
|
||||
return target_types
|
||||
|
||||
|
||||
class ClosureContext(Protocol):
|
||||
@property
|
||||
def id(self) -> AgentId: ...
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
recipient: AgentId,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any: ...
|
||||
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
topic_id: TopicId,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class ClosureAgent(BaseAgent, ClosureContext):
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]],
|
||||
*,
|
||||
unknown_type_policy: Literal["error", "warn", "ignore"] = "warn",
|
||||
) -> None:
|
||||
try:
|
||||
runtime = AgentInstantiationContext.current_runtime()
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"ClosureAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."
|
||||
) from e
|
||||
|
||||
self._runtime: AgentRuntime = runtime
|
||||
self._id: AgentId = id
|
||||
self._description = description
|
||||
handled_types = get_handled_types_from_closure(closure)
|
||||
self._expected_types = handled_types
|
||||
self._closure = closure
|
||||
self._unknown_type_policy = unknown_type_policy
|
||||
super().__init__(description)
|
||||
|
||||
@property
|
||||
def metadata(self) -> AgentMetadata:
|
||||
assert self._id is not None
|
||||
return AgentMetadata(
|
||||
key=self._id.key,
|
||||
type=self._id.type,
|
||||
description=self._description,
|
||||
)
|
||||
|
||||
@property
|
||||
def id(self) -> AgentId:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def runtime(self) -> AgentRuntime:
|
||||
return self._runtime
|
||||
|
||||
async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any:
|
||||
if type(message) not in self._expected_types:
|
||||
if self._unknown_type_policy == "warn":
|
||||
warnings.warn(
|
||||
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'error' to raise an exception, or 'ignore' to suppress this warning.",
|
||||
stacklevel=1,
|
||||
)
|
||||
return None
|
||||
elif self._unknown_type_policy == "error":
|
||||
raise CantHandleException(
|
||||
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'warn' to suppress this exception, or 'ignore' to suppress this warning."
|
||||
)
|
||||
|
||||
return await self._closure(self, message, ctx)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Closure agents do not have state. So this method always returns an empty dictionary."""
|
||||
return {}
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Closure agents do not have state. So this method does nothing."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def register_closure(
|
||||
cls,
|
||||
runtime: AgentRuntime,
|
||||
type: str,
|
||||
closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]],
|
||||
*,
|
||||
unknown_type_policy: Literal["error", "warn", "ignore"] = "warn",
|
||||
skip_direct_message_subscription: bool = False,
|
||||
description: str = "",
|
||||
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None,
|
||||
) -> AgentType:
|
||||
"""The closure agent allows you to define an agent using a closure, or function without needing to define a class. It allows values to be extracted out of the runtime.
|
||||
|
||||
The closure can define the type of message which is expected, or `Any` can be used to accept any type of message.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_core import SingleThreadedAgentRuntime, MessageContext, ClosureAgent, ClosureContext
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agentdhal_core._default_subscription import DefaultSubscription
|
||||
from agentdhal_core._default_topic import DefaultTopicId
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyMessage:
|
||||
content: str
|
||||
|
||||
|
||||
async def main():
|
||||
queue = asyncio.Queue[MyMessage]()
|
||||
|
||||
async def output_result(_ctx: ClosureContext, message: MyMessage, ctx: MessageContext) -> None:
|
||||
await queue.put(message)
|
||||
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
await ClosureAgent.register_closure(
|
||||
runtime, "output_result", output_result, subscriptions=lambda: [DefaultSubscription()]
|
||||
)
|
||||
|
||||
runtime.start()
|
||||
await runtime.publish_message(MyMessage("Hello, world!"), DefaultTopicId())
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
result = await queue.get()
|
||||
print(result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Args:
|
||||
runtime (AgentRuntime): Runtime to register the agent to
|
||||
type (str): Agent type of registered agent
|
||||
closure (Callable[[ClosureContext, T, MessageContext], Awaitable[Any]]): Closure to handle messages
|
||||
unknown_type_policy (Literal["error", "warn", "ignore"], optional): What to do if a type is encountered that does not match the closure type. Defaults to "warn".
|
||||
skip_direct_message_subscription (bool, optional): Do not add direct message subscription for this agent. Defaults to False.
|
||||
description (str, optional): Description of what agent does. Defaults to "".
|
||||
subscriptions (Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None, optional): List of subscriptions for this closure agent. Defaults to None.
|
||||
|
||||
Returns:
|
||||
AgentType: Type of the agent that was registered
|
||||
"""
|
||||
|
||||
def factory() -> ClosureAgent:
|
||||
return ClosureAgent(description=description, closure=closure, unknown_type_policy=unknown_type_policy)
|
||||
|
||||
assert len(cls._unbound_subscriptions()) == 0, "Closure agents are expected to have no class subscriptions"
|
||||
agent_type = await cls.register(
|
||||
runtime=runtime,
|
||||
type=type,
|
||||
factory=factory, # type: ignore
|
||||
# There should be no need to process class subscriptions, as the closure agent does not have any subscriptions.s
|
||||
skip_class_subscriptions=True,
|
||||
skip_direct_message_subscription=skip_direct_message_subscription,
|
||||
)
|
||||
|
||||
subscriptions_list: List[Subscription] = []
|
||||
if subscriptions is not None:
|
||||
with SubscriptionInstantiationContext.populate_context(agent_type):
|
||||
subscriptions_list_result = subscriptions()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
subscriptions_list.extend(await subscriptions_list_result)
|
||||
else:
|
||||
# just ignore mypy here
|
||||
subscriptions_list.extend(subscriptions_list_result) # type: ignore
|
||||
|
||||
for subscription in subscriptions_list:
|
||||
await runtime.add_subscription(subscription)
|
||||
|
||||
handled_types = get_handled_types_from_closure(closure)
|
||||
for message_type in handled_types:
|
||||
# TODO: support custom serializers
|
||||
serializer = try_get_known_serializers_for_type(message_type)
|
||||
runtime.add_message_serializer(serializer)
|
||||
|
||||
return agent_type
|
||||
361
agent_dhal/agentdhal_core/_component_config.py
Normal file
361
agent_dhal/agentdhal_core/_component_config.py
Normal file
@@ -0,0 +1,361 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import warnings
|
||||
from typing import Any, ClassVar, Dict, Generic, Literal, Type, TypeGuard, cast, overload
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self, TypeVar
|
||||
|
||||
ComponentType = Literal["model", "agent", "tool", "termination", "token_provider", "workbench"] | str
|
||||
ConfigT = TypeVar("ConfigT", bound=BaseModel)
|
||||
FromConfigT = TypeVar("FromConfigT", bound=BaseModel, contravariant=True)
|
||||
ToConfigT = TypeVar("ToConfigT", bound=BaseModel, covariant=True)
|
||||
|
||||
T = TypeVar("T", bound=BaseModel, covariant=True)
|
||||
|
||||
|
||||
class ComponentModel(BaseModel):
|
||||
"""Model class for a component. Contains all information required to instantiate a component."""
|
||||
|
||||
provider: str
|
||||
"""Describes how the component can be instantiated."""
|
||||
|
||||
component_type: ComponentType | None = None
|
||||
"""Logical type of the component. If missing, the component assumes the default type of the provider."""
|
||||
|
||||
version: int | None = None
|
||||
"""Version of the component specification. If missing, the component assumes whatever is the current version of the library used to load it. This is obviously dangerous and should be used for user authored ephmeral config. For all other configs version should be specified."""
|
||||
|
||||
component_version: int | None = None
|
||||
"""Version of the component. If missing, the component assumes the default version of the provider."""
|
||||
|
||||
description: str | None = None
|
||||
"""Description of the component."""
|
||||
|
||||
label: str | None = None
|
||||
"""Human readable label for the component. If missing the component assumes the class name of the provider."""
|
||||
|
||||
config: dict[str, Any]
|
||||
"""The schema validated config field is passed to a given class's implmentation of :py:meth:`agentdhal_core.ComponentConfigImpl._from_config` to create a new instance of the component class."""
|
||||
|
||||
|
||||
def _type_to_provider_str(t: type) -> str:
|
||||
return f"{t.__module__}.{t.__qualname__}"
|
||||
|
||||
|
||||
WELL_KNOWN_PROVIDERS = {
|
||||
"azure_openai_chat_completion_client": "agentdhal_extensions.models.openai.AzureOpenAIChatCompletionClient",
|
||||
"AzureOpenAIChatCompletionClient": "agentdhal_extensions.models.openai.AzureOpenAIChatCompletionClient",
|
||||
"openai_chat_completion_client": "agentdhal_extensions.models.openai.OpenAIChatCompletionClient",
|
||||
"OpenAIChatCompletionClient": "agentdhal_extensions.models.openai.OpenAIChatCompletionClient",
|
||||
}
|
||||
|
||||
|
||||
class ComponentFromConfig(Generic[FromConfigT]):
|
||||
@classmethod
|
||||
def _from_config(cls, config: FromConfigT) -> Self:
|
||||
"""Create a new instance of the component from a configuration object.
|
||||
|
||||
Args:
|
||||
config (T): The configuration object.
|
||||
|
||||
Returns:
|
||||
Self: The new instance of the component.
|
||||
|
||||
:meta public:
|
||||
"""
|
||||
raise NotImplementedError("This component does not support dumping to config")
|
||||
|
||||
@classmethod
|
||||
def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self:
|
||||
"""Create a new instance of the component from a previous version of the configuration object.
|
||||
|
||||
This is only called when the version of the configuration object is less than the current version, since in this case the schema is not known.
|
||||
|
||||
Args:
|
||||
config (Dict[str, Any]): The configuration object.
|
||||
version (int): The version of the configuration object.
|
||||
|
||||
Returns:
|
||||
Self: The new instance of the component.
|
||||
|
||||
:meta public:
|
||||
"""
|
||||
raise NotImplementedError("This component does not support loading from past versions")
|
||||
|
||||
|
||||
class ComponentToConfig(Generic[ToConfigT]):
|
||||
"""The two methods a class must implement to be a component.
|
||||
|
||||
Args:
|
||||
Protocol (ConfigT): Type which derives from :py:class:`pydantic.BaseModel`.
|
||||
"""
|
||||
|
||||
component_type: ClassVar[ComponentType]
|
||||
"""The logical type of the component."""
|
||||
component_version: ClassVar[int] = 1
|
||||
"""The version of the component, if schema incompatibilities are introduced this should be updated."""
|
||||
component_provider_override: ClassVar[str | None] = None
|
||||
"""Override the provider string for the component. This should be used to prevent internal module names being a part of the module name."""
|
||||
component_description: ClassVar[str | None] = None
|
||||
"""A description of the component. If not provided, the docstring of the class will be used."""
|
||||
component_label: ClassVar[str | None] = None
|
||||
"""A human readable label for the component. If not provided, the component class name will be used."""
|
||||
|
||||
def _to_config(self) -> ToConfigT:
|
||||
"""Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance.
|
||||
|
||||
Returns:
|
||||
T: The configuration of the component.
|
||||
|
||||
:meta public:
|
||||
"""
|
||||
raise NotImplementedError("This component does not support dumping to config")
|
||||
|
||||
def dump_component(self) -> ComponentModel:
|
||||
"""Dump the component to a model that can be loaded back in.
|
||||
|
||||
Raises:
|
||||
TypeError: If the component is a local class.
|
||||
|
||||
Returns:
|
||||
ComponentModel: The model representing the component.
|
||||
"""
|
||||
if self.component_provider_override is not None:
|
||||
provider = self.component_provider_override
|
||||
else:
|
||||
provider = _type_to_provider_str(self.__class__)
|
||||
# Warn if internal module name is used,
|
||||
if "._" in provider:
|
||||
warnings.warn(
|
||||
"Internal module name used in provider string. This is not recommended and may cause issues in the future. Silence this warning by setting component_provider_override to this value.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if "<locals>" in provider:
|
||||
raise TypeError("Cannot dump component with local class")
|
||||
|
||||
if not hasattr(self, "component_type"):
|
||||
raise AttributeError("component_type not defined")
|
||||
|
||||
description = self.component_description
|
||||
if description is None and self.__class__.__doc__:
|
||||
# use docstring as description
|
||||
docstring = self.__class__.__doc__.strip()
|
||||
for marker in ["\n\nArgs:", "\n\nParameters:", "\n\nAttributes:", "\n\n"]:
|
||||
docstring = docstring.split(marker)[0]
|
||||
description = docstring.strip()
|
||||
|
||||
obj_config = self._to_config().model_dump(exclude_none=True)
|
||||
model = ComponentModel(
|
||||
provider=provider,
|
||||
component_type=self.component_type,
|
||||
version=self.component_version,
|
||||
component_version=self.component_version,
|
||||
description=description,
|
||||
label=self.component_label or self.__class__.__name__,
|
||||
config=obj_config,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
ExpectedType = TypeVar("ExpectedType")
|
||||
|
||||
|
||||
class ComponentLoader:
|
||||
@overload
|
||||
@classmethod
|
||||
def load_component(cls, model: ComponentModel | Dict[str, Any], expected: None = None) -> Self: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def load_component(cls, model: ComponentModel | Dict[str, Any], expected: Type[ExpectedType]) -> ExpectedType: ...
|
||||
|
||||
@classmethod
|
||||
def load_component(
|
||||
cls, model: ComponentModel | Dict[str, Any], expected: Type[ExpectedType] | None = None
|
||||
) -> Self | ExpectedType:
|
||||
"""Load a component from a model. Intended to be used with the return type of :py:meth:`agentdhal_core.ComponentConfig.dump_component`.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core import ComponentModel
|
||||
from agentdhal_core.models import ChatCompletionClient
|
||||
|
||||
component: ComponentModel = ... # type: ignore
|
||||
|
||||
model_client = ChatCompletionClient.load_component(component)
|
||||
|
||||
Args:
|
||||
model (ComponentModel): The model to load the component from.
|
||||
|
||||
Returns:
|
||||
Self: The loaded component.
|
||||
|
||||
Args:
|
||||
model (ComponentModel): _description_
|
||||
expected (Type[ExpectedType] | None, optional): Explicit type only if used directly on ComponentLoader. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider string is invalid.
|
||||
TypeError: Provider is not a subclass of ComponentConfigImpl, or the expected type does not match.
|
||||
|
||||
Returns:
|
||||
Self | ExpectedType: The loaded component.
|
||||
"""
|
||||
|
||||
# Use global and add further type checks
|
||||
|
||||
if isinstance(model, dict):
|
||||
loaded_model = ComponentModel(**model)
|
||||
else:
|
||||
loaded_model = model
|
||||
|
||||
# First, do a look up in well known providers
|
||||
if loaded_model.provider in WELL_KNOWN_PROVIDERS:
|
||||
loaded_model.provider = WELL_KNOWN_PROVIDERS[loaded_model.provider]
|
||||
|
||||
output = loaded_model.provider.rsplit(".", maxsplit=1)
|
||||
if len(output) != 2:
|
||||
raise ValueError("Invalid")
|
||||
|
||||
module_path, class_name = output
|
||||
module = importlib.import_module(module_path)
|
||||
component_class = module.__getattribute__(class_name)
|
||||
|
||||
if not is_component_class(component_class):
|
||||
raise TypeError("Invalid component class")
|
||||
|
||||
# We need to check the schema is valid
|
||||
if not hasattr(component_class, "component_config_schema"):
|
||||
raise AttributeError("component_config_schema not defined")
|
||||
|
||||
if not hasattr(component_class, "component_type"):
|
||||
raise AttributeError("component_type not defined")
|
||||
|
||||
loaded_config_version = loaded_model.component_version or component_class.component_version
|
||||
if loaded_config_version < component_class.component_version:
|
||||
try:
|
||||
instance = component_class._from_config_past_version(loaded_model.config, loaded_config_version) # type: ignore
|
||||
except NotImplementedError as e:
|
||||
raise NotImplementedError(
|
||||
f"Tried to load component {component_class} which is on version {component_class.component_version} with a config on version {loaded_config_version} but _from_config_past_version is not implemented"
|
||||
) from e
|
||||
else:
|
||||
schema = component_class.component_config_schema # type: ignore
|
||||
validated_config = schema.model_validate(loaded_model.config)
|
||||
|
||||
# We're allowed to use the private method here
|
||||
instance = component_class._from_config(validated_config) # type: ignore
|
||||
|
||||
if expected is None and not isinstance(instance, cls):
|
||||
raise TypeError("Expected type does not match")
|
||||
elif expected is None:
|
||||
return cast(Self, instance)
|
||||
elif not isinstance(instance, expected):
|
||||
raise TypeError("Expected type does not match")
|
||||
else:
|
||||
return cast(ExpectedType, instance)
|
||||
|
||||
|
||||
class ComponentSchemaType(Generic[ConfigT]):
|
||||
# Ideally would be ClassVar[Type[ConfigT]], but this is disallowed https://github.com/python/typing/discussions/1424 (despite being valid in this context)
|
||||
component_config_schema: Type[ConfigT]
|
||||
"""The Pydantic model class which represents the configuration of the component."""
|
||||
|
||||
required_class_vars = ["component_config_schema", "component_type"]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
if cls.__name__ != "Component" and not cls.__name__ == "_ConcreteComponent":
|
||||
# TODO: validate provider is loadable
|
||||
for var in cls.required_class_vars:
|
||||
if not hasattr(cls, var):
|
||||
warnings.warn(
|
||||
f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
class ComponentBase(ComponentToConfig[ConfigT], ComponentLoader, Generic[ConfigT]): ...
|
||||
|
||||
|
||||
class Component(
|
||||
ComponentFromConfig[ConfigT],
|
||||
ComponentSchemaType[ConfigT],
|
||||
Generic[ConfigT],
|
||||
):
|
||||
"""To create a component class, inherit from this class for the concrete class and ComponentBase on the interface. Then implement two class variables:
|
||||
|
||||
- :py:attr:`component_config_schema` - A Pydantic model class which represents the configuration of the component. This is also the type parameter of Component.
|
||||
- :py:attr:`component_type` - What is the logical type of the component.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
from agentdhal_core import Component
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class MyComponent(Component[Config]):
|
||||
component_type = "custom"
|
||||
component_config_schema = Config
|
||||
|
||||
def __init__(self, value: str):
|
||||
self.value = value
|
||||
|
||||
def _to_config(self) -> Config:
|
||||
return Config(value=self.value)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: Config) -> MyComponent:
|
||||
return cls(value=config.value)
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
if not is_component_class(cls):
|
||||
warnings.warn(
|
||||
f"Component class '{cls.__name__}' must subclass the following: ComponentFromConfig, ComponentToConfig, ComponentSchemaType, ComponentLoader, individually or with ComponentBase and Component. Look at the component config documentation or how OpenAIChatCompletionClient does it.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
# Should never be used directly, only for type checking
|
||||
class _ConcreteComponent(
|
||||
ComponentFromConfig[ConfigT],
|
||||
ComponentSchemaType[ConfigT],
|
||||
ComponentToConfig[ConfigT],
|
||||
ComponentLoader,
|
||||
Generic[ConfigT],
|
||||
): ...
|
||||
|
||||
|
||||
def is_component_instance(cls: Any) -> TypeGuard[_ConcreteComponent[BaseModel]]:
|
||||
return (
|
||||
isinstance(cls, ComponentFromConfig)
|
||||
and isinstance(cls, ComponentToConfig)
|
||||
and isinstance(cls, ComponentSchemaType)
|
||||
and isinstance(cls, ComponentLoader)
|
||||
)
|
||||
|
||||
|
||||
def is_component_class(cls: type) -> TypeGuard[Type[_ConcreteComponent[BaseModel]]]:
|
||||
return (
|
||||
issubclass(cls, ComponentFromConfig)
|
||||
and issubclass(cls, ComponentToConfig)
|
||||
and issubclass(cls, ComponentSchemaType)
|
||||
and issubclass(cls, ComponentLoader)
|
||||
)
|
||||
9
agent_dhal/agentdhal_core/_constants.py
Normal file
9
agent_dhal/agentdhal_core/_constants.py
Normal file
@@ -0,0 +1,9 @@
|
||||
ROOT_LOGGER_NAME = "agentdhal_core"
|
||||
"""str: Logger name used for root logger"""
|
||||
|
||||
EVENT_LOGGER_NAME = "agentdhal_core.events"
|
||||
"""str: Logger name used for structured event logging"""
|
||||
|
||||
|
||||
TRACE_LOGGER_NAME = "agentdhal_core.trace"
|
||||
"""str: Logger name used for developer intended trace logging. The content and format of this log should not be depended upon."""
|
||||
53
agent_dhal/agentdhal_core/_default_subscription.py
Normal file
53
agent_dhal/agentdhal_core/_default_subscription.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Callable, Type, TypeVar, overload
|
||||
|
||||
from ._agent_type import AgentType
|
||||
from ._base_agent import BaseAgent, subscription_factory
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._type_subscription import TypeSubscription
|
||||
from .exceptions import CantHandleException
|
||||
|
||||
|
||||
class DefaultSubscription(TypeSubscription):
|
||||
"""The default subscription is designed to be a sensible default for applications that only need global scope for agents.
|
||||
|
||||
This topic by default uses the "default" topic type and attempts to detect the agent type to use based on the instantiation context.
|
||||
|
||||
Args:
|
||||
topic_type (str, optional): The topic type to subscribe to. Defaults to "default".
|
||||
agent_type (str, optional): The agent type to use for the subscription. Defaults to None, in which case it will attempt to detect the agent type based on the instantiation context.
|
||||
"""
|
||||
|
||||
def __init__(self, topic_type: str = "default", agent_type: str | AgentType | None = None):
|
||||
if agent_type is None:
|
||||
try:
|
||||
agent_type = SubscriptionInstantiationContext.agent_type().type
|
||||
except RuntimeError as e:
|
||||
raise CantHandleException(
|
||||
"If agent_type is not specified DefaultSubscription must be created within the subscription callback in AgentRuntime.register"
|
||||
) from e
|
||||
|
||||
super().__init__(topic_type, agent_type)
|
||||
|
||||
|
||||
BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")
|
||||
|
||||
|
||||
@overload
|
||||
def default_subscription() -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def default_subscription(cls: Type[BaseAgentType]) -> Type[BaseAgentType]: ...
|
||||
|
||||
|
||||
def default_subscription(
|
||||
cls: Type[BaseAgentType] | None = None,
|
||||
) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]] | Type[BaseAgentType]:
|
||||
if cls is None:
|
||||
return subscription_factory(lambda: [DefaultSubscription()])
|
||||
else:
|
||||
return subscription_factory(lambda: [DefaultSubscription()])(cls)
|
||||
|
||||
|
||||
def type_subscription(topic_type: str) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
|
||||
return subscription_factory(lambda: [DefaultSubscription(topic_type=topic_type)])
|
||||
23
agent_dhal/agentdhal_core/_default_topic.py
Normal file
23
agent_dhal/agentdhal_core/_default_topic.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from ._message_handler_context import MessageHandlerContext
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
class DefaultTopicId(TopicId):
|
||||
"""DefaultTopicId provides a sensible default for the topic_id and source fields of a TopicId.
|
||||
|
||||
If created in the context of a message handler, the source will be set to the agent_id of the message handler, otherwise it will be set to "default".
|
||||
|
||||
Args:
|
||||
type (str, optional): Topic type to publish message to. Defaults to "default".
|
||||
source (str | None, optional): Topic source to publish message to. If None, the source will be set to the agent_id of the message handler if in the context of a message handler, otherwise it will be set to "default". Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, type: str = "default", source: str | None = None) -> None:
|
||||
if source is None:
|
||||
try:
|
||||
source = MessageHandlerContext.agent_id().key
|
||||
# If we aren't in the context of a message handler, we use the default source
|
||||
except RuntimeError:
|
||||
source = "default"
|
||||
|
||||
super().__init__(type, source)
|
||||
324
agent_dhal/agentdhal_core/_function_utils.py
Normal file
324
agent_dhal/agentdhal_core/_function_utils.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/function_utils.py
|
||||
# Credit to original authors
|
||||
|
||||
import inspect
|
||||
import typing
|
||||
from functools import partial
|
||||
from logging import getLogger
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter, create_model # type: ignore
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import Literal
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
"""Get the signature of a function with type annotations.
|
||||
|
||||
Args:
|
||||
call: The function to get the signature for
|
||||
|
||||
Returns:
|
||||
The signature of the function with type annotations
|
||||
"""
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
func_call = call.func if isinstance(call, partial) else call
|
||||
type_hints = typing.get_type_hints(func_call, globalns, include_extras=True)
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=type_hints[param.name],
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
return_annotation = type_hints.get("return", inspect.Signature.empty)
|
||||
typed_signature = inspect.Signature(typed_params, return_annotation=return_annotation)
|
||||
return typed_signature
|
||||
|
||||
|
||||
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
||||
"""Get the return annotation of a function.
|
||||
|
||||
Args:
|
||||
call: The function to get the return annotation for
|
||||
|
||||
Returns:
|
||||
The return annotation of the function
|
||||
"""
|
||||
signature = inspect.signature(call)
|
||||
annotation = signature.return_annotation
|
||||
|
||||
if annotation is inspect.Signature.empty:
|
||||
return None
|
||||
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
type_hints = typing.get_type_hints(call, globalns, include_extras=True)
|
||||
return type_hints.get("return", inspect.Signature.empty)
|
||||
|
||||
|
||||
def get_param_annotations(
|
||||
typed_signature: inspect.Signature,
|
||||
) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]:
|
||||
"""Get the type annotations of the parameters of a function
|
||||
|
||||
Args:
|
||||
typed_signature: The signature of the function with type annotations
|
||||
|
||||
Returns:
|
||||
A dictionary of the type annotations of the parameters of the function
|
||||
"""
|
||||
return {
|
||||
k: v.annotation for k, v in typed_signature.parameters.items() if v.annotation is not inspect.Signature.empty
|
||||
}
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
"""Parameters of a function as defined by the OpenAI API"""
|
||||
|
||||
type: Literal["object"] = "object"
|
||||
properties: Dict[str, Dict[str, Any]]
|
||||
required: List[str]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""A function as defined by the OpenAI API"""
|
||||
|
||||
description: Annotated[str, Field(description="Description of the function")]
|
||||
name: Annotated[str, Field(description="Name of the function")]
|
||||
parameters: Annotated[Parameters, Field(description="Parameters of the function")]
|
||||
|
||||
|
||||
class ToolFunction(BaseModel):
|
||||
"""A function under tool as defined by the OpenAI API."""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
function: Annotated[Function, Field(description="Function under tool")]
|
||||
|
||||
|
||||
def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
|
||||
# handles Annotated
|
||||
if hasattr(v, "__metadata__"):
|
||||
retval = v.__metadata__[0]
|
||||
if isinstance(retval, str):
|
||||
return retval
|
||||
else:
|
||||
raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
|
||||
else:
|
||||
return k
|
||||
|
||||
|
||||
def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get a JSON schema for a parameter as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
k: The name of the parameter
|
||||
v: The type of the parameter
|
||||
default_values: The default values of the parameters of the function
|
||||
|
||||
Returns:
|
||||
A Pydanitc model for the parameter
|
||||
"""
|
||||
|
||||
schema = TypeAdapter(v).json_schema()
|
||||
if k in default_values:
|
||||
dv = default_values[k]
|
||||
schema["default"] = dv
|
||||
|
||||
schema["description"] = type2description(k, v)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def get_required_params(typed_signature: inspect.Signature) -> List[str]:
|
||||
"""Get the required parameters of a function
|
||||
|
||||
Args:
|
||||
typed_signature: The signature of the function as returned by inspect.signature
|
||||
|
||||
Returns:
|
||||
A list of the required parameters of the function
|
||||
"""
|
||||
return [k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty]
|
||||
|
||||
|
||||
def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]:
|
||||
"""Get default values of parameters of a function
|
||||
|
||||
Args:
|
||||
typed_signature: The signature of the function as returned by inspect.signature
|
||||
|
||||
Returns:
|
||||
A dictionary of the default values of the parameters of the function
|
||||
"""
|
||||
return {k: v.default for k, v in typed_signature.parameters.items() if v.default != inspect.Signature.empty}
|
||||
|
||||
|
||||
def get_parameters(
|
||||
required: List[str],
|
||||
param_annotations: Dict[str, Union[Annotated[Type[Any], str], Type[Any]]],
|
||||
default_values: Dict[str, Any],
|
||||
) -> Parameters:
|
||||
"""Get the parameters of a function as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
required: The required parameters of the function
|
||||
param_annotations: A dictionary of the type annotations of the parameters of the function
|
||||
default_values: The default values of the parameters of the function
|
||||
|
||||
Returns:
|
||||
A Pydantic model for the parameters of the function
|
||||
"""
|
||||
return Parameters(
|
||||
properties={
|
||||
k: get_parameter_json_schema(k, v, default_values)
|
||||
for k, v in param_annotations.items()
|
||||
if v is not inspect.Signature.empty
|
||||
},
|
||||
required=required,
|
||||
)
|
||||
|
||||
|
||||
def get_missing_annotations(typed_signature: inspect.Signature, required: List[str]) -> Tuple[Set[str], Set[str]]:
|
||||
"""Get the missing annotations of a function
|
||||
|
||||
Ignores the parameters with default values as they are not required to be annotated, but logs a warning.
|
||||
Args:
|
||||
typed_signature: The signature of the function with type annotations
|
||||
required: The required parameters of the function
|
||||
|
||||
Returns:
|
||||
A set of the missing annotations of the function
|
||||
"""
|
||||
all_missing = {k for k, v in typed_signature.parameters.items() if v.annotation is inspect.Signature.empty}
|
||||
missing = all_missing.intersection(set(required))
|
||||
unannotated_with_default = all_missing.difference(missing)
|
||||
return missing, unannotated_with_default
|
||||
|
||||
|
||||
def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, description: str) -> Dict[str, Any]:
|
||||
"""Get a JSON schema for a function as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
f: The function to get the JSON schema for
|
||||
name: The name of the function
|
||||
description: The description of the function
|
||||
|
||||
Returns:
|
||||
A JSON schema for the function
|
||||
|
||||
Raises:
|
||||
TypeError: If the function is not annotated
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def f(
|
||||
a: Annotated[str, "Parameter a"],
|
||||
b: int = 2,
|
||||
c: Annotated[float, "Parameter c"] = 0.1,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
get_function_schema(f, description="function f")
|
||||
|
||||
# {'type': 'function',
|
||||
# 'function': {'description': 'function f',
|
||||
# 'name': 'f',
|
||||
# 'parameters': {'type': 'object',
|
||||
# 'properties': {'a': {'type': 'str', 'description': 'Parameter a'},
|
||||
# 'b': {'type': 'int', 'description': 'b'},
|
||||
# 'c': {'type': 'float', 'description': 'Parameter c'}},
|
||||
# 'required': ['a']}}}
|
||||
|
||||
"""
|
||||
typed_signature = get_typed_signature(f)
|
||||
required = get_required_params(typed_signature)
|
||||
default_values = get_default_values(typed_signature)
|
||||
param_annotations = get_param_annotations(typed_signature)
|
||||
return_annotation = get_typed_return_annotation(f)
|
||||
missing, unannotated_with_default = get_missing_annotations(typed_signature, required)
|
||||
|
||||
if return_annotation is None:
|
||||
logger.warning(
|
||||
f"The return type of the function '{f.__name__}' is not annotated. Although annotating it is "
|
||||
+ "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'."
|
||||
)
|
||||
|
||||
if unannotated_with_default != set():
|
||||
unannotated_with_default_s = [f"'{k}'" for k in sorted(unannotated_with_default)]
|
||||
logger.warning(
|
||||
f"The following parameters of the function '{f.__name__}' with default values are not annotated: "
|
||||
+ f"{', '.join(unannotated_with_default_s)}."
|
||||
)
|
||||
|
||||
if missing != set():
|
||||
missing_s = [f"'{k}'" for k in sorted(missing)]
|
||||
raise TypeError(
|
||||
f"All parameters of the function '{f.__name__}' without default values must be annotated. "
|
||||
+ f"The annotations are missing for the following parameters: {', '.join(missing_s)}"
|
||||
)
|
||||
|
||||
fname = name if name else f.__name__
|
||||
|
||||
parameters = get_parameters(required, param_annotations, default_values=default_values)
|
||||
|
||||
function = ToolFunction(
|
||||
function=Function(
|
||||
description=description,
|
||||
name=fname,
|
||||
parameters=parameters,
|
||||
)
|
||||
)
|
||||
|
||||
return function.model_dump()
|
||||
|
||||
|
||||
def normalize_annotated_type(type_hint: Type[Any]) -> Type[Any]:
|
||||
"""Normalize typing.Annotated types to the inner type."""
|
||||
if get_origin(type_hint) is Annotated:
|
||||
# Extract the inner type from Annotated
|
||||
return get_args(type_hint)[0] # type: ignore
|
||||
return type_hint
|
||||
|
||||
|
||||
def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
|
||||
fields: Dict[str, tuple[Type[Any], Any]] = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
# This is handled externally
|
||||
if param_name == "cancellation_token":
|
||||
continue
|
||||
|
||||
if param.annotation is inspect.Parameter.empty:
|
||||
raise ValueError("No annotation")
|
||||
|
||||
type = normalize_annotated_type(param.annotation)
|
||||
description = type2description(param_name, param.annotation)
|
||||
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined
|
||||
|
||||
fields[param_name] = (type, Field(default=default_value, description=description))
|
||||
|
||||
return cast(BaseModel, create_model(name, **fields)) # type: ignore
|
||||
127
agent_dhal/agentdhal_core/_image.py
Normal file
127
agent_dhal/agentdhal_core/_image.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import re
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from pydantic import GetCoreSchemaHandler, ValidationInfo
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class Image:
|
||||
"""Represents an image.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
Loading an image from a URL:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core import Image
|
||||
from PIL import Image as PILImage
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
|
||||
async def from_url(url: str) -> Image:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
content = await response.read()
|
||||
return Image.from_pil(PILImage.open(content))
|
||||
|
||||
|
||||
image = asyncio.run(from_url("https://example.com/image"))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, image: PILImage.Image):
|
||||
self.image: PILImage.Image = image.convert("RGB")
|
||||
|
||||
@classmethod
|
||||
def from_pil(cls, pil_image: PILImage.Image) -> Image:
|
||||
return cls(pil_image)
|
||||
|
||||
@classmethod
|
||||
def from_uri(cls, uri: str) -> Image:
|
||||
if not re.match(r"data:image/(?:png|jpeg);base64,", uri):
|
||||
raise ValueError("Invalid URI format. It should be a base64 encoded image URI.")
|
||||
|
||||
# A URI. Remove the prefix and decode the base64 string.
|
||||
base64_data = re.sub(r"data:image/(?:png|jpeg);base64,", "", uri)
|
||||
return cls.from_base64(base64_data)
|
||||
|
||||
@classmethod
|
||||
def from_base64(cls, base64_str: str) -> Image:
|
||||
return cls(PILImage.open(BytesIO(base64.b64decode(base64_str))))
|
||||
|
||||
def to_base64(self) -> str:
|
||||
buffered = BytesIO()
|
||||
self.image.save(buffered, format="PNG")
|
||||
content = buffered.getvalue()
|
||||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file_path: Path) -> Image:
|
||||
return cls(PILImage.open(file_path))
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
# Show the image in Jupyter notebook
|
||||
return f'<img src="{self.data_uri}"/>'
|
||||
|
||||
@property
|
||||
def data_uri(self) -> str:
|
||||
return _convert_base64_to_data_uri(self.to_base64())
|
||||
|
||||
# Returns openai.types.chat.ChatCompletionContentPartImageParam, which is a TypedDict
|
||||
# We don't use the explicit type annotation so that we can avoid a dependency on the OpenAI Python SDK in this package.
|
||||
def to_openai_format(self, detail: Literal["auto", "low", "high"] = "auto") -> Dict[str, Any]:
|
||||
return {"type": "image_url", "image_url": {"url": self.data_uri, "detail": detail}}
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
# Custom validation
|
||||
def validate(value: Any, validation_info: ValidationInfo) -> Image:
|
||||
if isinstance(value, dict):
|
||||
base_64 = cast(str | None, value.get("data")) # type: ignore
|
||||
if base_64 is None:
|
||||
raise ValueError("Expected 'data' key in the dictionary")
|
||||
return cls.from_base64(base_64)
|
||||
elif isinstance(value, cls):
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"Expected dict or {cls.__name__} instance, got {type(value)}")
|
||||
|
||||
# Custom serialization
|
||||
def serialize(value: Image) -> dict[str, Any]:
|
||||
return {"data": value.to_base64()}
|
||||
|
||||
return core_schema.with_info_after_validator_function(
|
||||
validate,
|
||||
core_schema.any_schema(), # Accept any type; adjust if needed
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(serialize),
|
||||
)
|
||||
|
||||
|
||||
def _convert_base64_to_data_uri(base64_image: str) -> str:
|
||||
def _get_mime_type_from_data_uri(base64_image: str) -> str:
|
||||
# Decode the base64 string
|
||||
image_data = base64.b64decode(base64_image)
|
||||
# Check the first few bytes for known signatures
|
||||
if image_data.startswith(b"\xff\xd8\xff"):
|
||||
return "image/jpeg"
|
||||
elif image_data.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||
return "image/png"
|
||||
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
|
||||
return "image/gif"
|
||||
elif image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP":
|
||||
return "image/webp"
|
||||
return "image/jpeg" # use jpeg for unknown formats, best guess.
|
||||
|
||||
mime_type = _get_mime_type_from_data_uri(base64_image)
|
||||
data_uri = f"data:{mime_type};base64,{base64_image}"
|
||||
return data_uri
|
||||
83
agent_dhal/agentdhal_core/_intervention.py
Normal file
83
agent_dhal/agentdhal_core/_intervention.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Any, Protocol, final
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._message_context import MessageContext
|
||||
|
||||
__all__ = [
|
||||
"DropMessage",
|
||||
"InterventionHandler",
|
||||
"DefaultInterventionHandler",
|
||||
]
|
||||
|
||||
|
||||
@final
|
||||
class DropMessage:
|
||||
"""Marker type for signalling that a message should be dropped by an intervention handler. The type itself should be returned from the handler."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class InterventionHandler(Protocol):
|
||||
"""An intervention handler is a class that can be used to modify, log or drop messages that are being processed by the :class:`agentdhal_core.base.AgentRuntime`.
|
||||
|
||||
The handler is called when the message is submitted to the runtime.
|
||||
|
||||
Currently the only runtime which supports this is the :class:`agentdhal_core.base.SingleThreadedAgentRuntime`.
|
||||
|
||||
Note: Returning None from any of the intervention handler methods will result in a warning being issued and treated as "no change". If you intend to drop a message, you should return :class:`DropMessage` explicitly.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core import DefaultInterventionHandler, MessageContext, AgentId, SingleThreadedAgentRuntime
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyMessage:
|
||||
content: str
|
||||
|
||||
|
||||
class MyInterventionHandler(DefaultInterventionHandler):
|
||||
async def on_send(self, message: Any, *, message_context: MessageContext, recipient: AgentId) -> MyMessage:
|
||||
if isinstance(message, MyMessage):
|
||||
message.content = message.content.upper()
|
||||
return message
|
||||
|
||||
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handlers=[MyInterventionHandler()])
|
||||
|
||||
"""
|
||||
|
||||
async def on_send(
|
||||
self, message: Any, *, message_context: MessageContext, recipient: AgentId
|
||||
) -> Any | type[DropMessage]:
|
||||
"""Called when a message is submitted to the AgentRuntime using :meth:`agentdhal_core.base.AgentRuntime.send_message`."""
|
||||
...
|
||||
|
||||
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]:
|
||||
"""Called when a message is published to the AgentRuntime using :meth:`agentdhal_core.base.AgentRuntime.publish_message`."""
|
||||
...
|
||||
|
||||
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
|
||||
"""Called when a response is received by the AgentRuntime from an Agent's message handler returning a value."""
|
||||
...
|
||||
|
||||
|
||||
class DefaultInterventionHandler(InterventionHandler):
|
||||
"""Simple class that provides a default implementation for all intervention
|
||||
handler methods, that simply returns the message unchanged. Allows for easy
|
||||
subclassing to override only the desired methods."""
|
||||
|
||||
async def on_send(
|
||||
self, message: Any, *, message_context: MessageContext, recipient: AgentId
|
||||
) -> Any | type[DropMessage]:
|
||||
return message
|
||||
|
||||
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]:
|
||||
return message
|
||||
|
||||
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
|
||||
return message
|
||||
14
agent_dhal/agentdhal_core/_message_context.py
Normal file
14
agent_dhal/agentdhal_core/_message_context.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageContext:
|
||||
sender: AgentId | None
|
||||
topic_id: TopicId | None
|
||||
is_rpc: bool
|
||||
cancellation_token: CancellationToken
|
||||
message_id: str
|
||||
31
agent_dhal/agentdhal_core/_message_handler_context.py
Normal file
31
agent_dhal/agentdhal_core/_message_handler_context.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, ClassVar, Generator
|
||||
|
||||
from ._agent_id import AgentId
|
||||
|
||||
|
||||
class MessageHandlerContext:
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError(
|
||||
"MessageHandlerContext cannot be instantiated. It is a static class that provides context management for message handling."
|
||||
)
|
||||
|
||||
_MESSAGE_HANDLER_CONTEXT: ClassVar[ContextVar[AgentId]] = ContextVar("_MESSAGE_HANDLER_CONTEXT")
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def populate_context(cls, ctx: AgentId) -> Generator[None, Any, None]:
|
||||
""":meta private:"""
|
||||
token = MessageHandlerContext._MESSAGE_HANDLER_CONTEXT.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
MessageHandlerContext._MESSAGE_HANDLER_CONTEXT.reset(token)
|
||||
|
||||
@classmethod
|
||||
def agent_id(cls) -> AgentId:
|
||||
try:
|
||||
return cls._MESSAGE_HANDLER_CONTEXT.get()
|
||||
except LookupError as e:
|
||||
raise RuntimeError("MessageHandlerContext.agent_id() must be called within a message handler.") from e
|
||||
264
agent_dhal/agentdhal_core/_queue.py
Normal file
264
agent_dhal/agentdhal_core/_queue.py
Normal file
@@ -0,0 +1,264 @@
|
||||
# Copy of Asyncio queue: https://github.com/python/cpython/blob/main/Lib/asyncio/queues.py
|
||||
# So that shutdown can be used in <3.13
|
||||
# Modified to work outside of the asyncio package
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import threading
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
_global_lock = threading.Lock()
|
||||
|
||||
|
||||
class _LoopBoundMixin:
|
||||
_loop = None
|
||||
|
||||
def _get_loop(self) -> asyncio.AbstractEventLoop:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if self._loop is None:
|
||||
with _global_lock:
|
||||
if self._loop is None:
|
||||
self._loop = loop
|
||||
if loop is not self._loop:
|
||||
raise RuntimeError(f"{self!r} is bound to a different event loop")
|
||||
return loop
|
||||
|
||||
|
||||
class QueueShutDown(Exception):
|
||||
"""Raised when putting on to or getting from a shut-down Queue."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Queue(_LoopBoundMixin, Generic[T]):
|
||||
def __init__(self, maxsize: int = 0):
|
||||
self._maxsize = maxsize
|
||||
self._getters = collections.deque[asyncio.Future[None]]()
|
||||
self._putters = collections.deque[asyncio.Future[None]]()
|
||||
self._unfinished_tasks = 0
|
||||
self._finished = asyncio.Event()
|
||||
self._finished.set()
|
||||
self._queue = collections.deque[T]()
|
||||
self._is_shutdown = False
|
||||
|
||||
# These three are overridable in subclasses.
|
||||
|
||||
def _get(self) -> T:
|
||||
return self._queue.popleft()
|
||||
|
||||
def _put(self, item: T) -> None:
|
||||
self._queue.append(item)
|
||||
|
||||
# End of the overridable methods.
|
||||
|
||||
def _wakeup_next(self, waiters: collections.deque[asyncio.Future[None]]) -> None:
|
||||
# Wake up the next waiter (if any) that isn't cancelled.
|
||||
while waiters:
|
||||
waiter = waiters.popleft()
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
break
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{type(self).__name__} at {id(self):#x} {self._format()}>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"<{type(self).__name__} {self._format()}>"
|
||||
|
||||
def _format(self) -> str:
|
||||
result = f"maxsize={self._maxsize!r}"
|
||||
if getattr(self, "_queue", None):
|
||||
result += f" _queue={list(self._queue)!r}"
|
||||
if self._getters:
|
||||
result += f" _getters[{len(self._getters)}]"
|
||||
if self._putters:
|
||||
result += f" _putters[{len(self._putters)}]"
|
||||
if self._unfinished_tasks:
|
||||
result += f" tasks={self._unfinished_tasks}"
|
||||
if self._is_shutdown:
|
||||
result += " shutdown"
|
||||
return result
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""Number of items in the queue."""
|
||||
return len(self._queue)
|
||||
|
||||
@property
|
||||
def maxsize(self) -> int:
|
||||
"""Number of items allowed in the queue."""
|
||||
return self._maxsize
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Return True if the queue is empty, False otherwise."""
|
||||
return not self._queue
|
||||
|
||||
def full(self) -> bool:
|
||||
"""Return True if there are maxsize items in the queue.
|
||||
|
||||
Note: if the Queue was initialized with maxsize=0 (the default),
|
||||
then full() is never True.
|
||||
"""
|
||||
if self._maxsize <= 0:
|
||||
return False
|
||||
else:
|
||||
return self.qsize() >= self._maxsize
|
||||
|
||||
async def put(self, item: T) -> None:
|
||||
"""Put an item into the queue.
|
||||
|
||||
Put an item into the queue. If the queue is full, wait until a free
|
||||
slot is available before adding item.
|
||||
|
||||
Raises QueueShutDown if the queue has been shut down.
|
||||
"""
|
||||
while self.full():
|
||||
if self._is_shutdown:
|
||||
raise QueueShutDown
|
||||
putter = self._get_loop().create_future()
|
||||
self._putters.append(putter)
|
||||
try:
|
||||
await putter
|
||||
except:
|
||||
putter.cancel() # Just in case putter is not done yet.
|
||||
try:
|
||||
# Clean self._putters from canceled putters.
|
||||
self._putters.remove(putter)
|
||||
except ValueError:
|
||||
# The putter could be removed from self._putters by a
|
||||
# previous get_nowait call or a shutdown call.
|
||||
pass
|
||||
if not self.full() and not putter.cancelled():
|
||||
# We were woken up by get_nowait(), but can't take
|
||||
# the call. Wake up the next in line.
|
||||
self._wakeup_next(self._putters)
|
||||
raise
|
||||
return self.put_nowait(item)
|
||||
|
||||
def put_nowait(self, item: T) -> None:
|
||||
"""Put an item into the queue without blocking.
|
||||
|
||||
If no free slot is immediately available, raise QueueFull.
|
||||
|
||||
Raises QueueShutDown if the queue has been shut down.
|
||||
"""
|
||||
if self._is_shutdown:
|
||||
raise QueueShutDown
|
||||
if self.full():
|
||||
raise asyncio.QueueFull
|
||||
self._put(item)
|
||||
self._unfinished_tasks += 1
|
||||
self._finished.clear()
|
||||
self._wakeup_next(self._getters)
|
||||
|
||||
async def get(self) -> T:
|
||||
"""Remove and return an item from the queue.
|
||||
|
||||
If queue is empty, wait until an item is available.
|
||||
|
||||
Raises QueueShutDown if the queue has been shut down and is empty, or
|
||||
if the queue has been shut down immediately.
|
||||
"""
|
||||
while self.empty():
|
||||
if self._is_shutdown and self.empty():
|
||||
raise QueueShutDown
|
||||
getter = self._get_loop().create_future()
|
||||
self._getters.append(getter)
|
||||
try:
|
||||
await getter
|
||||
except:
|
||||
getter.cancel() # Just in case getter is not done yet.
|
||||
try:
|
||||
# Clean self._getters from canceled getters.
|
||||
self._getters.remove(getter)
|
||||
except ValueError:
|
||||
# The getter could be removed from self._getters by a
|
||||
# previous put_nowait call, or a shutdown call.
|
||||
pass
|
||||
if not self.empty() and not getter.cancelled():
|
||||
# We were woken up by put_nowait(), but can't take
|
||||
# the call. Wake up the next in line.
|
||||
self._wakeup_next(self._getters)
|
||||
raise
|
||||
return self.get_nowait()
|
||||
|
||||
def get_nowait(self) -> T:
|
||||
"""Remove and return an item from the queue.
|
||||
|
||||
Return an item if one is immediately available, else raise QueueEmpty.
|
||||
|
||||
Raises QueueShutDown if the queue has been shut down and is empty, or
|
||||
if the queue has been shut down immediately.
|
||||
"""
|
||||
if self.empty():
|
||||
if self._is_shutdown:
|
||||
raise QueueShutDown
|
||||
raise asyncio.QueueEmpty
|
||||
item = self._get()
|
||||
self._wakeup_next(self._putters)
|
||||
return item
|
||||
|
||||
def task_done(self) -> None:
|
||||
"""Indicate that a formerly enqueued task is complete.
|
||||
|
||||
Used by queue consumers. For each get() used to fetch a task,
|
||||
a subsequent call to task_done() tells the queue that the processing
|
||||
on the task is complete.
|
||||
|
||||
If a join() is currently blocking, it will resume when all items have
|
||||
been processed (meaning that a task_done() call was received for every
|
||||
item that had been put() into the queue).
|
||||
|
||||
shutdown(immediate=True) calls task_done() for each remaining item in
|
||||
the queue.
|
||||
|
||||
Raises ValueError if called more times than there were items placed in
|
||||
the queue.
|
||||
"""
|
||||
if self._unfinished_tasks <= 0:
|
||||
raise ValueError("task_done() called too many times")
|
||||
self._unfinished_tasks -= 1
|
||||
if self._unfinished_tasks == 0:
|
||||
self._finished.set()
|
||||
|
||||
async def join(self) -> None:
|
||||
"""Block until all items in the queue have been gotten and processed.
|
||||
|
||||
The count of unfinished tasks goes up whenever an item is added to the
|
||||
queue. The count goes down whenever a consumer calls task_done() to
|
||||
indicate that the item was retrieved and all work on it is complete.
|
||||
When the count of unfinished tasks drops to zero, join() unblocks.
|
||||
"""
|
||||
if self._unfinished_tasks > 0:
|
||||
await self._finished.wait()
|
||||
|
||||
def shutdown(self, immediate: bool = False) -> None:
|
||||
"""Shut-down the queue, making queue gets and puts raise QueueShutDown.
|
||||
|
||||
By default, gets will only raise once the queue is empty. Set
|
||||
'immediate' to True to make gets raise immediately instead.
|
||||
|
||||
All blocked callers of put() and get() will be unblocked. If
|
||||
'immediate', a task is marked as done for each item remaining in
|
||||
the queue, which may unblock callers of join().
|
||||
"""
|
||||
self._is_shutdown = True
|
||||
if immediate:
|
||||
while not self.empty():
|
||||
self._get()
|
||||
if self._unfinished_tasks > 0:
|
||||
self._unfinished_tasks -= 1
|
||||
if self._unfinished_tasks == 0:
|
||||
self._finished.set()
|
||||
# All getters need to re-check queue-empty to raise ShutDown
|
||||
while self._getters:
|
||||
getter = self._getters.popleft()
|
||||
if not getter.done():
|
||||
getter.set_result(None)
|
||||
while self._putters:
|
||||
putter = self._putters.popleft()
|
||||
if not putter.done():
|
||||
putter.set_result(None)
|
||||
518
agent_dhal/agentdhal_core/_routed_agent.py
Normal file
518
agent_dhal/agentdhal_core/_routed_agent.py
Normal file
@@ -0,0 +1,518 @@
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
DefaultDict,
|
||||
List,
|
||||
Literal,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_type_hints,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from ._base_agent import BaseAgent
|
||||
from ._message_context import MessageContext
|
||||
from ._serialization import MessageSerializer, try_get_known_serializers_for_type
|
||||
from ._type_helpers import AnyType, get_types
|
||||
from .exceptions import CantHandleException
|
||||
|
||||
logger = logging.getLogger("agentdhal_core")
|
||||
|
||||
AgentT = TypeVar("AgentT")
|
||||
ReceivesT = TypeVar("ReceivesT")
|
||||
ProducesT = TypeVar("ProducesT", covariant=True)
|
||||
|
||||
# TODO: Generic typevar bound binding U to agent type
|
||||
# Can't do because python doesnt support it
|
||||
|
||||
|
||||
# Pyright and mypy disagree on the variance of ReceivesT. Mypy thinks it should be contravariant here.
|
||||
# Revisit this later to see if we can remove the ignore.
|
||||
@runtime_checkable
|
||||
class MessageHandler(Protocol[AgentT, ReceivesT, ProducesT]): # type: ignore
|
||||
target_types: Sequence[type]
|
||||
produces_types: Sequence[type]
|
||||
is_message_handler: Literal[True]
|
||||
router: Callable[[ReceivesT, MessageContext], bool]
|
||||
|
||||
# agent_instance binds to self in the method
|
||||
@staticmethod
|
||||
async def __call__(agent_instance: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: ...
|
||||
|
||||
|
||||
# NOTE: this works on concrete types and not inheritance
|
||||
# TODO: Use a protocol for the outer function to check checked arg names
|
||||
|
||||
|
||||
@overload
|
||||
def message_handler(
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, ProducesT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def message_handler(
|
||||
func: None = None,
|
||||
*,
|
||||
match: None = ...,
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def message_handler(
|
||||
func: None = None,
|
||||
*,
|
||||
match: Callable[[ReceivesT, MessageContext], bool],
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
def message_handler(
|
||||
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]
|
||||
| MessageHandler[AgentT, ReceivesT, ProducesT]
|
||||
):
|
||||
"""Decorator for generic message handlers.
|
||||
|
||||
Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle both event and RPC messages.
|
||||
These methods must have a specific signature that needs to be followed for it to be valid:
|
||||
|
||||
- The method must be an `async` method.
|
||||
- The method must be decorated with the `@message_handler` decorator.
|
||||
- The method must have exactly 3 arguments:
|
||||
1. `self`
|
||||
2. `message`: The message to be handled, this must be type-hinted with the message type that it is intended to handle.
|
||||
3. `ctx`: A :class:`agentdhal_core.MessageContext` object.
|
||||
- The method must be type hinted with what message types it can return as a response, or it can return `None` if it does not return anything.
|
||||
|
||||
Handlers can handle more than one message type by accepting a Union of the message types. It can also return more than one message type by returning a Union of the message types.
|
||||
|
||||
Args:
|
||||
func: The function to be decorated.
|
||||
strict: If `True`, the handler will raise an exception if the message type or return type is not in the target types. If `False`, it will log a warning instead.
|
||||
match: A function that takes the message and the context as arguments and returns a boolean. This is used for secondary routing after the message type. For handlers addressing the same message type, the match function is applied in alphabetical order of the handlers and the first matching handler will be called while the rest are skipped. If `None`, the first handler in alphabetical order matching the same message type will be called.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, ProducesT]:
|
||||
type_hints = get_type_hints(func)
|
||||
if "message" not in type_hints:
|
||||
raise AssertionError("message parameter not found in function signature")
|
||||
|
||||
if "return" not in type_hints:
|
||||
raise AssertionError("return parameter not found in function signature")
|
||||
|
||||
# Get the type of the message parameter
|
||||
target_types = get_types(type_hints["message"])
|
||||
if target_types is None:
|
||||
raise AssertionError("Message type not found")
|
||||
|
||||
# print(type_hints)
|
||||
return_types = get_types(type_hints["return"])
|
||||
|
||||
if return_types is None:
|
||||
raise AssertionError("Return type not found")
|
||||
|
||||
# Convert target_types to list and stash
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
else:
|
||||
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
||||
|
||||
return_value = await func(self, message, ctx)
|
||||
|
||||
if AnyType not in return_types and type(return_value) not in return_types:
|
||||
if strict:
|
||||
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
else:
|
||||
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
|
||||
return return_value
|
||||
|
||||
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper)
|
||||
wrapper_handler.target_types = list(target_types)
|
||||
wrapper_handler.produces_types = list(return_types)
|
||||
wrapper_handler.is_message_handler = True
|
||||
wrapper_handler.router = match or (lambda _message, _ctx: True)
|
||||
|
||||
return wrapper_handler
|
||||
|
||||
if func is None and not callable(func):
|
||||
return decorator
|
||||
elif callable(func):
|
||||
return decorator(func)
|
||||
else:
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
|
||||
@overload
|
||||
def event(
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def event(
|
||||
func: None = None,
|
||||
*,
|
||||
match: None = ...,
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[AgentT, ReceivesT, None],
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def event(
|
||||
func: None = None,
|
||||
*,
|
||||
match: Callable[[ReceivesT, MessageContext], bool],
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[AgentT, ReceivesT, None],
|
||||
]: ...
|
||||
|
||||
|
||||
def event(
|
||||
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[AgentT, ReceivesT, None],
|
||||
]
|
||||
| MessageHandler[AgentT, ReceivesT, None]
|
||||
):
|
||||
"""Decorator for event message handlers.
|
||||
|
||||
Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle event messages.
|
||||
These methods must have a specific signature that needs to be followed for it to be valid:
|
||||
|
||||
- The method must be an `async` method.
|
||||
- The method must be decorated with the `@message_handler` decorator.
|
||||
- The method must have exactly 3 arguments:
|
||||
1. `self`
|
||||
2. `message`: The event message to be handled, this must be type-hinted with the message type that it is intended to handle.
|
||||
3. `ctx`: A :class:`agentdhal_core.MessageContext` object.
|
||||
- The method must return `None`.
|
||||
|
||||
Handlers can handle more than one message type by accepting a Union of the message types.
|
||||
|
||||
Args:
|
||||
func: The function to be decorated.
|
||||
strict: If `True`, the handler will raise an exception if the message type is not in the target types. If `False`, it will log a warning instead.
|
||||
match: A function that takes the message and the context as arguments and returns a boolean. This is used for secondary routing after the message type. For handlers addressing the same message type, the match function is applied in alphabetical order of the handlers and the first matching handler will be called while the rest are skipped. If `None`, the first handler in alphabetical order matching the same message type will be called.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, None]:
|
||||
type_hints = get_type_hints(func)
|
||||
if "message" not in type_hints:
|
||||
raise AssertionError("message parameter not found in function signature")
|
||||
|
||||
if "return" not in type_hints:
|
||||
raise AssertionError("return parameter not found in function signature")
|
||||
|
||||
# Get the type of the message parameter
|
||||
target_types = get_types(type_hints["message"])
|
||||
if target_types is None:
|
||||
raise AssertionError("Message type not found. Please provide a type hint for the message parameter.")
|
||||
|
||||
return_types = get_types(type_hints["return"])
|
||||
|
||||
if return_types is None:
|
||||
raise AssertionError("Return type not found. Please use `None` as the type hint of the return type.")
|
||||
|
||||
# Convert target_types to list and stash
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
else:
|
||||
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
||||
|
||||
return_value = await func(self, message, ctx) # type: ignore
|
||||
|
||||
if return_value is not None:
|
||||
if strict:
|
||||
raise ValueError(f"Return type {type(return_value)} is not None.")
|
||||
else:
|
||||
logger.warning(f"Return type {type(return_value)} is not None. It will be ignored.")
|
||||
|
||||
return None
|
||||
|
||||
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, None], wrapper)
|
||||
wrapper_handler.target_types = list(target_types)
|
||||
wrapper_handler.produces_types = list(return_types)
|
||||
wrapper_handler.is_message_handler = True
|
||||
# Wrap the match function with a check on the is_rpc flag.
|
||||
wrapper_handler.router = lambda _message, _ctx: (not _ctx.is_rpc) and (match(_message, _ctx) if match else True)
|
||||
|
||||
return wrapper_handler
|
||||
|
||||
if func is None and not callable(func):
|
||||
return decorator
|
||||
elif callable(func):
|
||||
return decorator(func)
|
||||
else:
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
|
||||
@overload
|
||||
def rpc(
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, ProducesT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def rpc(
|
||||
func: None = None,
|
||||
*,
|
||||
match: None = ...,
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def rpc(
|
||||
func: None = None,
|
||||
*,
|
||||
match: Callable[[ReceivesT, MessageContext], bool],
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
def rpc(
|
||||
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]
|
||||
| MessageHandler[AgentT, ReceivesT, ProducesT]
|
||||
):
|
||||
"""Decorator for RPC message handlers.
|
||||
|
||||
Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle RPC messages.
|
||||
These methods must have a specific signature that needs to be followed for it to be valid:
|
||||
|
||||
- The method must be an `async` method.
|
||||
- The method must be decorated with the `@message_handler` decorator.
|
||||
- The method must have exactly 3 arguments:
|
||||
1. `self`
|
||||
2. `message`: The message to be handled, this must be type-hinted with the message type that it is intended to handle.
|
||||
3. `ctx`: A :class:`agentdhal_core.MessageContext` object.
|
||||
- The method must be type hinted with what message types it can return as a response, or it can return `None` if it does not return anything.
|
||||
|
||||
Handlers can handle more than one message type by accepting a Union of the message types. It can also return more than one message type by returning a Union of the message types.
|
||||
|
||||
Args:
|
||||
func: The function to be decorated.
|
||||
strict: If `True`, the handler will raise an exception if the message type or return type is not in the target types. If `False`, it will log a warning instead.
|
||||
match: A function that takes the message and the context as arguments and returns a boolean. This is used for secondary routing after the message type. For handlers addressing the same message type, the match function is applied in alphabetical order of the handlers and the first matching handler will be called while the rest are skipped. If `None`, the first handler in alphabetical order matching the same message type will be called.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, ProducesT]:
|
||||
type_hints = get_type_hints(func)
|
||||
if "message" not in type_hints:
|
||||
raise AssertionError("message parameter not found in function signature")
|
||||
|
||||
if "return" not in type_hints:
|
||||
raise AssertionError("return parameter not found in function signature")
|
||||
|
||||
# Get the type of the message parameter
|
||||
target_types = get_types(type_hints["message"])
|
||||
if target_types is None:
|
||||
raise AssertionError("Message type not found")
|
||||
|
||||
# print(type_hints)
|
||||
return_types = get_types(type_hints["return"])
|
||||
|
||||
if return_types is None:
|
||||
raise AssertionError("Return type not found")
|
||||
|
||||
# Convert target_types to list and stash
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
else:
|
||||
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
||||
|
||||
return_value = await func(self, message, ctx)
|
||||
|
||||
if AnyType not in return_types and type(return_value) not in return_types:
|
||||
if strict:
|
||||
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
else:
|
||||
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
|
||||
return return_value
|
||||
|
||||
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper)
|
||||
wrapper_handler.target_types = list(target_types)
|
||||
wrapper_handler.produces_types = list(return_types)
|
||||
wrapper_handler.is_message_handler = True
|
||||
wrapper_handler.router = lambda _message, _ctx: (_ctx.is_rpc) and (match(_message, _ctx) if match else True)
|
||||
|
||||
return wrapper_handler
|
||||
|
||||
if func is None and not callable(func):
|
||||
return decorator
|
||||
elif callable(func):
|
||||
return decorator(func)
|
||||
else:
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
|
||||
class RoutedAgent(BaseAgent):
|
||||
"""A base class for agents that route messages to handlers based on the type of the message
|
||||
and optional matching functions.
|
||||
|
||||
To create a routed agent, subclass this class and add message handlers as methods decorated with
|
||||
either :func:`event` or :func:`rpc` decorator.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dataclasses import dataclass
|
||||
from agentdhal_core import MessageContext
|
||||
from agentdhal_core import RoutedAgent, event, rpc
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageWithContent:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Response:
|
||||
pass
|
||||
|
||||
|
||||
class MyAgent(RoutedAgent):
|
||||
def __init__(self):
|
||||
super().__init__("MyAgent")
|
||||
|
||||
@event
|
||||
async def handle_event_message(self, message: Message, ctx: MessageContext) -> None:
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(MessageWithContent("event handled"), ctx.topic_id)
|
||||
|
||||
@rpc(match=lambda message, ctx: message.content == "special") # type: ignore
|
||||
async def handle_special_rpc_message(self, message: MessageWithContent, ctx: MessageContext) -> Response:
|
||||
return Response()
|
||||
"""
|
||||
|
||||
def __init__(self, description: str) -> None:
|
||||
# Self is already bound to the handlers
|
||||
self._handlers: DefaultDict[
|
||||
Type[Any],
|
||||
List[MessageHandler[RoutedAgent, Any, Any]],
|
||||
] = DefaultDict(list)
|
||||
|
||||
handlers = self._discover_handlers()
|
||||
for message_handler in handlers:
|
||||
for target_type in message_handler.target_types:
|
||||
self._handlers[target_type].append(message_handler)
|
||||
|
||||
super().__init__(description)
|
||||
|
||||
async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None:
|
||||
"""Handle a message by routing it to the appropriate message handler.
|
||||
Do not override this method in subclasses. Instead, add message handlers as methods decorated with
|
||||
either the :func:`event` or :func:`rpc` decorator."""
|
||||
key_type: Type[Any] = type(message) # type: ignore
|
||||
handlers = self._handlers.get(key_type) # type: ignore
|
||||
if handlers is not None:
|
||||
# Iterate over all handlers for this matching message type.
|
||||
# Call the first handler whose router returns True and then return the result.
|
||||
for h in handlers:
|
||||
if h.router(message, ctx):
|
||||
return await h(self, message, ctx)
|
||||
return await self.on_unhandled_message(message, ctx) # type: ignore
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
"""Called when a message is received that does not have a matching message handler.
|
||||
The default implementation logs an info message."""
|
||||
logger.info(f"Unhandled message: {message}")
|
||||
|
||||
@classmethod
|
||||
def _discover_handlers(cls) -> Sequence[MessageHandler[Any, Any, Any]]:
|
||||
handlers: List[MessageHandler[Any, Any, Any]] = []
|
||||
for attr in dir(cls):
|
||||
if callable(getattr(cls, attr, None)):
|
||||
# Since we are getting it from the class, self is not bound
|
||||
handler = getattr(cls, attr)
|
||||
if hasattr(handler, "is_message_handler"):
|
||||
handlers.append(cast(MessageHandler[Any, Any, Any], handler))
|
||||
return handlers
|
||||
|
||||
@classmethod
|
||||
def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]:
|
||||
# TODO handle deduplication
|
||||
handlers = cls._discover_handlers()
|
||||
types: List[Tuple[Type[Any], List[MessageSerializer[Any]]]] = []
|
||||
types.extend(cls.internal_extra_handles_types)
|
||||
for handler in handlers:
|
||||
for t in handler.target_types:
|
||||
# TODO: support different serializers
|
||||
serializers = try_get_known_serializers_for_type(t)
|
||||
if len(serializers) == 0:
|
||||
raise ValueError(f"No serializers found for type {t}.")
|
||||
|
||||
types.append((t, try_get_known_serializers_for_type(t)))
|
||||
return types
|
||||
78
agent_dhal/agentdhal_core/_runtime_impl_helpers.py
Normal file
78
agent_dhal/agentdhal_core/_runtime_impl_helpers.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from collections import defaultdict
|
||||
from typing import Awaitable, Callable, DefaultDict, List, Sequence, Set
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_type import AgentType
|
||||
from ._subscription import Subscription
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
async def get_impl(
|
||||
*,
|
||||
id_or_type: AgentId | AgentType | str,
|
||||
key: str,
|
||||
lazy: bool,
|
||||
instance_getter: Callable[[AgentId], Awaitable[Agent]],
|
||||
) -> AgentId:
|
||||
if isinstance(id_or_type, AgentId):
|
||||
if not lazy:
|
||||
await instance_getter(id_or_type)
|
||||
|
||||
return id_or_type
|
||||
|
||||
type_str = id_or_type if isinstance(id_or_type, str) else id_or_type.type
|
||||
id = AgentId(type_str, key)
|
||||
if not lazy:
|
||||
await instance_getter(id)
|
||||
|
||||
return id
|
||||
|
||||
|
||||
class SubscriptionManager:
|
||||
def __init__(self) -> None:
|
||||
self._subscriptions: List[Subscription] = []
|
||||
self._seen_topics: Set[TopicId] = set()
|
||||
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[Subscription]:
|
||||
return self._subscriptions
|
||||
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
# Check if the subscription already exists
|
||||
if any(sub == subscription for sub in self._subscriptions):
|
||||
raise ValueError("Subscription already exists")
|
||||
|
||||
self._subscriptions.append(subscription)
|
||||
self._rebuild_subscriptions(self._seen_topics)
|
||||
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
# Check if the subscription exists
|
||||
if not any(sub.id == id for sub in self._subscriptions):
|
||||
raise ValueError("Subscription does not exist")
|
||||
|
||||
def is_not_sub(x: Subscription) -> bool:
|
||||
return x.id != id
|
||||
|
||||
self._subscriptions = list(filter(is_not_sub, self._subscriptions))
|
||||
|
||||
# Rebuild the subscriptions
|
||||
self._rebuild_subscriptions(self._seen_topics)
|
||||
|
||||
async def get_subscribed_recipients(self, topic: TopicId) -> List[AgentId]:
|
||||
if topic not in self._seen_topics:
|
||||
self._build_for_new_topic(topic)
|
||||
return self._subscribed_recipients[topic]
|
||||
|
||||
# TODO: optimize this...
|
||||
def _rebuild_subscriptions(self, topics: Set[TopicId]) -> None:
|
||||
self._subscribed_recipients.clear()
|
||||
for topic in topics:
|
||||
self._build_for_new_topic(topic)
|
||||
|
||||
def _build_for_new_topic(self, topic: TopicId) -> None:
|
||||
self._seen_topics.add(topic)
|
||||
for subscription in self._subscriptions:
|
||||
if subscription.is_match(topic):
|
||||
self._subscribed_recipients[topic].append(subscription.map_to_agent(topic))
|
||||
258
agent_dhal/agentdhal_core/_serialization.py
Normal file
258
agent_dhal/agentdhal_core/_serialization.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable
|
||||
|
||||
from google.protobuf import any_pb2
|
||||
from google.protobuf.message import Message
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._type_helpers import is_union
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MessageSerializer(Protocol[T]):
|
||||
@property
|
||||
def data_content_type(self) -> str: ...
|
||||
|
||||
@property
|
||||
def type_name(self) -> str: ...
|
||||
|
||||
def deserialize(self, payload: bytes) -> T: ...
|
||||
|
||||
def serialize(self, message: T) -> bytes: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IsDataclass(Protocol):
|
||||
# as already noted in comments, checking for this attribute is currently
|
||||
# the most reliable way to ascertain that something is a dataclass
|
||||
__dataclass_fields__: ClassVar[Dict[str, Any]]
|
||||
|
||||
|
||||
def is_dataclass(cls: type[Any]) -> bool:
|
||||
return hasattr(cls, "__dataclass_fields__")
|
||||
|
||||
|
||||
def has_nested_dataclass(cls: type[IsDataclass]) -> bool:
|
||||
# iterate fields and check if any of them are dataclasses
|
||||
return any(is_dataclass(f.type) for f in cls.__dataclass_fields__.values())
|
||||
|
||||
|
||||
def contains_a_union(cls: type[IsDataclass]) -> bool:
|
||||
return any(is_union(f.type) for f in cls.__dataclass_fields__.values())
|
||||
|
||||
|
||||
def has_nested_base_model(cls: type[IsDataclass]) -> bool:
|
||||
for f in fields(cls):
|
||||
field_type = f.type
|
||||
# Resolve forward references and other annotations
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
# If the field type is directly a subclass of BaseModel
|
||||
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
return True
|
||||
|
||||
# If the field type is a generic type like List[BaseModel], Tuple[BaseModel, ...], etc.
|
||||
if origin is not None and args:
|
||||
for arg in args:
|
||||
# Recursively check the argument types
|
||||
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
||||
return True
|
||||
elif get_origin(arg) is not None:
|
||||
# Handle nested generics like List[List[BaseModel]]
|
||||
if has_nested_base_model_in_type(arg):
|
||||
return True
|
||||
# Handle Union types
|
||||
elif args:
|
||||
for arg in args:
|
||||
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
||||
return True
|
||||
elif get_origin(arg) is not None:
|
||||
if has_nested_base_model_in_type(arg):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def has_nested_base_model_in_type(tp: Any) -> bool:
|
||||
"""Helper function to check if a type or its arguments is a BaseModel subclass."""
|
||||
origin = get_origin(tp)
|
||||
args = get_args(tp)
|
||||
|
||||
if isinstance(tp, type) and issubclass(tp, BaseModel):
|
||||
return True
|
||||
if origin is not None and args:
|
||||
for arg in args:
|
||||
if has_nested_base_model_in_type(arg):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
DataclassT = TypeVar("DataclassT", bound=IsDataclass)
|
||||
|
||||
JSON_DATA_CONTENT_TYPE = "application/json"
|
||||
"""JSON data content type"""
|
||||
|
||||
# TODO: what's the correct content type? There seems to be some disagreement over what it should be
|
||||
PROTOBUF_DATA_CONTENT_TYPE = "application/x-protobuf"
|
||||
"""Protobuf data content type"""
|
||||
|
||||
|
||||
class DataclassJsonMessageSerializer(MessageSerializer[DataclassT]):
|
||||
def __init__(self, cls: type[DataclassT]) -> None:
|
||||
if contains_a_union(cls):
|
||||
raise ValueError("Dataclass has a union type, which is not supported. To use a union, use a Pydantic model")
|
||||
|
||||
if has_nested_dataclass(cls) or has_nested_base_model(cls):
|
||||
raise ValueError(
|
||||
"Dataclass has nested dataclasses or base models, which are not supported. To use nested types, use a Pydantic model"
|
||||
)
|
||||
|
||||
self.cls = cls
|
||||
|
||||
@property
|
||||
def data_content_type(self) -> str:
|
||||
return JSON_DATA_CONTENT_TYPE
|
||||
|
||||
@property
|
||||
def type_name(self) -> str:
|
||||
return _type_name(self.cls)
|
||||
|
||||
def deserialize(self, payload: bytes) -> DataclassT:
|
||||
message_str = payload.decode("utf-8")
|
||||
return self.cls(**json.loads(message_str))
|
||||
|
||||
def serialize(self, message: DataclassT) -> bytes:
|
||||
return json.dumps(asdict(message)).encode("utf-8")
|
||||
|
||||
|
||||
PydanticT = TypeVar("PydanticT", bound=BaseModel)
|
||||
|
||||
|
||||
class PydanticJsonMessageSerializer(MessageSerializer[PydanticT]):
|
||||
def __init__(self, cls: type[PydanticT]) -> None:
|
||||
self.cls = cls
|
||||
|
||||
@property
|
||||
def data_content_type(self) -> str:
|
||||
return JSON_DATA_CONTENT_TYPE
|
||||
|
||||
@property
|
||||
def type_name(self) -> str:
|
||||
return _type_name(self.cls)
|
||||
|
||||
def deserialize(self, payload: bytes) -> PydanticT:
|
||||
message_str = payload.decode("utf-8")
|
||||
return self.cls.model_validate_json(message_str)
|
||||
|
||||
def serialize(self, message: PydanticT) -> bytes:
|
||||
return message.model_dump_json().encode("utf-8")
|
||||
|
||||
|
||||
ProtobufT = TypeVar("ProtobufT", bound=Message)
|
||||
|
||||
|
||||
# This class serializes to and from a google.protobuf.Any message that has been serialized to a string
|
||||
class ProtobufMessageSerializer(MessageSerializer[ProtobufT]):
|
||||
def __init__(self, cls: type[ProtobufT]) -> None:
|
||||
self.cls = cls
|
||||
|
||||
@property
|
||||
def data_content_type(self) -> str:
|
||||
return PROTOBUF_DATA_CONTENT_TYPE
|
||||
|
||||
@property
|
||||
def type_name(self) -> str:
|
||||
return _type_name(self.cls)
|
||||
|
||||
def deserialize(self, payload: bytes) -> ProtobufT:
|
||||
# Parse payload into a proto any
|
||||
any_proto = any_pb2.Any()
|
||||
any_proto.ParseFromString(payload)
|
||||
|
||||
destination_message = self.cls()
|
||||
|
||||
if not any_proto.Unpack(destination_message): # type: ignore
|
||||
raise ValueError(f"Failed to unpack payload into {self.cls}")
|
||||
|
||||
return destination_message
|
||||
|
||||
def serialize(self, message: ProtobufT) -> bytes:
|
||||
any_proto = any_pb2.Any()
|
||||
any_proto.Pack(message) # type: ignore
|
||||
return any_proto.SerializeToString()
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnknownPayload:
|
||||
type_name: str
|
||||
data_content_type: str
|
||||
payload: bytes
|
||||
|
||||
|
||||
def _type_name(cls: type[Any] | Any) -> str:
|
||||
# If cls is a protobuf, then we need to determine the descriptor
|
||||
if isinstance(cls, type):
|
||||
if issubclass(cls, Message):
|
||||
return cast(str, cls.DESCRIPTOR.full_name)
|
||||
elif isinstance(cls, Message):
|
||||
return cast(str, cls.DESCRIPTOR.full_name)
|
||||
|
||||
if isinstance(cls, type):
|
||||
return cls.__name__
|
||||
else:
|
||||
return cast(str, cls.__class__.__name__)
|
||||
|
||||
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
def try_get_known_serializers_for_type(cls: type[Any]) -> list[MessageSerializer[Any]]:
|
||||
""":meta private:"""
|
||||
|
||||
serializers: List[MessageSerializer[Any]] = []
|
||||
if issubclass(cls, BaseModel):
|
||||
serializers.append(PydanticJsonMessageSerializer(cls))
|
||||
elif is_dataclass(cls):
|
||||
serializers.append(DataclassJsonMessageSerializer(cls))
|
||||
elif issubclass(cls, Message):
|
||||
serializers.append(ProtobufMessageSerializer(cls))
|
||||
|
||||
return serializers
|
||||
|
||||
|
||||
class SerializationRegistry:
|
||||
""":meta private:"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# type_name, data_content_type -> serializer
|
||||
self._serializers: dict[tuple[str, str], MessageSerializer[Any]] = {}
|
||||
|
||||
def add_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
|
||||
if isinstance(serializer, Sequence):
|
||||
for c in serializer:
|
||||
self.add_serializer(c)
|
||||
return
|
||||
|
||||
self._serializers[(serializer.type_name, serializer.data_content_type)] = serializer
|
||||
|
||||
def deserialize(self, payload: bytes, *, type_name: str, data_content_type: str) -> Any:
|
||||
serializer = self._serializers.get((type_name, data_content_type))
|
||||
if serializer is None:
|
||||
return UnknownPayload(type_name, data_content_type, payload)
|
||||
|
||||
return serializer.deserialize(payload)
|
||||
|
||||
def serialize(self, message: Any, *, type_name: str, data_content_type: str) -> bytes:
|
||||
serializer = self._serializers.get((type_name, data_content_type))
|
||||
if serializer is None:
|
||||
raise ValueError(f"Unknown type {type_name} with content type {data_content_type}")
|
||||
|
||||
return serializer.serialize(message)
|
||||
|
||||
def is_registered(self, type_name: str, data_content_type: str) -> bool:
|
||||
return (type_name, data_content_type) in self._serializers
|
||||
|
||||
def type_name(self, message: Any) -> str:
|
||||
return _type_name(message)
|
||||
1029
agent_dhal/agentdhal_core/_single_threaded_agent_runtime.py
Normal file
1029
agent_dhal/agentdhal_core/_single_threaded_agent_runtime.py
Normal file
File diff suppressed because it is too large
Load Diff
65
agent_dhal/agentdhal_core/_subscription.py
Normal file
65
agent_dhal/agentdhal_core/_subscription.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Awaitable, Callable, Protocol, runtime_checkable
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Subscription(Protocol):
|
||||
"""Subscriptions define the topics that an agent is interested in."""
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
"""Get the ID of the subscription.
|
||||
|
||||
Implementations should return a unique ID for the subscription. Usually this is a UUID.
|
||||
|
||||
Returns:
|
||||
str: ID of the subscription.
|
||||
"""
|
||||
...
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Check if two subscriptions are equal.
|
||||
|
||||
Args:
|
||||
other (object): Other subscription to compare against.
|
||||
|
||||
Returns:
|
||||
bool: True if the subscriptions are equal, False otherwise.
|
||||
"""
|
||||
if not isinstance(other, Subscription):
|
||||
return False
|
||||
|
||||
return self.id == other.id
|
||||
|
||||
def is_match(self, topic_id: TopicId) -> bool:
|
||||
"""Check if a given topic_id matches the subscription.
|
||||
|
||||
Args:
|
||||
topic_id (TopicId): TopicId to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the topic_id matches the subscription, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
def map_to_agent(self, topic_id: TopicId) -> AgentId:
|
||||
"""Map a topic_id to an agent. Should only be called if `is_match` returns True for the given topic_id.
|
||||
|
||||
Args:
|
||||
topic_id (TopicId): TopicId to map.
|
||||
|
||||
Returns:
|
||||
AgentId: ID of the agent that should handle the topic_id.
|
||||
|
||||
Raises:
|
||||
CantHandleException: If the subscription cannot handle the topic_id.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# Helper alias to represent the lambdas used to define subscriptions
|
||||
UnboundSubscription = Callable[[], list[Subscription] | Awaitable[list[Subscription]]]
|
||||
33
agent_dhal/agentdhal_core/_subscription_context.py
Normal file
33
agent_dhal/agentdhal_core/_subscription_context.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, ClassVar, Generator
|
||||
|
||||
from ._agent_type import AgentType
|
||||
|
||||
|
||||
class SubscriptionInstantiationContext:
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError(
|
||||
"SubscriptionInstantiationContext cannot be instantiated. It is a static class that provides context management for subscription instantiation."
|
||||
)
|
||||
|
||||
_SUBSCRIPTION_CONTEXT_VAR: ClassVar[ContextVar[AgentType]] = ContextVar("_SUBSCRIPTION_CONTEXT_VAR")
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def populate_context(cls, ctx: AgentType) -> Generator[None, Any, None]:
|
||||
""":meta private:"""
|
||||
token = SubscriptionInstantiationContext._SUBSCRIPTION_CONTEXT_VAR.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
SubscriptionInstantiationContext._SUBSCRIPTION_CONTEXT_VAR.reset(token)
|
||||
|
||||
@classmethod
|
||||
def agent_type(cls) -> AgentType:
|
||||
try:
|
||||
return cls._SUBSCRIPTION_CONTEXT_VAR.get()
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"SubscriptionInstantiationContext.runtime() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
|
||||
) from e
|
||||
25
agent_dhal/agentdhal_core/_telemetry/__init__.py
Normal file
25
agent_dhal/agentdhal_core/_telemetry/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from ._genai import (
|
||||
trace_create_agent_span,
|
||||
trace_invoke_agent_span,
|
||||
trace_tool_span,
|
||||
)
|
||||
from ._propagation import (
|
||||
EnvelopeMetadata,
|
||||
TelemetryMetadataContainer,
|
||||
get_telemetry_envelope_metadata,
|
||||
get_telemetry_grpc_metadata,
|
||||
)
|
||||
from ._tracing import TraceHelper
|
||||
from ._tracing_config import MessageRuntimeTracingConfig
|
||||
|
||||
__all__ = [
|
||||
"EnvelopeMetadata",
|
||||
"get_telemetry_envelope_metadata",
|
||||
"get_telemetry_grpc_metadata",
|
||||
"TelemetryMetadataContainer",
|
||||
"TraceHelper",
|
||||
"MessageRuntimeTracingConfig",
|
||||
"trace_create_agent_span",
|
||||
"trace_invoke_agent_span",
|
||||
"trace_tool_span",
|
||||
]
|
||||
1
agent_dhal/agentdhal_core/_telemetry/_constants.py
Normal file
1
agent_dhal/agentdhal_core/_telemetry/_constants.py
Normal file
@@ -0,0 +1 @@
|
||||
NAMESPACE = "agentdhal"
|
||||
214
agent_dhal/agentdhal_core/_telemetry/_genai.py
Normal file
214
agent_dhal/agentdhal_core/_telemetry/_genai.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.trace import Span, SpanKind
|
||||
|
||||
from .._agent_instantiation import AgentInstantiationContext
|
||||
|
||||
# OpenTelemetry semantic convention constants for GenAI operations
|
||||
# Copied from opentelemetry-semantic-conventions to avoid dependency
|
||||
|
||||
# GenAI Agent attributes
|
||||
GEN_AI_AGENT_DESCRIPTION = "gen_ai.agent.description"
|
||||
GEN_AI_AGENT_ID = "gen_ai.agent.id"
|
||||
GEN_AI_AGENT_NAME = "gen_ai.agent.name"
|
||||
|
||||
# GenAI Operation attributes
|
||||
GEN_AI_OPERATION_NAME = "gen_ai.operation.name"
|
||||
GEN_AI_SYSTEM = "gen_ai.system"
|
||||
|
||||
# GenAI Tool attributes
|
||||
GEN_AI_TOOL_CALL_ID = "gen_ai.tool.call.id"
|
||||
GEN_AI_TOOL_DESCRIPTION = "gen_ai.tool.description"
|
||||
GEN_AI_TOOL_NAME = "gen_ai.tool.name"
|
||||
|
||||
# Error attributes
|
||||
ERROR_TYPE = "error.type"
|
||||
|
||||
|
||||
class GenAiOperationNameValues(Enum):
|
||||
"""Enum for GenAI operation name values."""
|
||||
|
||||
CHAT = "chat"
|
||||
CREATE_AGENT = "create_agent"
|
||||
EMBEDDINGS = "embeddings"
|
||||
EXECUTE_TOOL = "execute_tool"
|
||||
GENERATE_CONTENT = "generate_content"
|
||||
INVOKE_AGENT = "invoke_agent"
|
||||
TEXT_COMPLETION = "text_completion"
|
||||
|
||||
|
||||
# Constant for system name
|
||||
GENAI_SYSTEM_AUTOGEN = "agentdhal"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_tool_span(
|
||||
tool_name: str,
|
||||
*,
|
||||
tracer: Optional[trace.Tracer] = None,
|
||||
parent: Optional[Span] = None,
|
||||
tool_description: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
) -> Generator[Span, Any, None]:
|
||||
"""Context manager to create a span for tool execution following the
|
||||
OpenTelemetry Semantic conventions for generative AI systems.
|
||||
|
||||
See the GenAI semantic conventions documentation:
|
||||
`OpenTelemetry GenAI Semantic Conventions <https://opentelemetry.io/docs/specs/semconv/gen-ai/>`__
|
||||
|
||||
.. warning::
|
||||
|
||||
The GenAI Semantic Conventions are still in incubation and
|
||||
subject to changes in future releases.
|
||||
|
||||
|
||||
Args:
|
||||
tool_name (str): The name of the tool being executed.
|
||||
tracer (Optional[trace.Tracer]): The tracer to use for creating the span.
|
||||
parent (Optional[Span]): The parent span to link this span to.
|
||||
tool_description (Optional[str]): A description of the tool.
|
||||
tool_call_id (Optional[str]): A unique identifier for the tool call.
|
||||
"""
|
||||
if tracer is None:
|
||||
tracer = trace.get_tracer("agentdhal-core")
|
||||
span_attributes = {
|
||||
GEN_AI_OPERATION_NAME: GenAiOperationNameValues.EXECUTE_TOOL.value,
|
||||
GEN_AI_SYSTEM: GENAI_SYSTEM_AUTOGEN,
|
||||
GEN_AI_TOOL_NAME: tool_name,
|
||||
}
|
||||
if tool_description is not None:
|
||||
span_attributes[GEN_AI_TOOL_DESCRIPTION] = tool_description
|
||||
if tool_call_id is not None:
|
||||
span_attributes[GEN_AI_TOOL_CALL_ID] = tool_call_id
|
||||
with tracer.start_as_current_span(
|
||||
f"{GenAiOperationNameValues.EXECUTE_TOOL.value} {tool_name}",
|
||||
kind=SpanKind.INTERNAL,
|
||||
context=trace.set_span_in_context(parent) if parent else None,
|
||||
attributes=span_attributes,
|
||||
) as span:
|
||||
try:
|
||||
yield span
|
||||
except Exception as e:
|
||||
# Set the exception details on the span if an error occurs
|
||||
span.record_exception(e)
|
||||
span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
|
||||
span.set_attribute(ERROR_TYPE, type(e).__name__)
|
||||
raise
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_create_agent_span(
|
||||
agent_name: str,
|
||||
*,
|
||||
tracer: Optional[trace.Tracer] = None,
|
||||
parent: Optional[Span] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
agent_description: Optional[str] = None,
|
||||
) -> Generator[Span, Any, None]:
|
||||
"""Context manager to create a span for agent creation following the
|
||||
OpenTelemetry Semantic conventions for generative AI systems.
|
||||
|
||||
See the GenAI semantic conventions documentation:
|
||||
`OpenTelemetry GenAI Semantic Conventions <https://opentelemetry.io/docs/specs/semconv/gen-ai/>`__
|
||||
|
||||
.. warning::
|
||||
|
||||
The GenAI Semantic Conventions are still in incubation and
|
||||
subject to changes in future releases.
|
||||
|
||||
Args:
|
||||
agent_name (str): The name of the agent being created.
|
||||
tracer (Optional[trace.Tracer]): The tracer to use for creating the span.
|
||||
parent (Optional[Span]): The parent span to link this span to.
|
||||
agent_id (Optional[str]): The unique identifier for the agent.
|
||||
agent_description (Optional[str]): A description of the agent.
|
||||
"""
|
||||
if tracer is None:
|
||||
tracer = trace.get_tracer("agentdhal-core")
|
||||
span_attributes = {
|
||||
GEN_AI_OPERATION_NAME: GenAiOperationNameValues.CREATE_AGENT.value,
|
||||
GEN_AI_SYSTEM: GENAI_SYSTEM_AUTOGEN,
|
||||
GEN_AI_AGENT_NAME: agent_name,
|
||||
}
|
||||
if agent_id is None:
|
||||
# Try to see if we can get the agent ID from the current context
|
||||
try:
|
||||
agent_id = str(AgentInstantiationContext.current_agent_id())
|
||||
except RuntimeError:
|
||||
agent_id = None
|
||||
if agent_id is not None:
|
||||
span_attributes[GEN_AI_AGENT_ID] = agent_id
|
||||
if agent_description is not None:
|
||||
span_attributes[GEN_AI_AGENT_DESCRIPTION] = agent_description
|
||||
with tracer.start_as_current_span(
|
||||
f"{GenAiOperationNameValues.CREATE_AGENT.value} {agent_name}",
|
||||
kind=SpanKind.CLIENT,
|
||||
context=trace.set_span_in_context(parent) if parent else None,
|
||||
attributes=span_attributes,
|
||||
) as span:
|
||||
try:
|
||||
yield span
|
||||
except Exception as e:
|
||||
# Set the exception details on the span if an error occurs
|
||||
span.record_exception(e)
|
||||
span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
|
||||
span.set_attribute(ERROR_TYPE, type(e).__name__)
|
||||
raise
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_invoke_agent_span(
|
||||
agent_name: str,
|
||||
*,
|
||||
tracer: Optional[trace.Tracer] = None,
|
||||
parent: Optional[Span] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
agent_description: Optional[str] = None,
|
||||
) -> Generator[Span, Any, None]:
|
||||
"""Context manager to create a span for invoking an agent following the
|
||||
OpenTelemetry Semantic conventions for generative AI systems.
|
||||
|
||||
See the GenAI semantic conventions documentation:
|
||||
`OpenTelemetry GenAI Semantic Conventions <https://opentelemetry.io/docs/specs/semconv/gen-ai/>`__
|
||||
|
||||
.. warning::
|
||||
|
||||
The GenAI Semantic Conventions are still in incubation and
|
||||
subject to changes in future releases.
|
||||
|
||||
Args:
|
||||
agent_name (str): The name of the agent being invoked.
|
||||
tracer (Optional[trace.Tracer]): The tracer to use for creating the span.
|
||||
parent (Optional[Span]): The parent span to link this span to.
|
||||
agent_id (Optional[str]): The unique identifier for the agent.
|
||||
agent_description (Optional[str]): A description of the agent.
|
||||
"""
|
||||
if tracer is None:
|
||||
tracer = trace.get_tracer("agentdhal-core")
|
||||
span_attributes = {
|
||||
GEN_AI_OPERATION_NAME: GenAiOperationNameValues.INVOKE_AGENT.value,
|
||||
GEN_AI_SYSTEM: GENAI_SYSTEM_AUTOGEN,
|
||||
GEN_AI_AGENT_NAME: agent_name,
|
||||
}
|
||||
if agent_id is not None:
|
||||
span_attributes[GEN_AI_AGENT_ID] = agent_id
|
||||
if agent_description is not None:
|
||||
span_attributes[GEN_AI_AGENT_DESCRIPTION] = agent_description
|
||||
with tracer.start_as_current_span(
|
||||
f"{GenAiOperationNameValues.INVOKE_AGENT.value} {agent_name}",
|
||||
kind=SpanKind.CLIENT,
|
||||
context=trace.set_span_in_context(parent) if parent else None,
|
||||
attributes=span_attributes,
|
||||
) as span:
|
||||
try:
|
||||
yield span
|
||||
except Exception as e:
|
||||
# Set the exception details on the span if an error occurs
|
||||
span.record_exception(e)
|
||||
span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
|
||||
span.set_attribute(ERROR_TYPE, type(e).__name__)
|
||||
raise
|
||||
127
agent_dhal/agentdhal_core/_telemetry/_propagation.py
Normal file
127
agent_dhal/agentdhal_core/_telemetry/_propagation.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Mapping, Optional, Sequence
|
||||
|
||||
from opentelemetry.context import Context
|
||||
from opentelemetry.propagate import extract
|
||||
from opentelemetry.trace import Link, get_current_span
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class EnvelopeMetadata:
|
||||
"""Metadata for an envelope."""
|
||||
|
||||
traceparent: Optional[str] = None
|
||||
tracestate: Optional[str] = None
|
||||
links: Optional[Sequence[Link]] = None
|
||||
|
||||
|
||||
def _get_carrier_for_envelope_metadata(envelope_metadata: EnvelopeMetadata) -> Dict[str, str]:
|
||||
carrier: Dict[str, str] = {}
|
||||
if envelope_metadata.traceparent is not None:
|
||||
carrier["traceparent"] = envelope_metadata.traceparent
|
||||
if envelope_metadata.tracestate is not None:
|
||||
carrier["tracestate"] = envelope_metadata.tracestate
|
||||
return carrier
|
||||
|
||||
|
||||
def get_telemetry_envelope_metadata() -> EnvelopeMetadata:
|
||||
"""
|
||||
Retrieves the telemetry envelope metadata.
|
||||
|
||||
Returns:
|
||||
EnvelopeMetadata: The envelope metadata containing the traceparent and tracestate.
|
||||
"""
|
||||
carrier: Dict[str, str] = {}
|
||||
TraceContextTextMapPropagator().inject(carrier)
|
||||
return EnvelopeMetadata(
|
||||
traceparent=carrier.get("traceparent"),
|
||||
tracestate=carrier.get("tracestate"),
|
||||
)
|
||||
|
||||
|
||||
def _get_carrier_for_remote_call_metadata(remote_call_metadata: Mapping[str, str]) -> Dict[str, str]:
|
||||
carrier: Dict[str, str] = {}
|
||||
traceparent = remote_call_metadata.get("traceparent")
|
||||
tracestate = remote_call_metadata.get("tracestate")
|
||||
if traceparent:
|
||||
carrier["traceparent"] = traceparent
|
||||
if tracestate:
|
||||
carrier["tracestate"] = tracestate
|
||||
return carrier
|
||||
|
||||
|
||||
def get_telemetry_grpc_metadata(existingMetadata: Optional[Mapping[str, str]] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Retrieves the telemetry gRPC metadata.
|
||||
|
||||
Args:
|
||||
existingMetadata (Optional[Mapping[str, str]]): The existing metadata to include in the gRPC metadata.
|
||||
|
||||
Returns:
|
||||
Mapping[str, str]: The gRPC metadata containing the traceparent and tracestate.
|
||||
"""
|
||||
carrier: Dict[str, str] = {}
|
||||
TraceContextTextMapPropagator().inject(carrier)
|
||||
traceparent = carrier.get("traceparent")
|
||||
tracestate = carrier.get("tracestate")
|
||||
metadata: Dict[str, str] = {}
|
||||
if existingMetadata is not None:
|
||||
for key, value in existingMetadata.items():
|
||||
metadata[key] = value
|
||||
if traceparent is not None:
|
||||
metadata["traceparent"] = traceparent
|
||||
if tracestate is not None:
|
||||
metadata["tracestate"] = tracestate
|
||||
return metadata
|
||||
|
||||
|
||||
TelemetryMetadataContainer = Optional[EnvelopeMetadata] | Mapping[str, str]
|
||||
|
||||
|
||||
def get_telemetry_context(metadata: TelemetryMetadataContainer) -> Context:
|
||||
"""
|
||||
Retrieves the telemetry context from the given metadata.
|
||||
|
||||
Args:
|
||||
metadata (Optional[EnvelopeMetadata]): The metadata containing the telemetry context.
|
||||
|
||||
Returns:
|
||||
Context: The telemetry context extracted from the metadata, or an empty context if the metadata is None.
|
||||
"""
|
||||
if metadata is None:
|
||||
return Context()
|
||||
elif isinstance(metadata, EnvelopeMetadata):
|
||||
return extract(_get_carrier_for_envelope_metadata(metadata))
|
||||
elif hasattr(metadata, "__getitem__"):
|
||||
return extract(_get_carrier_for_remote_call_metadata(metadata))
|
||||
else:
|
||||
raise ValueError(f"Unknown metadata type: {type(metadata)}")
|
||||
|
||||
|
||||
def get_telemetry_links(
|
||||
metadata: TelemetryMetadataContainer,
|
||||
) -> Optional[Sequence[Link]]:
|
||||
"""
|
||||
Retrieves the telemetry links from the given metadata.
|
||||
|
||||
Args:
|
||||
metadata (Optional[EnvelopeMetadata]): The metadata containing the telemetry links.
|
||||
|
||||
Returns:
|
||||
Optional[Sequence[Link]]: The telemetry links extracted from the metadata, or None if there are no links.
|
||||
"""
|
||||
if metadata is None:
|
||||
return None
|
||||
elif isinstance(metadata, EnvelopeMetadata):
|
||||
context = extract(_get_carrier_for_envelope_metadata(metadata))
|
||||
elif hasattr(metadata, "__getitem__"):
|
||||
context = extract(_get_carrier_for_remote_call_metadata(metadata))
|
||||
else:
|
||||
return None
|
||||
# Retrieve the extracted SpanContext from the context.
|
||||
linked_span = get_current_span(context)
|
||||
# Use the linked span to get the SpanContext.
|
||||
span_context = linked_span.get_span_context()
|
||||
# Create a Link object using the SpanContext.
|
||||
return [Link(span_context)]
|
||||
99
agent_dhal/agentdhal_core/_telemetry/_tracing.py
Normal file
99
agent_dhal/agentdhal_core/_telemetry/_tracing.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import contextlib
|
||||
import os
|
||||
from typing import Dict, Generic, Iterator, Optional
|
||||
|
||||
from opentelemetry.trace import NoOpTracerProvider, Span, SpanKind, TracerProvider, get_tracer_provider
|
||||
from opentelemetry.util import types
|
||||
|
||||
from ._propagation import TelemetryMetadataContainer, get_telemetry_links
|
||||
from ._tracing_config import Destination, ExtraAttributes, Operation, TracingConfig
|
||||
|
||||
|
||||
class TraceHelper(Generic[Operation, Destination, ExtraAttributes]):
|
||||
"""
|
||||
TraceHelper is a utility class to assist with tracing operations using OpenTelemetry.
|
||||
|
||||
This class provides a context manager `trace_block` to create and manage spans for tracing operations,
|
||||
following semantic conventions and supporting nested spans through metadata contexts.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tracer_provider: TracerProvider | None,
|
||||
instrumentation_builder_config: TracingConfig[Operation, Destination, ExtraAttributes],
|
||||
) -> None:
|
||||
self.instrumentation_builder_config = instrumentation_builder_config
|
||||
|
||||
disable_runtime_tracing = os.environ.get("AUTOGEN_DISABLE_RUNTIME_TRACING") == "true"
|
||||
if disable_runtime_tracing:
|
||||
self.tracer_provider: TracerProvider = NoOpTracerProvider()
|
||||
self.tracer = self.tracer_provider.get_tracer(f"agentdhal {instrumentation_builder_config.name}")
|
||||
return
|
||||
|
||||
# Evaluate in order: first try tracer_provider param, then get_tracer_provider(), finally fallback to NoOp
|
||||
# This allows for nested tracing with a default tracer provided by the user
|
||||
self.tracer_provider = tracer_provider or get_tracer_provider() or NoOpTracerProvider()
|
||||
self.tracer = self.tracer_provider.get_tracer(f"agentdhal {instrumentation_builder_config.name}")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def trace_block(
|
||||
self,
|
||||
operation: Operation,
|
||||
destination: Destination,
|
||||
parent: Optional[TelemetryMetadataContainer],
|
||||
*,
|
||||
extraAttributes: ExtraAttributes | None = None,
|
||||
kind: Optional[SpanKind] = None,
|
||||
attributes: Optional[types.Attributes] = None,
|
||||
start_time: Optional[int] = None,
|
||||
record_exception: bool = True,
|
||||
set_status_on_exception: bool = True,
|
||||
end_on_exit: bool = True,
|
||||
) -> Iterator[Span]:
|
||||
"""
|
||||
Thin wrapper on top of start_as_current_span.
|
||||
1. It helps us follow semantic conventions
|
||||
2. It helps us get contexts from metadata so we can get nested spans
|
||||
|
||||
Args:
|
||||
operation (MessagingOperation): The messaging operation being performed.
|
||||
destination (MessagingDestination): The messaging destination being used.
|
||||
parent Optional[TelemetryMetadataContainer]: The parent telemetry metadta context
|
||||
kind (SpanKind, optional): The kind of span. If not provided, it maps to PRODUCER or CONSUMER depending on the operation.
|
||||
extraAttributes (ExtraAttributes, optional): Additional defined attributes for the span. Defaults to None.
|
||||
attributes (Optional[types.Attributes], optional): Additional non-defined attributes for the span. Defaults to None.
|
||||
start_time (Optional[int], optional): The start time of the span. Defaults to None.
|
||||
record_exception (bool, optional): Whether to record exceptions. Defaults to True.
|
||||
set_status_on_exception (bool, optional): Whether to set the status on exception. Defaults to True.
|
||||
end_on_exit (bool, optional): Whether to end the span on exit. Defaults to True.
|
||||
|
||||
Yields:
|
||||
Iterator[Span]: The span object.
|
||||
|
||||
"""
|
||||
span_name = self.instrumentation_builder_config.get_span_name(operation, destination)
|
||||
span_kind = kind or self.instrumentation_builder_config.get_span_kind(operation)
|
||||
# context = get_telemetry_context(parent) if parent else None
|
||||
context = None # TODO: we may need to remove other code for using custom context.
|
||||
links = get_telemetry_links(parent) if parent else None
|
||||
attributes_with_defaults: Dict[str, types.AttributeValue] = {}
|
||||
for key, value in (attributes or {}).items():
|
||||
attributes_with_defaults[key] = value
|
||||
instrumentation_attributes = self.instrumentation_builder_config.build_attributes(
|
||||
operation, destination, extraAttributes
|
||||
)
|
||||
for key, value in instrumentation_attributes.items():
|
||||
attributes_with_defaults[key] = value
|
||||
with self.tracer.start_as_current_span(
|
||||
span_name,
|
||||
context,
|
||||
span_kind,
|
||||
attributes_with_defaults,
|
||||
links,
|
||||
start_time,
|
||||
record_exception,
|
||||
set_status_on_exception,
|
||||
end_on_exit,
|
||||
) as span:
|
||||
yield span
|
||||
201
agent_dhal/agentdhal_core/_telemetry/_tracing_config.py
Normal file
201
agent_dhal/agentdhal_core/_telemetry/_tracing_config.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Generic, List, Literal, TypedDict, TypeVar, Union
|
||||
|
||||
from opentelemetry.trace import SpanKind
|
||||
from opentelemetry.util import types
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from .._agent_id import AgentId
|
||||
from .._topic import TopicId
|
||||
from ._constants import NAMESPACE
|
||||
|
||||
logger = logging.getLogger("agentdhal_core")
|
||||
event_logger = logging.getLogger("agentdhal_core.events")
|
||||
|
||||
Operation = TypeVar("Operation", bound=str)
|
||||
Destination = TypeVar("Destination")
|
||||
ExtraAttributes = TypeVar("ExtraAttributes")
|
||||
|
||||
|
||||
class TracingConfig(ABC, Generic[Operation, Destination, ExtraAttributes]):
|
||||
"""
|
||||
A protocol that defines the configuration for instrumentation.
|
||||
|
||||
This protocol specifies the required properties and methods that any
|
||||
instrumentation configuration class must implement. It includes a
|
||||
property to get the name of the module being instrumented and a method
|
||||
to build attributes for the instrumentation configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Returns:
|
||||
The name of the module that is being instrumented.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def build_attributes(
|
||||
self,
|
||||
operation: Operation,
|
||||
destination: Destination,
|
||||
extraAttributes: ExtraAttributes | None,
|
||||
) -> Dict[str, types.AttributeValue]:
|
||||
"""
|
||||
Builds the attributes for the instrumentation configuration.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: The attributes for the instrumentation configuration.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_span_name(
|
||||
self,
|
||||
operation: Operation,
|
||||
destination: Destination,
|
||||
) -> str:
|
||||
"""
|
||||
Returns the span name based on the given operation and destination.
|
||||
|
||||
Parameters:
|
||||
operation (MessagingOperation): The messaging operation.
|
||||
destination (Optional[MessagingDestination]): The messaging destination.
|
||||
|
||||
Returns:
|
||||
str: The span name.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_span_kind(
|
||||
self,
|
||||
operation: Operation,
|
||||
) -> SpanKind:
|
||||
"""
|
||||
Determines the span kind based on the given messaging operation.
|
||||
|
||||
Parameters:
|
||||
operation (MessagingOperation): The messaging operation.
|
||||
|
||||
Returns:
|
||||
SpanKind: The span kind based on the messaging operation.
|
||||
"""
|
||||
|
||||
|
||||
class ExtraMessageRuntimeAttributes(TypedDict):
|
||||
message_size: NotRequired[int]
|
||||
message_type: NotRequired[str]
|
||||
|
||||
|
||||
MessagingDestination = Union[AgentId, TopicId, str, None]
|
||||
MessagingOperation = Literal["create", "send", "publish", "receive", "intercept", "process", "ack"]
|
||||
|
||||
|
||||
class MessageRuntimeTracingConfig(
|
||||
TracingConfig[MessagingOperation, MessagingDestination, ExtraMessageRuntimeAttributes]
|
||||
):
|
||||
"""
|
||||
A class that defines the configuration for message runtime instrumentation.
|
||||
|
||||
This class implements the TracingConfig protocol and provides
|
||||
the name of the module being instrumented and the attributes for the
|
||||
instrumentation configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, runtime_name: str) -> None:
|
||||
self._runtime_name = runtime_name
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._runtime_name
|
||||
|
||||
def build_attributes(
|
||||
self,
|
||||
operation: MessagingOperation,
|
||||
destination: MessagingDestination,
|
||||
extraAttributes: ExtraMessageRuntimeAttributes | None,
|
||||
) -> Dict[str, types.AttributeValue]:
|
||||
attrs: Dict[str, types.AttributeValue] = {
|
||||
"messaging.operation": self._get_operation_type(operation),
|
||||
"messaging.destination": self._get_destination_str(destination),
|
||||
}
|
||||
if extraAttributes:
|
||||
# TODO: Make this more pythonic?
|
||||
if "message_size" in extraAttributes:
|
||||
attrs["messaging.message.envelope.size"] = extraAttributes["message_size"]
|
||||
if "message_type" in extraAttributes:
|
||||
attrs["messaging.message.type"] = extraAttributes["message_type"]
|
||||
return attrs
|
||||
|
||||
def get_span_name(
|
||||
self,
|
||||
operation: MessagingOperation,
|
||||
destination: MessagingDestination,
|
||||
) -> str:
|
||||
"""
|
||||
Returns the span name based on the given operation and destination.
|
||||
Semantic Conventions - https://opentelemetry.io/docs/specs/semconv/messaging/messaging-spans/#span-name
|
||||
|
||||
Parameters:
|
||||
operation (MessagingOperation): The messaging operation.
|
||||
destination (Optional[MessagingDestination]): The messaging destination.
|
||||
|
||||
Returns:
|
||||
str: The span name.
|
||||
"""
|
||||
span_parts: List[str] = [operation]
|
||||
destination_str = self._get_destination_str(destination)
|
||||
if destination_str:
|
||||
span_parts.append(destination_str)
|
||||
span_name = " ".join(span_parts)
|
||||
return f"{NAMESPACE} {span_name}"
|
||||
|
||||
def get_span_kind(
|
||||
self,
|
||||
operation: MessagingOperation,
|
||||
) -> SpanKind:
|
||||
"""
|
||||
Determines the span kind based on the given messaging operation.
|
||||
Semantic Conventions - https://opentelemetry.io/docs/specs/semconv/messaging/messaging-spans/#span-kind
|
||||
|
||||
Parameters:
|
||||
operation (MessagingOperation): The messaging operation.
|
||||
|
||||
Returns:
|
||||
SpanKind: The span kind based on the messaging operation.
|
||||
"""
|
||||
if operation in ["create", "send", "publish"]:
|
||||
return SpanKind.PRODUCER
|
||||
elif operation in ["receive", "intercept", "process", "ack"]:
|
||||
return SpanKind.CONSUMER
|
||||
else:
|
||||
return SpanKind.CLIENT
|
||||
|
||||
# TODO: Use stringified convention
|
||||
def _get_destination_str(self, destination: MessagingDestination) -> str:
|
||||
if isinstance(destination, AgentId):
|
||||
return f"{destination.type}.({destination.key})-A"
|
||||
elif isinstance(destination, TopicId):
|
||||
return f"{destination.type}.({destination.source})-T"
|
||||
elif isinstance(destination, str):
|
||||
return destination
|
||||
elif destination is None:
|
||||
return ""
|
||||
else:
|
||||
raise ValueError(f"Unknown destination type: {type(destination)}")
|
||||
|
||||
def _get_operation_type(self, operation: MessagingOperation) -> str:
|
||||
if operation in ["send", "publish"]:
|
||||
return "publish"
|
||||
if operation in ["create"]:
|
||||
return "create"
|
||||
elif operation in ["receive", "intercept", "ack"]:
|
||||
return "receive"
|
||||
elif operation in ["process"]:
|
||||
return "process"
|
||||
else:
|
||||
return "Unknown"
|
||||
47
agent_dhal/agentdhal_core/_topic.py
Normal file
47
agent_dhal/agentdhal_core/_topic.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
def is_valid_topic_type(value: str) -> bool:
|
||||
return bool(re.match(r"^[\w\-\.\:\=]+\Z", value))
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class TopicId:
|
||||
"""
|
||||
TopicId defines the scope of a broadcast message. In essence, agent runtime implements a publish-subscribe model through its broadcast API: when publishing a message, the topic must be specified.
|
||||
|
||||
See here for more information: :ref:`topic_and_subscription_topic`
|
||||
"""
|
||||
|
||||
type: str
|
||||
"""Type of the event that this topic_id contains. Adhere's to the cloud event spec.
|
||||
|
||||
Must match the pattern: ^[\\w\\-\\.\\:\\=]+\\Z
|
||||
|
||||
Learn more here: https://github.com/cloudevents/spec/blob/main/cloudevents/spec.md#type
|
||||
"""
|
||||
|
||||
source: str
|
||||
"""Identifies the context in which an event happened. Adhere's to the cloud event spec.
|
||||
|
||||
Learn more here: https://github.com/cloudevents/spec/blob/main/cloudevents/spec.md#source-1
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if is_valid_topic_type(self.type) is False:
|
||||
raise ValueError(f"Invalid topic type: {self.type}. Must match the pattern: ^[\\w\\-\\.\\:\\=]+\\Z")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.type}/{self.source}"
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, topic_id: str) -> Self:
|
||||
"""Convert a string of the format ``type/source`` into a TopicId"""
|
||||
items = topic_id.split("/", maxsplit=1)
|
||||
if len(items) != 2:
|
||||
raise ValueError(f"Invalid topic id: {topic_id}")
|
||||
type, source = items[0], items[1]
|
||||
return cls(type, source)
|
||||
33
agent_dhal/agentdhal_core/_type_helpers.py
Normal file
33
agent_dhal/agentdhal_core/_type_helpers.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from collections.abc import Sequence
|
||||
from types import NoneType, UnionType
|
||||
from typing import Any, Optional, Type, Union, get_args, get_origin
|
||||
|
||||
|
||||
def is_union(t: object) -> bool:
|
||||
origin = get_origin(t)
|
||||
return origin is Union or origin is UnionType
|
||||
|
||||
|
||||
def is_optional(t: object) -> bool:
|
||||
origin = get_origin(t)
|
||||
return origin is Optional
|
||||
|
||||
|
||||
# Special type to avoid the 3.10 vs 3.11+ difference of typing._SpecialForm vs typing.Any
|
||||
class AnyType:
|
||||
pass
|
||||
|
||||
|
||||
def get_types(t: object) -> Sequence[Type[Any]] | None:
|
||||
if is_union(t):
|
||||
return get_args(t)
|
||||
elif is_optional(t):
|
||||
return tuple(list(get_args(t)) + [NoneType])
|
||||
elif t is Any:
|
||||
return (AnyType,)
|
||||
elif isinstance(t, type):
|
||||
return (t,)
|
||||
elif isinstance(t, NoneType):
|
||||
return (NoneType,)
|
||||
else:
|
||||
return None
|
||||
69
agent_dhal/agentdhal_core/_type_prefix_subscription.py
Normal file
69
agent_dhal/agentdhal_core/_type_prefix_subscription.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import uuid
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_type import AgentType
|
||||
from ._subscription import Subscription
|
||||
from ._topic import TopicId
|
||||
from .exceptions import CantHandleException
|
||||
|
||||
|
||||
class TypePrefixSubscription(Subscription):
|
||||
"""This subscription matches on topics based on a prefix of the type and maps to agents using the source of the topic as the agent key.
|
||||
|
||||
This subscription causes each source to have its own agent instance.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core import TypePrefixSubscription
|
||||
|
||||
subscription = TypePrefixSubscription(topic_type_prefix="t1", agent_type="a1")
|
||||
|
||||
In this case:
|
||||
|
||||
- A topic_id with type `t1` and source `s1` will be handled by an agent of type `a1` with key `s1`
|
||||
- A topic_id with type `t1` and source `s2` will be handled by an agent of type `a1` with key `s2`.
|
||||
- A topic_id with type `t1SUFFIX` and source `s2` will be handled by an agent of type `a1` with key `s2`.
|
||||
|
||||
Args:
|
||||
topic_type_prefix (str): Topic type prefix to match against
|
||||
agent_type (str): Agent type to handle this subscription
|
||||
"""
|
||||
|
||||
def __init__(self, topic_type_prefix: str, agent_type: str | AgentType, id: str | None = None):
|
||||
self._topic_type_prefix = topic_type_prefix
|
||||
if isinstance(agent_type, AgentType):
|
||||
self._agent_type = agent_type.type
|
||||
else:
|
||||
self._agent_type = agent_type
|
||||
self._id = id or str(uuid.uuid4())
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def topic_type_prefix(self) -> str:
|
||||
return self._topic_type_prefix
|
||||
|
||||
@property
|
||||
def agent_type(self) -> str:
|
||||
return self._agent_type
|
||||
|
||||
def is_match(self, topic_id: TopicId) -> bool:
|
||||
return topic_id.type.startswith(self._topic_type_prefix)
|
||||
|
||||
def map_to_agent(self, topic_id: TopicId) -> AgentId:
|
||||
if not self.is_match(topic_id):
|
||||
raise CantHandleException("TopicId does not match the subscription")
|
||||
|
||||
return AgentId(type=self._agent_type, key=topic_id.source)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, TypePrefixSubscription):
|
||||
return False
|
||||
|
||||
return self.id == other.id or (
|
||||
self.agent_type == other.agent_type and self.topic_type_prefix == other.topic_type_prefix
|
||||
)
|
||||
66
agent_dhal/agentdhal_core/_type_subscription.py
Normal file
66
agent_dhal/agentdhal_core/_type_subscription.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import uuid
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_type import AgentType
|
||||
from ._subscription import Subscription
|
||||
from ._topic import TopicId
|
||||
from .exceptions import CantHandleException
|
||||
|
||||
|
||||
class TypeSubscription(Subscription):
|
||||
"""This subscription matches on topics based on the type and maps to agents using the source of the topic as the agent key.
|
||||
|
||||
This subscription causes each source to have its own agent instance.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core import TypeSubscription
|
||||
|
||||
subscription = TypeSubscription(topic_type="t1", agent_type="a1")
|
||||
|
||||
In this case:
|
||||
|
||||
- A topic_id with type `t1` and source `s1` will be handled by an agent of type `a1` with key `s1`
|
||||
- A topic_id with type `t1` and source `s2` will be handled by an agent of type `a1` with key `s2`.
|
||||
|
||||
Args:
|
||||
topic_type (str): Topic type to match against
|
||||
agent_type (str): Agent type to handle this subscription
|
||||
"""
|
||||
|
||||
def __init__(self, topic_type: str, agent_type: str | AgentType, id: str | None = None):
|
||||
self._topic_type = topic_type
|
||||
if isinstance(agent_type, AgentType):
|
||||
self._agent_type = agent_type.type
|
||||
else:
|
||||
self._agent_type = agent_type
|
||||
self._id = id or str(uuid.uuid4())
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def topic_type(self) -> str:
|
||||
return self._topic_type
|
||||
|
||||
@property
|
||||
def agent_type(self) -> str:
|
||||
return self._agent_type
|
||||
|
||||
def is_match(self, topic_id: TopicId) -> bool:
|
||||
return topic_id.type == self._topic_type
|
||||
|
||||
def map_to_agent(self, topic_id: TopicId) -> AgentId:
|
||||
if not self.is_match(topic_id):
|
||||
raise CantHandleException("TopicId does not match the subscription")
|
||||
|
||||
return AgentId(type=self._agent_type, key=topic_id.source)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, TypeSubscription):
|
||||
return False
|
||||
|
||||
return self.id == other.id or (self.agent_type == other.agent_type and self.topic_type == other.topic_type)
|
||||
12
agent_dhal/agentdhal_core/_types.py
Normal file
12
agent_dhal/agentdhal_core/_types.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCall:
|
||||
id: str
|
||||
# JSON args
|
||||
arguments: str
|
||||
# Function to call
|
||||
name: str
|
||||
21
agent_dhal/agentdhal_core/code_executor/__init__.py
Normal file
21
agent_dhal/agentdhal_core/code_executor/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from ._base import CodeBlock, CodeExecutor, CodeResult
|
||||
from ._func_with_reqs import (
|
||||
Alias,
|
||||
FunctionWithRequirements,
|
||||
FunctionWithRequirementsStr,
|
||||
Import,
|
||||
ImportFromModule,
|
||||
with_requirements,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CodeBlock",
|
||||
"CodeExecutor",
|
||||
"CodeResult",
|
||||
"Alias",
|
||||
"ImportFromModule",
|
||||
"Import",
|
||||
"FunctionWithRequirements",
|
||||
"FunctionWithRequirementsStr",
|
||||
"with_requirements",
|
||||
]
|
||||
102
agent_dhal/agentdhal_core/code_executor/_base.py
Normal file
102
agent_dhal/agentdhal_core/code_executor/_base.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/base.py
|
||||
# Credit to original authors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from types import TracebackType
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._cancellation_token import CancellationToken
|
||||
from .._component_config import ComponentBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeBlock:
|
||||
"""A code block extracted fromm an agent message."""
|
||||
|
||||
code: str
|
||||
language: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeResult:
|
||||
"""Result of a code execution."""
|
||||
|
||||
exit_code: int
|
||||
output: str
|
||||
|
||||
|
||||
class CodeExecutor(ABC, ComponentBase[BaseModel]):
|
||||
"""Executes code blocks and returns the result.
|
||||
|
||||
This is an abstract base class for code executors. It defines the interface
|
||||
for executing code blocks and returning the result. A concrete implementation
|
||||
of this class should be provided to execute code blocks in a specific
|
||||
environment. For example, :class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor` executes
|
||||
code blocks in a command line environment in a Docker container.
|
||||
|
||||
It is recommended for subclass to be used as a context manager to ensure
|
||||
that resources are cleaned up properly. To do this, implement the
|
||||
:meth:`~agentdhal_core.code_executor.CodeExecutor.start` and
|
||||
:meth:`~agentdhal_core.code_executor.CodeExecutor.stop` methods
|
||||
that will be called when entering and exiting the context manager.
|
||||
|
||||
"""
|
||||
|
||||
component_type = "code_executor"
|
||||
|
||||
@abstractmethod
|
||||
async def execute_code_blocks(
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> CodeResult:
|
||||
"""Execute code blocks and return the result.
|
||||
|
||||
This method should be implemented by the code executor.
|
||||
|
||||
Args:
|
||||
code_blocks (List[CodeBlock]): The code blocks to execute.
|
||||
|
||||
Returns:
|
||||
CodeResult: The result of the code execution.
|
||||
|
||||
Raises:
|
||||
ValueError: Errors in user inputs
|
||||
asyncio.TimeoutError: Code execution timeouts
|
||||
asyncio.CancelledError: CancellationToken evoked during execution
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""Start the code executor."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Stop the code executor and release any resources."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def restart(self) -> None:
|
||||
"""Restart the code executor.
|
||||
|
||||
This method should be implemented by the code executor.
|
||||
|
||||
This method is called when the agent is reset.
|
||||
"""
|
||||
...
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
|
||||
) -> Optional[bool]:
|
||||
await self.stop()
|
||||
return None
|
||||
277
agent_dhal/agentdhal_core/code_executor/_func_with_reqs.py
Normal file
277
agent_dhal/agentdhal_core/code_executor/_func_with_reqs.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/func_with_reqs.py
|
||||
# Credit to original authors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from importlib.abc import SourceLoader
|
||||
from importlib.util import module_from_spec, spec_from_loader
|
||||
from textwrap import dedent, indent
|
||||
from typing import Any, Callable, Generic, List, Sequence, Set, Tuple, TypeVar, Union
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str:
|
||||
if isinstance(func, FunctionWithRequirementsStr):
|
||||
return func.func
|
||||
|
||||
if isinstance(func, FunctionWithRequirements):
|
||||
code = inspect.getsource(func.func)
|
||||
else:
|
||||
code = inspect.getsource(func)
|
||||
# Strip the decorator
|
||||
if code.startswith("@"):
|
||||
code = code[code.index("\n") + 1 :]
|
||||
return code
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Alias:
|
||||
name: str
|
||||
alias: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ImportFromModule:
|
||||
module: str
|
||||
imports: Tuple[Union[str, Alias], ...]
|
||||
|
||||
# backward compatibility
|
||||
def __init__(
|
||||
self,
|
||||
module: str,
|
||||
imports: Union[Tuple[Union[str, Alias], ...], List[Union[str, Alias]]],
|
||||
):
|
||||
object.__setattr__(self, "module", module)
|
||||
if isinstance(imports, list):
|
||||
object.__setattr__(self, "imports", tuple(imports))
|
||||
else:
|
||||
object.__setattr__(self, "imports", imports)
|
||||
|
||||
|
||||
Import = Union[str, ImportFromModule, Alias]
|
||||
|
||||
|
||||
def _import_to_str(im: Import) -> str:
|
||||
if isinstance(im, str):
|
||||
return f"import {im}"
|
||||
elif isinstance(im, Alias):
|
||||
return f"import {im.name} as {im.alias}"
|
||||
else:
|
||||
|
||||
def to_str(i: Union[str, Alias]) -> str:
|
||||
if isinstance(i, str):
|
||||
return i
|
||||
else:
|
||||
return f"{i.name} as {i.alias}"
|
||||
|
||||
imports = ", ".join(map(to_str, im.imports))
|
||||
return f"from {im.module} import {imports}"
|
||||
|
||||
|
||||
class _StringLoader(SourceLoader):
|
||||
def __init__(self, data: str):
|
||||
self.data = data
|
||||
|
||||
def get_source(self, fullname: str) -> str:
|
||||
return self.data
|
||||
|
||||
def get_data(self, path: str) -> bytes:
|
||||
return self.data.encode("utf-8")
|
||||
|
||||
def get_filename(self, fullname: str) -> str:
|
||||
return "<not a real path>/" + fullname + ".py"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionWithRequirementsStr:
|
||||
func: str
|
||||
compiled_func: Callable[..., Any]
|
||||
_func_name: str
|
||||
python_packages: Sequence[str] = field(default_factory=list)
|
||||
global_imports: Sequence[Import] = field(default_factory=list)
|
||||
|
||||
def __init__(self, func: str, python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []):
|
||||
self.func = func
|
||||
self.python_packages = python_packages
|
||||
self.global_imports = global_imports
|
||||
|
||||
module_name = "func_module"
|
||||
loader = _StringLoader(func)
|
||||
spec = spec_from_loader(module_name, loader)
|
||||
if spec is None:
|
||||
raise ValueError("Could not create spec")
|
||||
module = module_from_spec(spec)
|
||||
if spec.loader is None:
|
||||
raise ValueError("Could not create loader")
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not compile function: {e}") from e
|
||||
|
||||
functions = inspect.getmembers(module, inspect.isfunction)
|
||||
if len(functions) != 1:
|
||||
raise ValueError("The string must contain exactly one function")
|
||||
|
||||
self._func_name, self.compiled_func = functions[0]
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError("String based function with requirement objects are not directly callable")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionWithRequirements(Generic[T, P]):
|
||||
func: Callable[P, T]
|
||||
python_packages: Sequence[str] = field(default_factory=list)
|
||||
global_imports: Sequence[Import] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_callable(
|
||||
cls, func: Callable[P, T], python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
|
||||
) -> FunctionWithRequirements[T, P]:
|
||||
return cls(python_packages=python_packages, global_imports=global_imports, func=func)
|
||||
|
||||
@staticmethod
|
||||
def from_str(
|
||||
func: str, python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
|
||||
) -> FunctionWithRequirementsStr:
|
||||
return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports)
|
||||
|
||||
# Type this based on F
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def with_requirements(
|
||||
python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
|
||||
) -> Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]:
|
||||
"""
|
||||
Decorate a function with package and import requirements for code execution environments.
|
||||
|
||||
This decorator makes a function available for reference in dynamically executed code blocks
|
||||
by wrapping it in a `FunctionWithRequirements` object that tracks its dependencies. When the
|
||||
decorated function is passed to a code executor, it can be imported by name in the executed
|
||||
code, with all dependencies automatically handled.
|
||||
|
||||
Args:
|
||||
python_packages (Sequence[str], optional): Python packages required by the function.
|
||||
Can include version specifications (e.g., ["pandas>=1.0.0"]). Defaults to [].
|
||||
global_imports (Sequence[Import], optional): Import statements required by the function.
|
||||
Can be strings ("numpy"), ImportFromModule objects, or Alias objects. Defaults to [].
|
||||
|
||||
Returns:
|
||||
Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]: A decorator that wraps
|
||||
the target function, preserving its functionality while registering its dependencies.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import tempfile
|
||||
import asyncio
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.code_executor import with_requirements, CodeBlock
|
||||
from agentdhal_extensions.code_executors.local import LocalCommandLineCodeExecutor
|
||||
import pandas
|
||||
|
||||
@with_requirements(python_packages=["pandas"], global_imports=["pandas"])
|
||||
def load_data() -> pandas.DataFrame:
|
||||
\"\"\"Load some sample data.
|
||||
|
||||
Returns:
|
||||
pandas.DataFrame: A DataFrame with sample data
|
||||
\"\"\"
|
||||
data = {
|
||||
"name": ["John", "Anna", "Peter", "Linda"],
|
||||
"location": ["New York", "Paris", "Berlin", "London"],
|
||||
"age": [24, 13, 53, 33],
|
||||
}
|
||||
return pandas.DataFrame(data)
|
||||
|
||||
async def run_example():
|
||||
# The decorated function can be used in executed code
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[load_data])
|
||||
code = f\"\"\"from {executor.functions_module} import load_data
|
||||
|
||||
# Use the imported function
|
||||
data = load_data()
|
||||
print(data['name'][0])\"\"\"
|
||||
|
||||
result = await executor.execute_code_blocks(
|
||||
code_blocks=[CodeBlock(language="python", code=code)],
|
||||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
print(result.output) # Output: John
|
||||
|
||||
# Run the async example
|
||||
asyncio.run(run_example())
|
||||
"""
|
||||
|
||||
def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]:
|
||||
func_with_reqs = FunctionWithRequirements(
|
||||
python_packages=python_packages, global_imports=global_imports, func=func
|
||||
)
|
||||
|
||||
functools.update_wrapper(func_with_reqs, func)
|
||||
return func_with_reqs
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def build_python_functions_file(
|
||||
funcs: Sequence[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]],
|
||||
) -> str:
|
||||
""":meta private:"""
|
||||
# First collect all global imports
|
||||
global_imports: Set[Import] = set()
|
||||
for func in funcs:
|
||||
if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)):
|
||||
global_imports.update(func.global_imports)
|
||||
|
||||
content = "\n".join(map(_import_to_str, global_imports)) + "\n\n"
|
||||
|
||||
for func in funcs:
|
||||
content += _to_code(func) + "\n\n"
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str:
|
||||
"""Generate a stub for a function as a string
|
||||
|
||||
Args:
|
||||
func (Callable[..., Any]): The function to generate a stub for
|
||||
|
||||
Returns:
|
||||
str: The stub for the function
|
||||
"""
|
||||
if isinstance(func, FunctionWithRequirementsStr):
|
||||
return to_stub(func.compiled_func)
|
||||
|
||||
content = f"def {func.__name__}{inspect.signature(func)}:\n"
|
||||
docstring = func.__doc__
|
||||
|
||||
if docstring:
|
||||
docstring = dedent(docstring)
|
||||
docstring = '"""' + docstring + '"""'
|
||||
docstring = indent(docstring, " ")
|
||||
content += docstring + "\n"
|
||||
|
||||
content += " ..."
|
||||
return content
|
||||
|
||||
|
||||
def to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str:
|
||||
return _to_code(func)
|
||||
|
||||
|
||||
def import_to_str(im: Import) -> str:
|
||||
return _import_to_str(im)
|
||||
17
agent_dhal/agentdhal_core/exceptions.py
Normal file
17
agent_dhal/agentdhal_core/exceptions.py
Normal file
@@ -0,0 +1,17 @@
|
||||
__all__ = ["CantHandleException", "UndeliverableException", "MessageDroppedException", "NotAccessibleError"]
|
||||
|
||||
|
||||
class CantHandleException(Exception):
|
||||
"""Raised when a handler can't handle the exception."""
|
||||
|
||||
|
||||
class UndeliverableException(Exception):
|
||||
"""Raised when a message can't be delivered."""
|
||||
|
||||
|
||||
class MessageDroppedException(Exception):
|
||||
"""Raised when a message is dropped."""
|
||||
|
||||
|
||||
class NotAccessibleError(Exception):
|
||||
"""Tried to access a value that is not accessible. For example if it is remote cannot be accessed locally."""
|
||||
294
agent_dhal/agentdhal_core/logging.py
Normal file
294
agent_dhal/agentdhal_core/logging.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, cast
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._message_handler_context import MessageHandlerContext
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
class LLMCallEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
messages: List[Dict[str, Any]],
|
||||
response: Dict[str, Any],
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""To be used by model clients to log the call to the LLM.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, Any]]): The messages used in the call. Must be json serializable.
|
||||
response (Dict[str, Any]): The response of the call. Must be json serializable.
|
||||
prompt_tokens (int): Number of tokens used in the prompt.
|
||||
completion_tokens (int): Number of tokens used in the completion.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import logging
|
||||
from agentdhal_core import EVENT_LOGGER_NAME
|
||||
from agentdhal_core.logging import LLMCallEvent
|
||||
|
||||
response = {"content": "Hello, world!"}
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.info(LLMCallEvent(prompt_tokens=10, completion_tokens=20, response=response, messages=messages))
|
||||
|
||||
"""
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["type"] = "LLMCall"
|
||||
self.kwargs["messages"] = messages
|
||||
self.kwargs["response"] = response
|
||||
self.kwargs["prompt_tokens"] = prompt_tokens
|
||||
self.kwargs["completion_tokens"] = completion_tokens
|
||||
try:
|
||||
agent_id = MessageHandlerContext.agent_id()
|
||||
except RuntimeError:
|
||||
agent_id = None
|
||||
self.kwargs["agent_id"] = None if agent_id is None else str(agent_id)
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["prompt_tokens"])
|
||||
|
||||
@property
|
||||
def completion_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["completion_tokens"])
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class LLMStreamStartEvent:
|
||||
"""To be used by model clients to log the start of a stream.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, Any]]): The messages used in the call. Must be json serializable.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import logging
|
||||
from agentdhal_core import EVENT_LOGGER_NAME
|
||||
from agentdhal_core.logging import LLMStreamStartEvent
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.info(LLMStreamStartEvent(messages=messages))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
messages: List[Dict[str, Any]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["type"] = "LLMStreamStart"
|
||||
self.kwargs["messages"] = messages
|
||||
try:
|
||||
agent_id = MessageHandlerContext.agent_id()
|
||||
except RuntimeError:
|
||||
agent_id = None
|
||||
self.kwargs["agent_id"] = None if agent_id is None else str(agent_id)
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class LLMStreamEndEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
response: Dict[str, Any],
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""To be used by model clients to log the end of a stream.
|
||||
|
||||
Args:
|
||||
response (Dict[str, Any]): The response of the call. Must be json serializable.
|
||||
prompt_tokens (int): Number of tokens used in the prompt.
|
||||
completion_tokens (int): Number of tokens used in the completion.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import logging
|
||||
from agentdhal_core import EVENT_LOGGER_NAME
|
||||
from agentdhal_core.logging import LLMStreamEndEvent
|
||||
|
||||
response = {"content": "Hello, world!"}
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.info(LLMStreamEndEvent(prompt_tokens=10, completion_tokens=20, response=response))
|
||||
|
||||
"""
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["type"] = "LLMStreamEnd"
|
||||
self.kwargs["response"] = response
|
||||
self.kwargs["prompt_tokens"] = prompt_tokens
|
||||
self.kwargs["completion_tokens"] = completion_tokens
|
||||
try:
|
||||
agent_id = MessageHandlerContext.agent_id()
|
||||
except RuntimeError:
|
||||
agent_id = None
|
||||
self.kwargs["agent_id"] = None if agent_id is None else str(agent_id)
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["prompt_tokens"])
|
||||
|
||||
@property
|
||||
def completion_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["completion_tokens"])
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class ToolCallEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
result: str,
|
||||
) -> None:
|
||||
"""Used by subclasses of :class:`~agentdhal_core.tools.BaseTool` to log executions of tools.
|
||||
|
||||
Args:
|
||||
tool_name (str): The name of the tool.
|
||||
arguments (Dict[str, Any]): The arguments of the tool. Must be json serializable.
|
||||
result (str): The result of the tool. Must be a string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core import EVENT_LOGGER_NAME
|
||||
from agentdhal_core.logging import ToolCallEvent
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.info(ToolCallEvent(tool_name="Tool1", call_id="123", arguments={"arg1": "value1"}))
|
||||
|
||||
"""
|
||||
self.kwargs: Dict[str, Any] = {}
|
||||
self.kwargs["type"] = "ToolCall"
|
||||
self.kwargs["tool_name"] = tool_name
|
||||
self.kwargs["arguments"] = arguments
|
||||
self.kwargs["result"] = result
|
||||
try:
|
||||
agent_id = MessageHandlerContext.agent_id()
|
||||
except RuntimeError:
|
||||
agent_id = None
|
||||
self.kwargs["agent_id"] = None if agent_id is None else str(agent_id)
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class MessageKind(Enum):
|
||||
DIRECT = 1
|
||||
PUBLISH = 2
|
||||
RESPOND = 3
|
||||
|
||||
|
||||
class DeliveryStage(Enum):
|
||||
SEND = 1
|
||||
DELIVER = 2
|
||||
|
||||
|
||||
class MessageEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
payload: str,
|
||||
sender: AgentId | None,
|
||||
receiver: AgentId | TopicId | None,
|
||||
kind: MessageKind,
|
||||
delivery_stage: DeliveryStage,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["payload"] = payload
|
||||
self.kwargs["sender"] = None if sender is None else str(sender)
|
||||
self.kwargs["receiver"] = None if receiver is None else str(receiver)
|
||||
self.kwargs["kind"] = str(kind)
|
||||
self.kwargs["delivery_stage"] = str(delivery_stage)
|
||||
self.kwargs["type"] = "Message"
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class MessageDroppedEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
payload: str,
|
||||
sender: AgentId | None,
|
||||
receiver: AgentId | TopicId | None,
|
||||
kind: MessageKind,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["payload"] = payload
|
||||
self.kwargs["sender"] = None if sender is None else str(sender)
|
||||
self.kwargs["receiver"] = None if receiver is None else str(receiver)
|
||||
self.kwargs["kind"] = str(kind)
|
||||
self.kwargs["type"] = "MessageDropped"
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class MessageHandlerExceptionEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
payload: str,
|
||||
handling_agent: AgentId,
|
||||
exception: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["payload"] = payload
|
||||
self.kwargs["handling_agent"] = str(handling_agent)
|
||||
self.kwargs["exception"] = str(exception)
|
||||
self.kwargs["type"] = "MessageHandlerException"
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class AgentConstructionExceptionEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent_id: AgentId,
|
||||
exception: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["agent_id"] = str(agent_id)
|
||||
self.kwargs["exception"] = str(exception)
|
||||
self.kwargs["type"] = "AgentConstructionException"
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
11
agent_dhal/agentdhal_core/memory/__init__.py
Normal file
11
agent_dhal/agentdhal_core/memory/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from ._base_memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
|
||||
from ._list_memory import ListMemory
|
||||
|
||||
__all__ = [
|
||||
"Memory",
|
||||
"MemoryContent",
|
||||
"MemoryQueryResult",
|
||||
"UpdateContextResult",
|
||||
"MemoryMimeType",
|
||||
"ListMemory",
|
||||
]
|
||||
132
agent_dhal/agentdhal_core/memory/_base_memory.py
Normal file
132
agent_dhal/agentdhal_core/memory/_base_memory.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer
|
||||
|
||||
from .._cancellation_token import CancellationToken
|
||||
from .._component_config import ComponentBase
|
||||
from .._image import Image
|
||||
from ..model_context import ChatCompletionContext
|
||||
|
||||
|
||||
class MemoryMimeType(Enum):
|
||||
"""Supported MIME types for memory content."""
|
||||
|
||||
TEXT = "text/plain"
|
||||
JSON = "application/json"
|
||||
MARKDOWN = "text/markdown"
|
||||
IMAGE = "image/*"
|
||||
BINARY = "application/octet-stream"
|
||||
|
||||
|
||||
ContentType = Union[str, bytes, Dict[str, Any], Image]
|
||||
|
||||
|
||||
class MemoryContent(BaseModel):
|
||||
"""A memory content item."""
|
||||
|
||||
content: ContentType
|
||||
"""The content of the memory item. It can be a string, bytes, dict, or :class:`~agentdhal_core.Image`."""
|
||||
|
||||
mime_type: MemoryMimeType | str
|
||||
"""The MIME type of the memory content."""
|
||||
|
||||
metadata: Dict[str, Any] | None = None
|
||||
"""Metadata associated with the memory item."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@field_serializer("mime_type")
|
||||
def serialize_mime_type(self, mime_type: MemoryMimeType | str) -> str:
|
||||
"""Serialize the MIME type to a string."""
|
||||
if isinstance(mime_type, MemoryMimeType):
|
||||
return mime_type.value
|
||||
return mime_type
|
||||
|
||||
|
||||
class MemoryQueryResult(BaseModel):
|
||||
"""Result of a memory :meth:`~agentdhal_core.memory.Memory.query` operation."""
|
||||
|
||||
results: List[MemoryContent]
|
||||
|
||||
|
||||
class UpdateContextResult(BaseModel):
|
||||
"""Result of a memory :meth:`~agentdhal_core.memory.Memory.update_context` operation."""
|
||||
|
||||
memories: MemoryQueryResult
|
||||
|
||||
|
||||
class Memory(ABC, ComponentBase[BaseModel]):
|
||||
"""Protocol defining the interface for memory implementations.
|
||||
|
||||
A memory is the storage for data that can be used to enrich or modify the model context.
|
||||
|
||||
A memory implementation can use any storage mechanism, such as a list, a database, or a file system.
|
||||
It can also use any retrieval mechanism, such as vector search or text search.
|
||||
It is up to the implementation to decide how to store and retrieve data.
|
||||
|
||||
It is also a memory implementation's responsibility to update the model context
|
||||
with relevant memory content based on the current model context and querying the memory store.
|
||||
|
||||
See :class:`~agentdhal_core.memory.ListMemory` for an example implementation.
|
||||
"""
|
||||
|
||||
component_type = "memory"
|
||||
|
||||
@abstractmethod
|
||||
async def update_context(
|
||||
self,
|
||||
model_context: ChatCompletionContext,
|
||||
) -> UpdateContextResult:
|
||||
"""
|
||||
Update the provided model context using relevant memory content.
|
||||
|
||||
Args:
|
||||
model_context: The context to update.
|
||||
|
||||
Returns:
|
||||
UpdateContextResult containing relevant memories
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def query(
|
||||
self,
|
||||
query: str | MemoryContent,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MemoryQueryResult:
|
||||
"""
|
||||
Query the memory store and return relevant entries.
|
||||
|
||||
Args:
|
||||
query: Query content item
|
||||
cancellation_token: Optional token to cancel operation
|
||||
**kwargs: Additional implementation-specific parameters
|
||||
|
||||
Returns:
|
||||
MemoryQueryResult containing memory entries with relevance scores
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
|
||||
"""
|
||||
Add a new content to memory.
|
||||
|
||||
Args:
|
||||
content: The memory content to add
|
||||
cancellation_token: Optional token to cancel operation
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> None:
|
||||
"""Clear all entries from memory."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Clean up any resources used by the memory implementation."""
|
||||
...
|
||||
172
agent_dhal/agentdhal_core/memory/_list_memory.py
Normal file
172
agent_dhal/agentdhal_core/memory/_list_memory.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._cancellation_token import CancellationToken
|
||||
from .._component_config import Component
|
||||
from ..model_context import ChatCompletionContext
|
||||
from ..models import SystemMessage
|
||||
from ._base_memory import Memory, MemoryContent, MemoryQueryResult, UpdateContextResult
|
||||
|
||||
|
||||
class ListMemoryConfig(BaseModel):
|
||||
"""Configuration for ListMemory component."""
|
||||
|
||||
name: str | None = None
|
||||
"""Optional identifier for this memory instance."""
|
||||
memory_contents: List[MemoryContent] = Field(default_factory=list)
|
||||
"""List of memory contents stored in this memory instance."""
|
||||
|
||||
|
||||
class ListMemory(Memory, Component[ListMemoryConfig]):
|
||||
"""Simple chronological list-based memory implementation.
|
||||
|
||||
This memory implementation stores contents in a list and retrieves them in
|
||||
chronological order. It has an `update_context` method that updates model contexts
|
||||
by appending all stored memories.
|
||||
|
||||
The memory content can be directly accessed and modified through the content property,
|
||||
allowing external applications to manage memory contents directly.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_core.memory import ListMemory, MemoryContent
|
||||
from agentdhal_core.model_context import BufferedChatCompletionContext
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Initialize memory
|
||||
memory = ListMemory(name="chat_history")
|
||||
|
||||
# Add memory content
|
||||
content = MemoryContent(content="User prefers formal language", mime_type="text/plain")
|
||||
await memory.add(content)
|
||||
|
||||
# Directly modify memory contents
|
||||
memory.content = [MemoryContent(content="New preference", mime_type="text/plain")]
|
||||
|
||||
# Create a model context
|
||||
model_context = BufferedChatCompletionContext(buffer_size=10)
|
||||
|
||||
# Update a model context with memory
|
||||
await memory.update_context(model_context)
|
||||
|
||||
# See the updated model context
|
||||
print(await model_context.get_messages())
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Args:
|
||||
name: Optional identifier for this memory instance
|
||||
|
||||
"""
|
||||
|
||||
component_type = "memory"
|
||||
component_provider_override = "agentdhal_core.memory.ListMemory"
|
||||
component_config_schema = ListMemoryConfig
|
||||
|
||||
def __init__(self, name: str | None = None, memory_contents: List[MemoryContent] | None = None) -> None:
|
||||
self._name = name or "default_list_memory"
|
||||
self._contents: List[MemoryContent] = memory_contents if memory_contents is not None else []
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get the memory instance identifier.
|
||||
|
||||
Returns:
|
||||
str: Memory instance name
|
||||
"""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def content(self) -> List[MemoryContent]:
|
||||
"""Get the current memory contents.
|
||||
|
||||
Returns:
|
||||
List[MemoryContent]: List of stored memory contents
|
||||
"""
|
||||
return self._contents
|
||||
|
||||
@content.setter
|
||||
def content(self, value: List[MemoryContent]) -> None:
|
||||
"""Set the memory contents.
|
||||
|
||||
Args:
|
||||
value: New list of memory contents to store
|
||||
"""
|
||||
self._contents = value
|
||||
|
||||
async def update_context(
|
||||
self,
|
||||
model_context: ChatCompletionContext,
|
||||
) -> UpdateContextResult:
|
||||
"""Update the model context by appending memory content.
|
||||
|
||||
This method mutates the provided model_context by adding all memories as a
|
||||
SystemMessage.
|
||||
|
||||
Args:
|
||||
model_context: The context to update. Will be mutated if memories exist.
|
||||
|
||||
Returns:
|
||||
UpdateContextResult containing the memories that were added to the context
|
||||
"""
|
||||
|
||||
if not self._contents:
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
|
||||
|
||||
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(self._contents, 1)]
|
||||
|
||||
if memory_strings:
|
||||
memory_context = "\nRelevant memory content (in chronological order):\n" + "\n".join(memory_strings) + "\n"
|
||||
await model_context.add_message(SystemMessage(content=memory_context))
|
||||
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=self._contents))
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str | MemoryContent = "",
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MemoryQueryResult:
|
||||
"""Return all memories without any filtering.
|
||||
|
||||
Args:
|
||||
query: Ignored in this implementation
|
||||
cancellation_token: Optional token to cancel operation
|
||||
**kwargs: Additional parameters (ignored)
|
||||
|
||||
Returns:
|
||||
MemoryQueryResult containing all stored memories
|
||||
"""
|
||||
_ = query, cancellation_token, kwargs
|
||||
return MemoryQueryResult(results=self._contents)
|
||||
|
||||
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
|
||||
"""Add new content to memory.
|
||||
|
||||
Args:
|
||||
content: Memory content to store
|
||||
cancellation_token: Optional token to cancel operation
|
||||
"""
|
||||
self._contents.append(content)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all memory content."""
|
||||
self._contents = []
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Cleanup resources if needed."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: ListMemoryConfig) -> Self:
|
||||
return cls(name=config.name, memory_contents=config.memory_contents)
|
||||
|
||||
def _to_config(self) -> ListMemoryConfig:
|
||||
return ListMemoryConfig(name=self.name, memory_contents=self._contents)
|
||||
16
agent_dhal/agentdhal_core/model_context/__init__.py
Normal file
16
agent_dhal/agentdhal_core/model_context/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from ._buffered_chat_completion_context import BufferedChatCompletionContext
|
||||
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
|
||||
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
|
||||
from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext
|
||||
from ._unbounded_chat_completion_context import (
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChatCompletionContext",
|
||||
"ChatCompletionContextState",
|
||||
"UnboundedChatCompletionContext",
|
||||
"BufferedChatCompletionContext",
|
||||
"TokenLimitedChatCompletionContext",
|
||||
"HeadAndTailChatCompletionContext",
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._component_config import Component
|
||||
from ..models import FunctionExecutionResultMessage, LLMMessage
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
|
||||
|
||||
class BufferedChatCompletionContextConfig(BaseModel):
|
||||
buffer_size: int
|
||||
initial_messages: List[LLMMessage] | None = None
|
||||
|
||||
|
||||
class BufferedChatCompletionContext(ChatCompletionContext, Component[BufferedChatCompletionContextConfig]):
|
||||
"""A buffered chat completion context that keeps a view of the last n messages,
|
||||
where n is the buffer size. The buffer size is set at initialization.
|
||||
|
||||
Args:
|
||||
buffer_size (int): The size of the buffer.
|
||||
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||
"""
|
||||
|
||||
component_config_schema = BufferedChatCompletionContextConfig
|
||||
component_provider_override = "agentdhal_core.model_context.BufferedChatCompletionContext"
|
||||
|
||||
def __init__(self, buffer_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
super().__init__(initial_messages)
|
||||
if buffer_size <= 0:
|
||||
raise ValueError("buffer_size must be greater than 0.")
|
||||
self._buffer_size = buffer_size
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `buffer_size` recent messages."""
|
||||
messages = self._messages[-self._buffer_size :]
|
||||
# Handle the first message is a function call result message.
|
||||
if messages and isinstance(messages[0], FunctionExecutionResultMessage):
|
||||
# Remove the first message from the list.
|
||||
messages = messages[1:]
|
||||
return messages
|
||||
|
||||
def _to_config(self) -> BufferedChatCompletionContextConfig:
|
||||
return BufferedChatCompletionContextConfig(
|
||||
buffer_size=self._buffer_size, initial_messages=self._initial_messages
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: BufferedChatCompletionContextConfig) -> Self:
|
||||
return cls(**config.model_dump())
|
||||
@@ -0,0 +1,74 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .._component_config import ComponentBase
|
||||
from ..models import LLMMessage
|
||||
|
||||
|
||||
class ChatCompletionContext(ABC, ComponentBase[BaseModel]):
|
||||
"""An abstract base class for defining the interface of a chat completion context.
|
||||
A chat completion context lets agents store and retrieve LLM messages.
|
||||
It can be implemented with different recall strategies.
|
||||
|
||||
Args:
|
||||
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||
|
||||
Example:
|
||||
|
||||
To create a custom model context that filters out the thought field from AssistantMessage.
|
||||
This is useful for reasoning models like DeepSeek R1, which produces
|
||||
very long thought that is not needed for subsequent completions.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from typing import List
|
||||
|
||||
from agentdhal_core.model_context import UnboundedChatCompletionContext
|
||||
from agentdhal_core.models import AssistantMessage, LLMMessage
|
||||
|
||||
|
||||
class ReasoningModelContext(UnboundedChatCompletionContext):
|
||||
\"\"\"A model context for reasoning models.\"\"\"
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
messages = await super().get_messages()
|
||||
# Filter out thought field from AssistantMessage.
|
||||
messages_out: List[LLMMessage] = []
|
||||
for message in messages:
|
||||
if isinstance(message, AssistantMessage):
|
||||
message.thought = None
|
||||
messages_out.append(message)
|
||||
return messages_out
|
||||
|
||||
"""
|
||||
|
||||
component_type = "chat_completion_context"
|
||||
|
||||
def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
self._messages: List[LLMMessage] = []
|
||||
if initial_messages is not None:
|
||||
self._messages.extend(initial_messages)
|
||||
self._initial_messages = initial_messages
|
||||
|
||||
async def add_message(self, message: LLMMessage) -> None:
|
||||
"""Add a message to the context."""
|
||||
self._messages.append(message)
|
||||
|
||||
@abstractmethod
|
||||
async def get_messages(self) -> List[LLMMessage]: ...
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the context."""
|
||||
self._messages = []
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return ChatCompletionContextState(messages=self._messages).model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._messages = ChatCompletionContextState.model_validate(state).messages
|
||||
|
||||
|
||||
class ChatCompletionContextState(BaseModel):
|
||||
messages: List[LLMMessage] = Field(default_factory=list)
|
||||
@@ -0,0 +1,76 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._component_config import Component
|
||||
from .._types import FunctionCall
|
||||
from ..models import AssistantMessage, FunctionExecutionResultMessage, LLMMessage, UserMessage
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
|
||||
|
||||
class HeadAndTailChatCompletionContextConfig(BaseModel):
|
||||
head_size: int
|
||||
tail_size: int
|
||||
initial_messages: List[LLMMessage] | None = None
|
||||
|
||||
|
||||
class HeadAndTailChatCompletionContext(ChatCompletionContext, Component[HeadAndTailChatCompletionContextConfig]):
|
||||
"""A chat completion context that keeps a view of the first n and last m messages,
|
||||
where n is the head size and m is the tail size. The head and tail sizes
|
||||
are set at initialization.
|
||||
|
||||
Args:
|
||||
head_size (int): The size of the head.
|
||||
tail_size (int): The size of the tail.
|
||||
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||
"""
|
||||
|
||||
component_config_schema = HeadAndTailChatCompletionContextConfig
|
||||
component_provider_override = "agentdhal_core.model_context.HeadAndTailChatCompletionContext"
|
||||
|
||||
def __init__(self, head_size: int, tail_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
super().__init__(initial_messages)
|
||||
if head_size <= 0:
|
||||
raise ValueError("head_size must be greater than 0.")
|
||||
if tail_size <= 0:
|
||||
raise ValueError("tail_size must be greater than 0.")
|
||||
self._head_size = head_size
|
||||
self._tail_size = tail_size
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `head_size` recent messages and `tail_size` oldest messages."""
|
||||
head_messages = self._messages[: self._head_size]
|
||||
# Handle the last message is a function call message.
|
||||
if (
|
||||
head_messages
|
||||
and isinstance(head_messages[-1], AssistantMessage)
|
||||
and isinstance(head_messages[-1].content, list)
|
||||
and all(isinstance(item, FunctionCall) for item in head_messages[-1].content)
|
||||
):
|
||||
# Remove the last message from the head.
|
||||
head_messages = head_messages[:-1]
|
||||
|
||||
tail_messages = self._messages[-self._tail_size :]
|
||||
# Handle the first message is a function call result message.
|
||||
if tail_messages and isinstance(tail_messages[0], FunctionExecutionResultMessage):
|
||||
# Remove the first message from the tail.
|
||||
tail_messages = tail_messages[1:]
|
||||
|
||||
num_skipped = len(self._messages) - self._head_size - self._tail_size
|
||||
if num_skipped <= 0:
|
||||
# If there are not enough messages to fill the head and tail,
|
||||
# return all messages.
|
||||
return self._messages
|
||||
|
||||
placeholder_messages = [UserMessage(content=f"Skipped {num_skipped} messages.", source="System")]
|
||||
return head_messages + placeholder_messages + tail_messages
|
||||
|
||||
def _to_config(self) -> HeadAndTailChatCompletionContextConfig:
|
||||
return HeadAndTailChatCompletionContextConfig(
|
||||
head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._initial_messages
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: HeadAndTailChatCompletionContextConfig) -> Self:
|
||||
return cls(head_size=config.head_size, tail_size=config.tail_size, initial_messages=config.initial_messages)
|
||||
@@ -0,0 +1,94 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._component_config import Component, ComponentModel
|
||||
from ..models import ChatCompletionClient, FunctionExecutionResultMessage, LLMMessage
|
||||
from ..tools import ToolSchema
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
|
||||
|
||||
class TokenLimitedChatCompletionContextConfig(BaseModel):
|
||||
model_client: ComponentModel
|
||||
token_limit: int | None = None
|
||||
tool_schema: List[ToolSchema] | None = None
|
||||
initial_messages: List[LLMMessage] | None = None
|
||||
|
||||
|
||||
class TokenLimitedChatCompletionContext(ChatCompletionContext, Component[TokenLimitedChatCompletionContextConfig]):
|
||||
"""(Experimental) A token based chat completion context maintains a view of the context up to a token limit.
|
||||
|
||||
.. note::
|
||||
|
||||
Added in v0.4.10. This is an experimental component and may change in the future.
|
||||
|
||||
Args:
|
||||
model_client (ChatCompletionClient): The model client to use for token counting.
|
||||
The model client must implement the :meth:`~agentdhal_core.models.ChatCompletionClient.count_tokens`
|
||||
and :meth:`~agentdhal_core.models.ChatCompletionClient.remaining_tokens` methods.
|
||||
token_limit (int | None): The maximum number of tokens to keep in the context
|
||||
using the :meth:`~agentdhal_core.models.ChatCompletionClient.count_tokens` method.
|
||||
If None, the context will be limited by the model client using the
|
||||
:meth:`~agentdhal_core.models.ChatCompletionClient.remaining_tokens` method.
|
||||
tools (List[ToolSchema] | None): A list of tool schema to use in the context.
|
||||
initial_messages (List[LLMMessage] | None): A list of initial messages to include in the context.
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = TokenLimitedChatCompletionContextConfig
|
||||
component_provider_override = "agentdhal_core.model_context.TokenLimitedChatCompletionContext"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
token_limit: int | None = None,
|
||||
tool_schema: List[ToolSchema] | None = None,
|
||||
initial_messages: List[LLMMessage] | None = None,
|
||||
) -> None:
|
||||
super().__init__(initial_messages)
|
||||
if token_limit is not None and token_limit <= 0:
|
||||
raise ValueError("token_limit must be greater than 0.")
|
||||
self._token_limit = token_limit
|
||||
self._model_client = model_client
|
||||
self._tool_schema = tool_schema or []
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `token_limit` tokens in recent messages. If the token limit is not
|
||||
provided, then return as many messages as the remaining token allowed by the model client."""
|
||||
messages = list(self._messages)
|
||||
if self._token_limit is None:
|
||||
remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema)
|
||||
while remaining_tokens < 0 and len(messages) > 0:
|
||||
middle_index = len(messages) // 2
|
||||
messages.pop(middle_index)
|
||||
remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema)
|
||||
else:
|
||||
token_count = self._model_client.count_tokens(messages, tools=self._tool_schema)
|
||||
while token_count > self._token_limit and len(messages) > 0:
|
||||
middle_index = len(messages) // 2
|
||||
messages.pop(middle_index)
|
||||
token_count = self._model_client.count_tokens(messages, tools=self._tool_schema)
|
||||
if messages and isinstance(messages[0], FunctionExecutionResultMessage):
|
||||
# Handle the first message is a function call result message.
|
||||
# Remove the first message from the list.
|
||||
messages = messages[1:]
|
||||
return messages
|
||||
|
||||
def _to_config(self) -> TokenLimitedChatCompletionContextConfig:
|
||||
return TokenLimitedChatCompletionContextConfig(
|
||||
model_client=self._model_client.dump_component(),
|
||||
token_limit=self._token_limit,
|
||||
tool_schema=self._tool_schema,
|
||||
initial_messages=self._initial_messages,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: TokenLimitedChatCompletionContextConfig) -> Self:
|
||||
return cls(
|
||||
model_client=ChatCompletionClient.load_component(config.model_client),
|
||||
token_limit=config.token_limit,
|
||||
tool_schema=config.tool_schema,
|
||||
initial_messages=config.initial_messages,
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._component_config import Component
|
||||
from ..models import LLMMessage
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
|
||||
|
||||
class UnboundedChatCompletionContextConfig(BaseModel):
|
||||
initial_messages: List[LLMMessage] | None = None
|
||||
|
||||
|
||||
class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedChatCompletionContextConfig]):
|
||||
"""An unbounded chat completion context that keeps a view of the all the messages."""
|
||||
|
||||
component_config_schema = UnboundedChatCompletionContextConfig
|
||||
component_provider_override = "agentdhal_core.model_context.UnboundedChatCompletionContext"
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `buffer_size` recent messages."""
|
||||
return self._messages
|
||||
|
||||
def _to_config(self) -> UnboundedChatCompletionContextConfig:
|
||||
return UnboundedChatCompletionContextConfig(initial_messages=self._initial_messages)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: UnboundedChatCompletionContextConfig) -> Self:
|
||||
return cls(initial_messages=config.initial_messages)
|
||||
0
agent_dhal/agentdhal_core/py.typed
Normal file
0
agent_dhal/agentdhal_core/py.typed
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user