first commit
This commit is contained in:
25
agent_dhal/agentdhal_agentchat/agents/__init__.py
Normal file
25
agent_dhal/agentdhal_agentchat/agents/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
This module initializes various pre-defined agents provided by the package.
|
||||
BaseChatAgent is the base class for all agents in AgentChat.
|
||||
"""
|
||||
|
||||
from ._assistant_agent import AssistantAgent
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
from ._code_executor_agent import ApprovalFuncType, ApprovalRequest, ApprovalResponse, CodeExecutorAgent
|
||||
from ._message_filter_agent import MessageFilterAgent, MessageFilterConfig, PerSourceFilter
|
||||
from ._society_of_mind_agent import SocietyOfMindAgent
|
||||
from ._user_proxy_agent import UserProxyAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseChatAgent",
|
||||
"AssistantAgent",
|
||||
"CodeExecutorAgent",
|
||||
"SocietyOfMindAgent",
|
||||
"UserProxyAgent",
|
||||
"MessageFilterAgent",
|
||||
"MessageFilterConfig",
|
||||
"PerSourceFilter",
|
||||
"ApprovalRequest",
|
||||
"ApprovalResponse",
|
||||
"ApprovalFuncType",
|
||||
]
|
||||
1699
agent_dhal/agentdhal_agentchat/agents/_assistant_agent.py
Normal file
1699
agent_dhal/agentdhal_agentchat/agents/_assistant_agent.py
Normal file
File diff suppressed because it is too large
Load Diff
245
agent_dhal/agentdhal_agentchat/agents/_base_chat_agent.py
Normal file
245
agent_dhal/agentdhal_agentchat/agents/_base_chat_agent.py
Normal file
@@ -0,0 +1,245 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence
|
||||
|
||||
from agentdhal_core import CancellationToken, ComponentBase, trace_create_agent_span, trace_invoke_agent_span
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..base import ChatAgent, Response, TaskResult
|
||||
from ..messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
TextMessage,
|
||||
)
|
||||
from ..state import BaseState
|
||||
|
||||
|
||||
class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
|
||||
"""Base class for a chat agent.
|
||||
|
||||
This abstract class provides a base implementation for a :class:`ChatAgent`.
|
||||
To create a new chat agent, subclass this class and implement the
|
||||
:meth:`on_messages`, :meth:`on_reset`, and :attr:`produced_message_types`.
|
||||
If streaming is required, also implement the :meth:`on_messages_stream` method.
|
||||
|
||||
An agent is considered stateful and maintains its state between calls to
|
||||
the :meth:`on_messages` or :meth:`on_messages_stream` methods.
|
||||
The agent should store its state in the
|
||||
agent instance. The agent should also implement the :meth:`on_reset` method
|
||||
to reset the agent to its initialization state.
|
||||
|
||||
.. note::
|
||||
|
||||
The caller should only pass the new messages to the agent on each call
|
||||
to the :meth:`on_messages` or :meth:`on_messages_stream` method.
|
||||
Do not pass the entire conversation history to the agent on each call.
|
||||
This design principle must be followed when creating a new agent.
|
||||
"""
|
||||
|
||||
component_type = "agent"
|
||||
|
||||
def __init__(self, name: str, description: str) -> None:
|
||||
"""Initialize the agent with a name and description."""
|
||||
with trace_create_agent_span(
|
||||
agent_name=name,
|
||||
agent_description=description,
|
||||
):
|
||||
self._name = name
|
||||
if self._name.isidentifier() is False:
|
||||
raise ValueError("The agent name must be a valid Python identifier.")
|
||||
self._description = description
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the agent. This is used by team to uniquely identify
|
||||
the agent. It should be unique within the team."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""The description of the agent. This is used by team to
|
||||
make decisions about which agents to use. The description should
|
||||
describe the agent's capabilities and how to interact with it."""
|
||||
return self._description
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""The types of messages that the agent produces in the
|
||||
:attr:`Response.chat_message` field. They must be :class:`BaseChatMessage` types."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
"""Handles incoming messages and returns a response.
|
||||
|
||||
.. note::
|
||||
|
||||
Agents are stateful and the messages passed to this method should
|
||||
be the new messages since the last call to this method. The agent
|
||||
should maintain its state between calls to this method. For example,
|
||||
if the agent needs to remember the previous messages to respond to
|
||||
the current message, it should store the previous messages in the
|
||||
agent state.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""Handles incoming messages and returns a stream of messages and
|
||||
and the final item is the response. The base implementation in
|
||||
:class:`BaseChatAgent` simply calls :meth:`on_messages` and yields
|
||||
the messages in the response.
|
||||
|
||||
.. note::
|
||||
|
||||
Agents are stateful and the messages passed to this method should
|
||||
be the new messages since the last call to this method. The agent
|
||||
should maintain its state between calls to this method. For example,
|
||||
if the agent needs to remember the previous messages to respond to
|
||||
the current message, it should store the previous messages in the
|
||||
agent state.
|
||||
|
||||
"""
|
||||
response = await self.on_messages(messages, cancellation_token)
|
||||
for inner_message in response.inner_messages or []:
|
||||
yield inner_message
|
||||
yield response
|
||||
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
output_task_messages: bool = True,
|
||||
) -> TaskResult:
|
||||
"""Run the agent with the given task and return the result."""
|
||||
with trace_invoke_agent_span(
|
||||
agent_name=self.name,
|
||||
agent_description=self.description,
|
||||
):
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
input_messages: List[BaseChatMessage] = []
|
||||
output_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if task is None:
|
||||
pass
|
||||
elif isinstance(task, str):
|
||||
text_msg = TextMessage(content=task, source="user")
|
||||
input_messages.append(text_msg)
|
||||
if output_task_messages:
|
||||
output_messages.append(text_msg)
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
input_messages.append(task)
|
||||
if output_task_messages:
|
||||
output_messages.append(task)
|
||||
else:
|
||||
if not task:
|
||||
raise ValueError("Task list cannot be empty.")
|
||||
# Task is a sequence of messages.
|
||||
for msg in task:
|
||||
if isinstance(msg, BaseChatMessage):
|
||||
input_messages.append(msg)
|
||||
if output_task_messages:
|
||||
output_messages.append(msg)
|
||||
else:
|
||||
raise ValueError(f"Invalid message type in sequence: {type(msg)}")
|
||||
response = await self.on_messages(input_messages, cancellation_token)
|
||||
if response.inner_messages is not None:
|
||||
output_messages += response.inner_messages
|
||||
output_messages.append(response.chat_message)
|
||||
return TaskResult(messages=output_messages)
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
output_task_messages: bool = True,
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
|
||||
"""Run the agent with the given task and return a stream of messages
|
||||
and the final task result as the last item in the stream.
|
||||
|
||||
Args:
|
||||
task: The task to run. Can be a string, a single message, or a sequence of messages.
|
||||
cancellation_token: The cancellation token to kill the task immediately.
|
||||
output_task_messages: Whether to include task messages in the output stream. Defaults to True for backward compatibility.
|
||||
"""
|
||||
with trace_invoke_agent_span(
|
||||
agent_name=self.name,
|
||||
agent_description=self.description,
|
||||
):
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
input_messages: List[BaseChatMessage] = []
|
||||
output_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if task is None:
|
||||
pass
|
||||
elif isinstance(task, str):
|
||||
text_msg = TextMessage(content=task, source="user")
|
||||
input_messages.append(text_msg)
|
||||
if output_task_messages:
|
||||
output_messages.append(text_msg)
|
||||
yield text_msg
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
input_messages.append(task)
|
||||
if output_task_messages:
|
||||
output_messages.append(task)
|
||||
yield task
|
||||
else:
|
||||
if not task:
|
||||
raise ValueError("Task list cannot be empty.")
|
||||
for msg in task:
|
||||
if isinstance(msg, BaseChatMessage):
|
||||
input_messages.append(msg)
|
||||
if output_task_messages:
|
||||
output_messages.append(msg)
|
||||
yield msg
|
||||
else:
|
||||
raise ValueError(f"Invalid message type in sequence: {type(msg)}")
|
||||
async for message in self.on_messages_stream(input_messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
yield message.chat_message
|
||||
output_messages.append(message.chat_message)
|
||||
yield TaskResult(messages=output_messages)
|
||||
else:
|
||||
yield message
|
||||
if isinstance(message, ModelClientStreamingChunkEvent):
|
||||
# Skip the model client streaming chunk events.
|
||||
continue
|
||||
output_messages.append(message)
|
||||
|
||||
@abstractmethod
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Resets the agent to its initialization state."""
|
||||
...
|
||||
|
||||
async def on_pause(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Called when the agent is paused while running in its :meth:`on_messages` or
|
||||
:meth:`on_messages_stream` method. This is a no-op by default in the
|
||||
:class:`BaseChatAgent` class. Subclasses can override this method to
|
||||
implement custom pause behavior."""
|
||||
pass
|
||||
|
||||
async def on_resume(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Called when the agent is resumed from a pause while running in
|
||||
its :meth:`on_messages` or :meth:`on_messages_stream` method.
|
||||
This is a no-op by default in the :class:`BaseChatAgent` class.
|
||||
Subclasses can override this method to implement custom resume behavior."""
|
||||
pass
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Export state. Default implementation for stateless agents."""
|
||||
return BaseState().model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Restore agent from saved state. Default implementation for stateless agents."""
|
||||
BaseState.model_validate(state)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Release any resources held by the agent. This is a no-op by default in the
|
||||
:class:`BaseChatAgent` class. Subclasses can override this method to
|
||||
implement custom close behavior."""
|
||||
pass
|
||||
881
agent_dhal/agentdhal_agentchat/agents/_code_executor_agent.py
Normal file
881
agent_dhal/agentdhal_agentchat/agents/_code_executor_agent.py
Normal file
@@ -0,0 +1,881 @@
|
||||
import logging
|
||||
import re
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import (
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from agentdhal_core import CancellationToken, Component, ComponentModel
|
||||
from agentdhal_core.code_executor import CodeBlock, CodeExecutor, CodeResult
|
||||
from agentdhal_core.model_context import (
|
||||
ChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .. import EVENT_LOGGER_NAME
|
||||
from ..base import Response
|
||||
from ..messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
CodeExecutionEvent,
|
||||
CodeGenerationEvent,
|
||||
HandoffMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
TextMessage,
|
||||
ThoughtEvent,
|
||||
)
|
||||
from ..utils import remove_images
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class CodeExecutorAgentConfig(BaseModel):
|
||||
"""Configuration for CodeExecutorAgent"""
|
||||
|
||||
name: str
|
||||
code_executor: ComponentModel
|
||||
model_client: ComponentModel | None = None
|
||||
description: str | None = None
|
||||
sources: List[str] | None = None
|
||||
system_message: str | None = None
|
||||
model_client_stream: bool = False
|
||||
model_context: ComponentModel | None = None
|
||||
supported_languages: List[str] | None = None
|
||||
|
||||
|
||||
class RetryDecision(BaseModel):
|
||||
reason: str
|
||||
retry: bool
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
"""Request for approval of code execution."""
|
||||
|
||||
code: str
|
||||
context: List[LLMMessage]
|
||||
|
||||
|
||||
class ApprovalResponse(BaseModel):
|
||||
"""Response to approval request."""
|
||||
|
||||
approved: bool
|
||||
reason: str
|
||||
|
||||
|
||||
# Type aliases for approval functions
|
||||
SyncApprovalFunc = Callable[[ApprovalRequest], ApprovalResponse]
|
||||
AsyncApprovalFunc = Callable[[ApprovalRequest], Awaitable[ApprovalResponse]]
|
||||
ApprovalFuncType = Union[SyncApprovalFunc, AsyncApprovalFunc]
|
||||
|
||||
|
||||
class CodeExecutorAgent(BaseChatAgent, Component[CodeExecutorAgentConfig]):
|
||||
"""(Experimental) An agent that generates and executes code snippets based on user instructions.
|
||||
|
||||
.. note::
|
||||
|
||||
This agent is experimental and may change in future releases.
|
||||
|
||||
It is typically used within a team with another agent that generates code snippets
|
||||
to be executed or alone with `model_client` provided so that it can generate code
|
||||
based on user query, execute it and reflect on the code result.
|
||||
|
||||
When used with `model_client`, it will generate code snippets using the model
|
||||
and execute them using the provided `code_executor`. The model will also reflect on the
|
||||
code execution results. The agent will yield the final reflection result from the model
|
||||
as the final response.
|
||||
|
||||
When used without `model_client`, it will only execute code blocks found in
|
||||
:class:`~agentdhal_agentchat.messages.TextMessage` messages and returns the output
|
||||
of the code execution.
|
||||
|
||||
.. note::
|
||||
|
||||
Using :class:`~agentdhal_agentchat.agents.AssistantAgent` with
|
||||
:class:`~agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool`
|
||||
is an alternative to this agent. However, the model for that agent will
|
||||
have to generate properly escaped code string as a parameter to the tool.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
code_executor (CodeExecutor): The code executor responsible for executing code received in messages
|
||||
(:py:class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor` recommended. See example below)
|
||||
model_client (ChatCompletionClient, optional): The model client to use for inference and generating code.
|
||||
If not provided, the agent will only execute code blocks found in input messages.
|
||||
Currently, the model must support structured output mode, which is required for
|
||||
the automatic retry mechanism to work.
|
||||
model_client_stream (bool, optional): If `True`, the model client will be used in streaming mode.
|
||||
:meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will
|
||||
also yield :class:`~agentdhal_agentchat.messages.ModelClientStreamingChunkEvent`
|
||||
messages as the model client produces chunks of response. Defaults to `False`.
|
||||
description (str, optional): The description of the agent. If not provided,
|
||||
:class:`~agentdhal_agentchat.agents.CodeExecutorAgent.DEFAULT_AGENT_DESCRIPTION` will be used.
|
||||
system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable.
|
||||
Defaults to :class:`~agentdhal_agentchat.agents.CodeExecutorAgent.DEFAULT_SYSTEM_MESSAGE`. This is only used if `model_client` is provided.
|
||||
sources (Sequence[str], optional): Check only messages from the specified agents for the code to execute.
|
||||
This is useful when the agent is part of a group chat and you want to limit the code execution to messages from specific agents.
|
||||
If not provided, all messages will be checked for code blocks.
|
||||
This is only used if `model_client` is not provided.
|
||||
max_retries_on_error (int, optional): The maximum number of retries on error. If the code execution fails, the agent will retry up to this number of times.
|
||||
If the code execution fails after this number of retries, the agent will yield a reflection result.
|
||||
supported_languages (List[str], optional): List of programming languages that will be parsed and executed from agent response;
|
||||
others will be ignored. Defaults to DEFAULT_SUPPORTED_LANGUAGES.
|
||||
approval_func (Optional[Union[Callable[[ApprovalRequest], ApprovalResponse], Callable[[ApprovalRequest], Awaitable[ApprovalResponse]]]], optional): A function that is called before each code execution to get approval.
|
||||
The function takes an ApprovalRequest containing the code to be executed and the current context, and returns an ApprovalResponse.
|
||||
The function can be either synchronous or asynchronous. If None (default), all code executions are automatically approved.
|
||||
If set, the agent cannot be serialized using :meth:`~agentdhal_agentchat.agents.CodeExecutorAgent.dump_component`.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
It is recommended that the `CodeExecutorAgent` agent uses a Docker container to execute code. This ensures that model-generated code is executed in an isolated environment. To use Docker, your environment must have Docker installed and running.
|
||||
Follow the installation instructions for `Docker <https://docs.docker.com/get-docker/>`_.
|
||||
|
||||
.. note::
|
||||
|
||||
The code executor only processes code that is properly formatted in markdown code blocks using triple backticks.
|
||||
For example:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
```python
|
||||
print("Hello World")
|
||||
```
|
||||
|
||||
# or
|
||||
|
||||
```sh
|
||||
echo "Hello World"
|
||||
```
|
||||
|
||||
In this example, we show how to set up a `CodeExecutorAgent` agent that uses the
|
||||
:py:class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor`
|
||||
to execute code snippets in a Docker container. The `work_dir` parameter indicates
|
||||
where all executed files are first saved locally before being executed in the Docker container.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import CodeExecutorAgent, ApprovalRequest, ApprovalResponse
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor
|
||||
from agentdhal_core import CancellationToken
|
||||
|
||||
|
||||
def simple_approval_func(request: ApprovalRequest) -> ApprovalResponse:
|
||||
\"\"\"Simple approval function that requests user input for code execution approval.\"\"\"
|
||||
print("Code execution approval requested:")
|
||||
print("=" * 50)
|
||||
print(request.code)
|
||||
print("=" * 50)
|
||||
|
||||
while True:
|
||||
user_input = input("Do you want to execute this code? (y/n): ").strip().lower()
|
||||
if user_input in ['y', 'yes']:
|
||||
return ApprovalResponse(approved=True, reason='Approved by user')
|
||||
elif user_input in ['n', 'no']:
|
||||
return ApprovalResponse(approved=False, reason='Denied by user')
|
||||
else:
|
||||
print("Please enter 'y' for yes or 'n' for no.")
|
||||
|
||||
|
||||
async def run_code_executor_agent() -> None:
|
||||
# Create a code executor agent that uses a Docker container to execute code.
|
||||
code_executor = DockerCommandLineCodeExecutor(work_dir="coding")
|
||||
await code_executor.start()
|
||||
code_executor_agent = CodeExecutorAgent(
|
||||
"code_executor",
|
||||
code_executor=code_executor,
|
||||
approval_func=simple_approval_func
|
||||
)
|
||||
|
||||
# Run the agent with a given code snippet.
|
||||
task = TextMessage(
|
||||
content='''Here is some code
|
||||
```python
|
||||
print('Hello world')
|
||||
```
|
||||
''',
|
||||
source="user",
|
||||
)
|
||||
response = await code_executor_agent.on_messages([task], CancellationToken())
|
||||
print(response.chat_message)
|
||||
|
||||
# Stop the code executor.
|
||||
await code_executor.stop()
|
||||
|
||||
|
||||
asyncio.run(run_code_executor_agent())
|
||||
|
||||
In this example, we show how to set up a `CodeExecutorAgent` agent that uses the
|
||||
:py:class:`~docker.types.DeviceRequest` to expose a GPU to the container for cuda-accelerated code execution.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import CodeExecutorAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor
|
||||
from agentdhal_core import CancellationToken
|
||||
from docker.types import DeviceRequest
|
||||
|
||||
|
||||
async def run_code_executor_agent() -> None:
|
||||
# Create a code executor agent that uses a Docker container to execute code.
|
||||
code_executor = DockerCommandLineCodeExecutor(
|
||||
work_dir="coding", device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])]
|
||||
)
|
||||
await code_executor.start()
|
||||
code_executor_agent = CodeExecutorAgent("code_executor", code_executor=code_executor)
|
||||
|
||||
# Display the GPU information
|
||||
task = TextMessage(
|
||||
content='''Here is some code
|
||||
```sh
|
||||
nvidia-smi
|
||||
```
|
||||
''',
|
||||
source="user",
|
||||
)
|
||||
response = await code_executor_agent.on_messages([task], CancellationToken())
|
||||
print(response.chat_message)
|
||||
|
||||
# Stop the code executor.
|
||||
await code_executor.stop()
|
||||
|
||||
|
||||
asyncio.run(run_code_executor_agent())
|
||||
|
||||
In the following example, we show how to setup `CodeExecutorAgent` without `model_client` parameter for executing code blocks generated by other agents in a group chat using :py:class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor`
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent, CodeExecutorAgent, ApprovalRequest, ApprovalResponse
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
termination_condition = MaxMessageTermination(3)
|
||||
|
||||
|
||||
def group_chat_approval_func(request: ApprovalRequest) -> ApprovalResponse:
|
||||
\"\"\"Approval function for group chat that allows basic Python operations.\"\"\"
|
||||
# Allow common safe operations
|
||||
safe_operations = ["print(", "import ", "def ", "class ", "if ", "for ", "while "]
|
||||
if any(op in request.code for op in safe_operations):
|
||||
return ApprovalResponse(approved=True, reason='Safe Python operation')
|
||||
|
||||
# Deny file system operations in group chat
|
||||
dangerous_operations = ["open(", "file(", "os.", "subprocess", "eval(", "exec("]
|
||||
if any(op in request.code for op in dangerous_operations):
|
||||
return ApprovalResponse(approved=False, reason='File system or dangerous operation not allowed')
|
||||
|
||||
return ApprovalResponse(approved=True, reason='Operation approved')
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
# define the Docker CLI Code Executor
|
||||
code_executor = DockerCommandLineCodeExecutor(work_dir="coding")
|
||||
|
||||
# start the execution container
|
||||
await code_executor.start()
|
||||
|
||||
code_executor_agent = CodeExecutorAgent(
|
||||
"code_executor_agent",
|
||||
code_executor=code_executor,
|
||||
approval_func=group_chat_approval_func
|
||||
)
|
||||
coder_agent = AssistantAgent("coder_agent", model_client=model_client)
|
||||
|
||||
groupchat = RoundRobinGroupChat(
|
||||
participants=[coder_agent, code_executor_agent], termination_condition=termination_condition
|
||||
)
|
||||
|
||||
task = "Write python code to print Hello World!"
|
||||
await Console(groupchat.run_stream(task=task))
|
||||
|
||||
# stop the execution container
|
||||
await code_executor.stop()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
---------- user ----------
|
||||
Write python code to print Hello World!
|
||||
---------- coder_agent ----------
|
||||
Certainly! Here's a simple Python code to print "Hello World!":
|
||||
|
||||
```python
|
||||
print("Hello World!")
|
||||
```
|
||||
|
||||
You can run this code in any Python environment to display the message.
|
||||
---------- code_executor_agent ----------
|
||||
Hello World!
|
||||
|
||||
In the following example, we show how to setup `CodeExecutorAgent` with `model_client`
|
||||
that can generate its own code without the help of any other agent and executing it in
|
||||
:py:class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor`.
|
||||
It also demonstrates using a model-based approval function that reviews the code for safety before execution.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_core.models import SystemMessage, UserMessage
|
||||
|
||||
from agentdhal_agentchat.agents import CodeExecutorAgent, ApprovalRequest, ApprovalResponse
|
||||
from agentdhal_agentchat.conditions import TextMessageTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
termination_condition = TextMessageTermination("code_executor_agent")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
async def model_client_approval_func(request: ApprovalRequest) -> ApprovalResponse:
|
||||
instruction = "Approve or reject the code in the last message based on whether it is dangerous or not. Use the following JSON format for your response: {approved: true/false, reason: 'your reason here'}"
|
||||
response = await model_client.create(
|
||||
messages=[SystemMessage(content=instruction)]
|
||||
+ request.context
|
||||
+ [UserMessage(content=request.code, source="user")],
|
||||
json_output=ApprovalResponse,
|
||||
)
|
||||
assert isinstance(response.content, str)
|
||||
return ApprovalResponse.model_validate_json(response.content)
|
||||
|
||||
# define the Docker CLI Code Executor
|
||||
code_executor = DockerCommandLineCodeExecutor(work_dir="coding")
|
||||
|
||||
# start the execution container
|
||||
await code_executor.start()
|
||||
|
||||
code_executor_agent = CodeExecutorAgent(
|
||||
"code_executor_agent",
|
||||
code_executor=code_executor,
|
||||
model_client=model_client,
|
||||
approval_func=model_client_approval_func,
|
||||
)
|
||||
|
||||
task = "Write python code to print Hello World!"
|
||||
await Console(code_executor_agent.run_stream(task=task))
|
||||
|
||||
# stop the execution container
|
||||
await code_executor.stop()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
---------- user ----------
|
||||
Write python code to print Hello World!
|
||||
---------- code_executor_agent ----------
|
||||
Certainly! Here is a simple Python code to print "Hello World!" to the console:
|
||||
|
||||
```python
|
||||
print("Hello World!")
|
||||
```
|
||||
|
||||
Let's execute it to confirm the output.
|
||||
---------- code_executor_agent ----------
|
||||
Hello World!
|
||||
|
||||
---------- code_executor_agent ----------
|
||||
The code has been executed successfully, and it printed "Hello World!" as expected. If you have any more requests or questions, feel free to ask!
|
||||
|
||||
"""
|
||||
|
||||
DEFAULT_TERMINAL_DESCRIPTION = "A computer terminal that performs no other action than running Python scripts (provided to it quoted in ```python code blocks), or sh shell scripts (provided to it quoted in ```sh code blocks)."
|
||||
DEFAULT_AGENT_DESCRIPTION = "A Code Execution Agent that generates and executes Python and shell scripts based on user instructions. It ensures correctness, efficiency, and minimal errors while gracefully handling edge cases."
|
||||
DEFAULT_SYSTEM_MESSAGE = "You are a Code Execution Agent. Your role is to generate and execute Python code and shell scripts based on user instructions, ensuring correctness, efficiency, and minimal errors. Handle edge cases gracefully. Python code should be provided in ```python code blocks, and sh shell scripts should be provided in ```sh code blocks for execution."
|
||||
NO_CODE_BLOCKS_FOUND_MESSAGE = "No code blocks found in the thread. Please provide at least one markdown-encoded code block to execute (i.e., quoting code in ```python or ```sh code blocks)."
|
||||
DEFAULT_SUPPORTED_LANGUAGES = ["python", "sh"]
|
||||
|
||||
component_config_schema = CodeExecutorAgentConfig
|
||||
component_provider_override = "agentdhal_agentchat.agents.CodeExecutorAgent"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
code_executor: CodeExecutor,
|
||||
*,
|
||||
model_client: ChatCompletionClient | None = None,
|
||||
model_context: ChatCompletionContext | None = None,
|
||||
model_client_stream: bool = False,
|
||||
max_retries_on_error: int = 0,
|
||||
description: str | None = None,
|
||||
system_message: str | None = DEFAULT_SYSTEM_MESSAGE,
|
||||
sources: Sequence[str] | None = None,
|
||||
supported_languages: List[str] | None = None,
|
||||
approval_func: Optional[ApprovalFuncType] = None,
|
||||
) -> None:
|
||||
if description is None:
|
||||
if model_client is None:
|
||||
description = CodeExecutorAgent.DEFAULT_TERMINAL_DESCRIPTION
|
||||
else:
|
||||
description = CodeExecutorAgent.DEFAULT_AGENT_DESCRIPTION
|
||||
|
||||
super().__init__(name=name, description=description)
|
||||
self._code_executor = code_executor
|
||||
self._sources = sources
|
||||
self._model_client_stream = model_client_stream
|
||||
self._max_retries_on_error = max_retries_on_error
|
||||
self._approval_func = approval_func
|
||||
self._approval_func_is_async = approval_func is not None and iscoroutinefunction(approval_func)
|
||||
|
||||
if supported_languages is not None:
|
||||
self._supported_languages = supported_languages
|
||||
else:
|
||||
self._supported_languages = CodeExecutorAgent.DEFAULT_SUPPORTED_LANGUAGES
|
||||
|
||||
self._supported_languages_regex = "|".join(re.escape(lang) for lang in self._supported_languages)
|
||||
|
||||
self._model_client = None
|
||||
if model_client is not None:
|
||||
self._model_client = model_client
|
||||
|
||||
if model_context is not None:
|
||||
self._model_context = model_context
|
||||
else:
|
||||
self._model_context = UnboundedChatCompletionContext()
|
||||
|
||||
self._system_messaages: List[SystemMessage] = []
|
||||
if system_message is None:
|
||||
self._system_messages = []
|
||||
else:
|
||||
self._system_messages = [SystemMessage(content=system_message)]
|
||||
|
||||
if self._max_retries_on_error > 0:
|
||||
if not self._model_client or not self._model_client.model_info:
|
||||
raise ValueError("model_client.model_info must be provided when max_retries_on_error > 0")
|
||||
if not self._model_client.model_info["structured_output"]:
|
||||
raise ValueError("Specified model_client doesn't support structured output mode.")
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""The types of messages that the code executor agent produces."""
|
||||
return (TextMessage,)
|
||||
|
||||
@property
|
||||
def model_context(self) -> ChatCompletionContext:
|
||||
"""
|
||||
The model context in use by the agent.
|
||||
"""
|
||||
return self._model_context
|
||||
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
return message
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""
|
||||
Process the incoming messages with the assistant agent and yield events/responses as they happen.
|
||||
"""
|
||||
|
||||
# Gather all relevant state here
|
||||
agent_name = self.name
|
||||
model_context = self._model_context
|
||||
system_messages = self._system_messages
|
||||
model_client = self._model_client
|
||||
model_client_stream = self._model_client_stream
|
||||
max_retries_on_error = self._max_retries_on_error
|
||||
|
||||
execution_result: CodeResult | None = None
|
||||
if model_client is None: # default behaviour for backward compatibility
|
||||
# execute generated code if present
|
||||
code_blocks: List[CodeBlock] = await self.extract_code_blocks_from_messages(messages)
|
||||
if not code_blocks:
|
||||
yield Response(
|
||||
chat_message=TextMessage(
|
||||
content=self.NO_CODE_BLOCKS_FOUND_MESSAGE,
|
||||
source=agent_name,
|
||||
)
|
||||
)
|
||||
return
|
||||
execution_result = await self.execute_code_block(code_blocks, cancellation_token)
|
||||
yield Response(chat_message=TextMessage(content=execution_result.output, source=self.name))
|
||||
return
|
||||
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
|
||||
for nth_try in range(max_retries_on_error + 1): # Do one default generation, execution and inference loop
|
||||
# Step 1: Add new user/handoff messages to the model context
|
||||
await self._add_messages_to_context(
|
||||
model_context=model_context,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Step 2: Run inference with the model context
|
||||
model_result = None
|
||||
async for inference_output in self._call_llm(
|
||||
model_client=model_client,
|
||||
model_client_stream=model_client_stream,
|
||||
system_messages=system_messages,
|
||||
model_context=model_context,
|
||||
agent_name=agent_name,
|
||||
cancellation_token=cancellation_token,
|
||||
):
|
||||
if isinstance(inference_output, CreateResult):
|
||||
model_result = inference_output
|
||||
else:
|
||||
# Streaming chunk event
|
||||
yield inference_output
|
||||
|
||||
assert model_result is not None, "No model result was produced."
|
||||
|
||||
# Step 3: [NEW] If the model produced a hidden "thought," yield it as an event
|
||||
if model_result.thought:
|
||||
thought_event = ThoughtEvent(content=model_result.thought, source=agent_name)
|
||||
yield thought_event
|
||||
inner_messages.append(thought_event)
|
||||
|
||||
# Step 4: Add the assistant message to the model context (including thought if present)
|
||||
await model_context.add_message(
|
||||
AssistantMessage(
|
||||
content=model_result.content,
|
||||
source=agent_name,
|
||||
thought=getattr(model_result, "thought", None),
|
||||
)
|
||||
)
|
||||
|
||||
# Step 5: Extract the code blocks from inferred text
|
||||
assert isinstance(model_result.content, str), "Expected inferred model_result.content to be of type str."
|
||||
code_blocks = self._extract_markdown_code_blocks(str(model_result.content))
|
||||
|
||||
# Step 6: Exit the loop if no code blocks found
|
||||
if not code_blocks:
|
||||
yield Response(
|
||||
chat_message=TextMessage(
|
||||
content=str(model_result.content),
|
||||
source=agent_name,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Step 7: Yield a CodeGenerationEvent
|
||||
inferred_text_message: CodeGenerationEvent = CodeGenerationEvent(
|
||||
retry_attempt=nth_try,
|
||||
content=model_result.content,
|
||||
code_blocks=code_blocks,
|
||||
source=agent_name,
|
||||
)
|
||||
|
||||
yield inferred_text_message
|
||||
|
||||
# Step 8: Execute the extracted code blocks
|
||||
execution_result = await self.execute_code_block(inferred_text_message.code_blocks, cancellation_token)
|
||||
|
||||
# Step 9: Update model context with the code execution result
|
||||
await model_context.add_message(
|
||||
UserMessage(
|
||||
content=execution_result.output,
|
||||
source=agent_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Step 10: Yield a CodeExecutionEvent
|
||||
yield CodeExecutionEvent(retry_attempt=nth_try, result=execution_result, source=self.name)
|
||||
|
||||
# If execution was successful or last retry, then exit
|
||||
if execution_result.exit_code == 0 or nth_try == max_retries_on_error:
|
||||
break
|
||||
|
||||
# Step 11: If exit code is non-zero and retries are available then
|
||||
# make an inference asking if we should retry or not
|
||||
chat_context = await model_context.get_messages()
|
||||
|
||||
retry_prompt = (
|
||||
f"The most recent code execution resulted in an error:\n{execution_result.output}\n\n"
|
||||
"Should we attempt to resolve it? Please respond with:\n"
|
||||
"- A boolean value for 'retry' indicating whether it should be retried.\n"
|
||||
"- A detailed explanation in 'reason' that identifies the issue, justifies your decision to retry or not, and outlines how you would resolve the error if a retry is attempted."
|
||||
)
|
||||
|
||||
chat_context = chat_context + [
|
||||
UserMessage(
|
||||
content=retry_prompt,
|
||||
source=agent_name,
|
||||
)
|
||||
]
|
||||
|
||||
response = await model_client.create(messages=chat_context, json_output=RetryDecision)
|
||||
|
||||
assert isinstance(
|
||||
response.content, str
|
||||
), "Expected structured response for retry decision to be of type str."
|
||||
should_retry_generation = RetryDecision.model_validate_json(str(response.content))
|
||||
|
||||
# Exit if no-retry is needed
|
||||
if not should_retry_generation.retry:
|
||||
break
|
||||
|
||||
yield CodeGenerationEvent(
|
||||
retry_attempt=nth_try,
|
||||
content=f"Attempt number: {nth_try + 1}\nProposed correction: {should_retry_generation.reason}",
|
||||
code_blocks=[],
|
||||
source=agent_name,
|
||||
)
|
||||
|
||||
# Always reflect on the execution result
|
||||
async for reflection_response in CodeExecutorAgent._reflect_on_code_block_results_flow(
|
||||
system_messages=system_messages,
|
||||
model_client=model_client,
|
||||
model_client_stream=model_client_stream,
|
||||
model_context=model_context,
|
||||
agent_name=agent_name,
|
||||
inner_messages=inner_messages,
|
||||
):
|
||||
yield reflection_response # Last reflection_response is of type Response so it will finish the routine
|
||||
|
||||
async def extract_code_blocks_from_messages(self, messages: Sequence[BaseChatMessage]) -> List[CodeBlock]:
|
||||
# Extract code blocks from the messages.
|
||||
code_blocks: List[CodeBlock] = []
|
||||
for msg in messages:
|
||||
if self._sources is None or msg.source in self._sources:
|
||||
if isinstance(msg, TextMessage):
|
||||
code_blocks.extend(self._extract_markdown_code_blocks(msg.content))
|
||||
# TODO: handle other message types if needed
|
||||
return code_blocks
|
||||
|
||||
async def execute_code_block(
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> CodeResult:
|
||||
# Check for approval before executing code blocks
|
||||
if self._approval_func is not None:
|
||||
# Combine all code blocks into a single string for approval
|
||||
combined_code = "\n\n".join([f"```{block.language}\n{block.code}\n```" for block in code_blocks])
|
||||
|
||||
# Get the current context from model_context
|
||||
context_messages = await self._model_context.get_messages()
|
||||
|
||||
# Create approval request
|
||||
approval_request = ApprovalRequest(code=combined_code, context=context_messages)
|
||||
|
||||
# Get approval (handle both sync and async functions)
|
||||
if self._approval_func_is_async:
|
||||
# Cast to AsyncApprovalFunc for proper typing
|
||||
async_func = cast(AsyncApprovalFunc, self._approval_func)
|
||||
approval_response = await async_func(approval_request)
|
||||
else:
|
||||
# Cast to SyncApprovalFunc for proper typing
|
||||
sync_func = cast(SyncApprovalFunc, self._approval_func)
|
||||
approval_response = sync_func(approval_request)
|
||||
|
||||
# If not approved, return error result
|
||||
if not approval_response.approved:
|
||||
return CodeResult(
|
||||
exit_code=1, output=f"Code execution was not approved. Reason: {approval_response.reason}"
|
||||
)
|
||||
|
||||
# Execute the code blocks.
|
||||
result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
|
||||
|
||||
if result.output.strip() == "":
|
||||
# No output
|
||||
result.output = f"The script ran but produced no output to console. The POSIX exit code was: {result.exit_code}. If you were expecting output, consider revising the script to ensure content is printed to stdout."
|
||||
elif result.exit_code != 0:
|
||||
# Error
|
||||
result.output = f"The script ran, then exited with an error (POSIX exit code: {result.exit_code})\nIts output was:\n{result.output}"
|
||||
|
||||
return result
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Its a no-op as the code executor agent has no mutable state."""
|
||||
pass
|
||||
|
||||
def _extract_markdown_code_blocks(self, markdown_text: str) -> List[CodeBlock]:
|
||||
pattern = re.compile(rf"```(?:\s*({self._supported_languages_regex}))\n([\s\S]*?)```", re.IGNORECASE)
|
||||
matches = pattern.findall(markdown_text)
|
||||
code_blocks: List[CodeBlock] = []
|
||||
for match in matches:
|
||||
language = match[0].strip() if match[0] else ""
|
||||
code_content = match[1]
|
||||
code_blocks.append(CodeBlock(code=code_content, language=language))
|
||||
return code_blocks
|
||||
|
||||
def _to_config(self) -> CodeExecutorAgentConfig:
|
||||
if self._approval_func is not None:
|
||||
raise ValueError(
|
||||
"Cannot serialize CodeExecutorAgent with approval_func set. The approval function is not serializable."
|
||||
)
|
||||
|
||||
return CodeExecutorAgentConfig(
|
||||
name=self.name,
|
||||
model_client=(self._model_client.dump_component() if self._model_client is not None else None),
|
||||
code_executor=self._code_executor.dump_component(),
|
||||
description=self.description,
|
||||
sources=list(self._sources) if self._sources is not None else None,
|
||||
system_message=(
|
||||
self._system_messages[0].content
|
||||
if self._system_messages and isinstance(self._system_messages[0].content, str)
|
||||
else None
|
||||
),
|
||||
model_client_stream=self._model_client_stream,
|
||||
model_context=self._model_context.dump_component(),
|
||||
supported_languages=self._supported_languages,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: CodeExecutorAgentConfig) -> Self:
|
||||
return cls(
|
||||
name=config.name,
|
||||
model_client=(
|
||||
ChatCompletionClient.load_component(config.model_client) if config.model_client is not None else None
|
||||
),
|
||||
code_executor=CodeExecutor.load_component(config.code_executor),
|
||||
description=config.description,
|
||||
sources=config.sources,
|
||||
system_message=config.system_message,
|
||||
model_client_stream=config.model_client_stream,
|
||||
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
|
||||
supported_languages=config.supported_languages,
|
||||
approval_func=None, # approval_func cannot be serialized, so it's always None when loading from config
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_compatible_context(model_client: ChatCompletionClient, messages: List[LLMMessage]) -> Sequence[LLMMessage]:
|
||||
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
|
||||
if model_client.model_info["vision"]:
|
||||
return messages
|
||||
else:
|
||||
return remove_images(messages)
|
||||
|
||||
@classmethod
|
||||
async def _call_llm(
|
||||
cls,
|
||||
model_client: ChatCompletionClient,
|
||||
model_client_stream: bool,
|
||||
system_messages: List[SystemMessage],
|
||||
model_context: ChatCompletionContext,
|
||||
agent_name: str,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> AsyncGenerator[Union[CreateResult, ModelClientStreamingChunkEvent], None]:
|
||||
"""
|
||||
Perform a model inference and yield either streaming chunk events or the final CreateResult.
|
||||
"""
|
||||
all_messages = await model_context.get_messages()
|
||||
llm_messages = cls._get_compatible_context(model_client=model_client, messages=system_messages + all_messages)
|
||||
|
||||
if model_client_stream:
|
||||
model_result: Optional[CreateResult] = None
|
||||
async for chunk in model_client.create_stream(
|
||||
llm_messages, tools=[], cancellation_token=cancellation_token
|
||||
):
|
||||
if isinstance(chunk, CreateResult):
|
||||
model_result = chunk
|
||||
elif isinstance(chunk, str):
|
||||
yield ModelClientStreamingChunkEvent(content=chunk, source=agent_name)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
|
||||
if model_result is None:
|
||||
raise RuntimeError("No final model result in streaming mode.")
|
||||
yield model_result
|
||||
else:
|
||||
model_result = await model_client.create(llm_messages, tools=[], cancellation_token=cancellation_token)
|
||||
yield model_result
|
||||
|
||||
@staticmethod
|
||||
async def _add_messages_to_context(
|
||||
model_context: ChatCompletionContext,
|
||||
messages: Sequence[BaseChatMessage],
|
||||
) -> None:
|
||||
"""
|
||||
Add incoming messages to the model context.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HandoffMessage):
|
||||
for llm_msg in msg.context:
|
||||
await model_context.add_message(llm_msg)
|
||||
await model_context.add_message(msg.to_model_message())
|
||||
|
||||
@classmethod
|
||||
async def _reflect_on_code_block_results_flow(
|
||||
cls,
|
||||
system_messages: List[SystemMessage],
|
||||
model_client: ChatCompletionClient,
|
||||
model_client_stream: bool,
|
||||
model_context: ChatCompletionContext,
|
||||
agent_name: str,
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage],
|
||||
) -> AsyncGenerator[Response | ModelClientStreamingChunkEvent | ThoughtEvent, None]:
|
||||
"""
|
||||
If reflect_on_code_block_results=True, we do another inference based on tool results
|
||||
and yield the final text response (or streaming chunks).
|
||||
"""
|
||||
all_messages = system_messages + await model_context.get_messages()
|
||||
llm_messages = cls._get_compatible_context(model_client=model_client, messages=all_messages)
|
||||
|
||||
reflection_result: Optional[CreateResult] = None
|
||||
|
||||
if model_client_stream:
|
||||
async for chunk in model_client.create_stream(llm_messages):
|
||||
if isinstance(chunk, CreateResult):
|
||||
reflection_result = chunk
|
||||
elif isinstance(chunk, str):
|
||||
yield ModelClientStreamingChunkEvent(content=chunk, source=agent_name)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
|
||||
else:
|
||||
reflection_result = await model_client.create(llm_messages)
|
||||
|
||||
if not reflection_result or not isinstance(reflection_result.content, str):
|
||||
raise RuntimeError("Reflect on tool use produced no valid text response.")
|
||||
|
||||
# --- NEW: If the reflection produced a thought, yield it ---
|
||||
if reflection_result.thought:
|
||||
thought_event = ThoughtEvent(content=reflection_result.thought, source=agent_name)
|
||||
yield thought_event
|
||||
inner_messages.append(thought_event)
|
||||
|
||||
# Add to context (including thought if present)
|
||||
await model_context.add_message(
|
||||
AssistantMessage(
|
||||
content=reflection_result.content,
|
||||
source=agent_name,
|
||||
thought=getattr(reflection_result, "thought", None),
|
||||
)
|
||||
)
|
||||
|
||||
yield Response(
|
||||
chat_message=TextMessage(
|
||||
content=reflection_result.content,
|
||||
source=agent_name,
|
||||
models_usage=reflection_result.usage,
|
||||
),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
203
agent_dhal/agentdhal_agentchat/agents/_message_filter_agent.py
Normal file
203
agent_dhal/agentdhal_agentchat/agents/_message_filter_agent.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from typing import AsyncGenerator, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from agentdhal_core import CancellationToken, Component, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agentdhal_agentchat.agents import BaseChatAgent
|
||||
from agentdhal_agentchat.base import Response
|
||||
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage
|
||||
|
||||
# ------------------------------
|
||||
# Message Filter Config
|
||||
# ------------------------------
|
||||
|
||||
|
||||
class PerSourceFilter(BaseModel):
|
||||
source: str
|
||||
position: Optional[Literal["first", "last"]] = None
|
||||
count: Optional[int] = None
|
||||
|
||||
|
||||
class MessageFilterConfig(BaseModel):
|
||||
per_source: List[PerSourceFilter]
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Component Config
|
||||
# ------------------------------
|
||||
|
||||
|
||||
class MessageFilterAgentConfig(BaseModel):
|
||||
name: str
|
||||
wrapped_agent: ComponentModel
|
||||
filter: MessageFilterConfig
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Message Filter Agent
|
||||
# ------------------------------
|
||||
|
||||
|
||||
class MessageFilterAgent(BaseChatAgent, Component[MessageFilterAgentConfig]):
|
||||
"""
|
||||
A wrapper agent that filters incoming messages before passing them to the inner agent.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental feature, and the API will change in the future releases.
|
||||
|
||||
This is useful in scenarios like multi-agent workflows where an agent should only
|
||||
process a subset of the full message history—for example, only the last message
|
||||
from each upstream agent, or only the first message from a specific source.
|
||||
|
||||
Filtering is configured using :class:`MessageFilterConfig`, which supports:
|
||||
- Filtering by message source (e.g., only messages from "user" or another agent)
|
||||
- Selecting the first N or last N messages from each source
|
||||
- If position is `None`, all messages from that source are included
|
||||
|
||||
This agent is compatible with both direct message passing and team-based execution
|
||||
such as :class:`~agentdhal_agentchat.teams.GraphFlow`.
|
||||
|
||||
Example:
|
||||
>>> agent_a = MessageFilterAgent(
|
||||
... name="A",
|
||||
... wrapped_agent=some_other_agent,
|
||||
... filter=MessageFilterConfig(
|
||||
... per_source=[
|
||||
... PerSourceFilter(source="user", position="first", count=1),
|
||||
... PerSourceFilter(source="B", position="last", count=2),
|
||||
... ]
|
||||
... ),
|
||||
... )
|
||||
|
||||
Example use case with Graph:
|
||||
Suppose you have a looping multi-agent graph: A → B → A → B → C.
|
||||
|
||||
You want:
|
||||
- A to only see the user message and the last message from B
|
||||
- B to see the user message, last message from A, and its own prior responses (for reflection)
|
||||
- C to see the user message and the last message from B
|
||||
|
||||
Wrap the agents like so:
|
||||
|
||||
>>> agent_a = MessageFilterAgent(
|
||||
... name="A",
|
||||
... wrapped_agent=agent_a_inner,
|
||||
... filter=MessageFilterConfig(
|
||||
... per_source=[
|
||||
... PerSourceFilter(source="user", position="first", count=1),
|
||||
... PerSourceFilter(source="B", position="last", count=1),
|
||||
... ]
|
||||
... ),
|
||||
... )
|
||||
|
||||
>>> agent_b = MessageFilterAgent(
|
||||
... name="B",
|
||||
... wrapped_agent=agent_b_inner,
|
||||
... filter=MessageFilterConfig(
|
||||
... per_source=[
|
||||
... PerSourceFilter(source="user", position="first", count=1),
|
||||
... PerSourceFilter(source="A", position="last", count=1),
|
||||
... PerSourceFilter(source="B", position="last", count=10),
|
||||
... ]
|
||||
... ),
|
||||
... )
|
||||
|
||||
>>> agent_c = MessageFilterAgent(
|
||||
... name="C",
|
||||
... wrapped_agent=agent_c_inner,
|
||||
... filter=MessageFilterConfig(
|
||||
... per_source=[
|
||||
... PerSourceFilter(source="user", position="first", count=1),
|
||||
... PerSourceFilter(source="B", position="last", count=1),
|
||||
... ]
|
||||
... ),
|
||||
... )
|
||||
|
||||
Then define the graph:
|
||||
|
||||
>>> graph = DiGraph(
|
||||
... nodes={
|
||||
... "A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
||||
... "B": DiGraphNode(
|
||||
... name="B",
|
||||
... edges=[
|
||||
... DiGraphEdge(target="C", condition="exit"),
|
||||
... DiGraphEdge(target="A", condition="loop"),
|
||||
... ],
|
||||
... ),
|
||||
... "C": DiGraphNode(name="C", edges=[]),
|
||||
... },
|
||||
... default_start_node="A",
|
||||
... )
|
||||
|
||||
This will ensure each agent sees only what is needed for its decision or action logic.
|
||||
"""
|
||||
|
||||
component_config_schema = MessageFilterAgentConfig
|
||||
component_provider_override = "agentdhal_agentchat.agents.MessageFilterAgent"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
wrapped_agent: BaseChatAgent,
|
||||
filter: MessageFilterConfig,
|
||||
):
|
||||
super().__init__(name=name, description=f"{wrapped_agent.description} (with message filtering)")
|
||||
self._wrapped_agent = wrapped_agent
|
||||
self._filter = filter
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return self._wrapped_agent.produced_message_types
|
||||
|
||||
def _apply_filter(self, messages: Sequence[BaseChatMessage]) -> Sequence[BaseChatMessage]:
|
||||
result: List[BaseChatMessage] = []
|
||||
|
||||
for source_filter in self._filter.per_source:
|
||||
msgs = [m for m in messages if m.source == source_filter.source]
|
||||
|
||||
if source_filter.position == "first" and source_filter.count:
|
||||
msgs = msgs[: source_filter.count]
|
||||
elif source_filter.position == "last" and source_filter.count:
|
||||
msgs = msgs[-source_filter.count :]
|
||||
|
||||
result.extend(msgs)
|
||||
|
||||
return result
|
||||
|
||||
async def on_messages(
|
||||
self,
|
||||
messages: Sequence[BaseChatMessage],
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Response:
|
||||
filtered = self._apply_filter(messages)
|
||||
return await self._wrapped_agent.on_messages(filtered, cancellation_token)
|
||||
|
||||
async def on_messages_stream(
|
||||
self,
|
||||
messages: Sequence[BaseChatMessage],
|
||||
cancellation_token: CancellationToken,
|
||||
) -> AsyncGenerator[Union[BaseAgentEvent, BaseChatMessage, Response], None]:
|
||||
filtered = self._apply_filter(messages)
|
||||
async for item in self._wrapped_agent.on_messages_stream(filtered, cancellation_token):
|
||||
yield item
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
await self._wrapped_agent.on_reset(cancellation_token)
|
||||
|
||||
def _to_config(self) -> MessageFilterAgentConfig:
|
||||
return MessageFilterAgentConfig(
|
||||
name=self.name,
|
||||
wrapped_agent=self._wrapped_agent.dump_component(),
|
||||
filter=self._filter,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: MessageFilterAgentConfig) -> "MessageFilterAgent":
|
||||
wrapped = BaseChatAgent.load_component(config.wrapped_agent)
|
||||
return cls(
|
||||
name=config.name,
|
||||
wrapped_agent=wrapped,
|
||||
filter=config.filter,
|
||||
)
|
||||
302
agent_dhal/agentdhal_agentchat/agents/_society_of_mind_agent.py
Normal file
302
agent_dhal/agentdhal_agentchat/agents/_society_of_mind_agent.py
Normal file
@@ -0,0 +1,302 @@
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence
|
||||
|
||||
from agentdhal_core import CancellationToken, Component, ComponentModel
|
||||
from agentdhal_core.model_context import (
|
||||
ChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
from agentdhal_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from agentdhal_agentchat.base import Response
|
||||
from agentdhal_agentchat.state import SocietyOfMindAgentState
|
||||
|
||||
from ..base import TaskResult, Team
|
||||
from ..messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
HandoffMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
TextMessage,
|
||||
)
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
|
||||
class SocietyOfMindAgentConfig(BaseModel):
|
||||
"""The declarative configuration for a SocietyOfMindAgent."""
|
||||
|
||||
name: str
|
||||
team: ComponentModel
|
||||
model_client: ComponentModel
|
||||
description: str | None = None
|
||||
instruction: str | None = None
|
||||
response_prompt: str | None = None
|
||||
model_context: ComponentModel | None = None
|
||||
|
||||
|
||||
class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
|
||||
"""An agent that uses an inner team of agents to generate responses.
|
||||
|
||||
Each time the agent's :meth:`on_messages` or :meth:`on_messages_stream`
|
||||
method is called, it runs the inner team of agents and then uses the
|
||||
model client to generate a response based on the inner team's messages.
|
||||
Once the response is generated, the agent resets the inner team by
|
||||
calling :meth:`Team.reset`.
|
||||
|
||||
Limit context size sent to the model:
|
||||
|
||||
You can limit the number of messages sent to the model by setting
|
||||
the `model_context` parameter to a :class:`~agentdhal_core.model_context.BufferedChatCompletionContext`.
|
||||
This will limit the number of recent messages sent to the model and can be useful
|
||||
when the model has a limit on the number of tokens it can process.
|
||||
You can also create your own model context by subclassing
|
||||
:class:`~agentdhal_core.model_context.ChatCompletionContext`.
|
||||
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
team (Team): The team of agents to use.
|
||||
model_client (ChatCompletionClient): The model client to use for preparing responses.
|
||||
description (str, optional): The description of the agent.
|
||||
instruction (str, optional): The instruction to use when generating a response using the inner team's messages.
|
||||
Defaults to :attr:`DEFAULT_INSTRUCTION`. It assumes the role of 'system'.
|
||||
response_prompt (str, optional): The response prompt to use when generating a response using the inner team's messages.
|
||||
Defaults to :attr:`DEFAULT_RESPONSE_PROMPT`. It assumes the role of 'system'.
|
||||
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~agentdhal_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset.
|
||||
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.agents import AssistantAgent, SocietyOfMindAgent
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a writer, write well.")
|
||||
agent2 = AssistantAgent(
|
||||
"assistant2",
|
||||
model_client=model_client,
|
||||
system_message="You are an editor, provide critical feedback. Respond with 'APPROVE' if the text addresses all feedbacks.",
|
||||
)
|
||||
inner_termination = TextMentionTermination("APPROVE")
|
||||
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
|
||||
|
||||
society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
|
||||
|
||||
agent3 = AssistantAgent(
|
||||
"assistant3", model_client=model_client, system_message="Translate the text to Spanish."
|
||||
)
|
||||
team = RoundRobinGroupChat([society_of_mind_agent, agent3], max_turns=2)
|
||||
|
||||
stream = team.run_stream(task="Write a short story with a surprising ending.")
|
||||
await Console(stream)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_config_schema = SocietyOfMindAgentConfig
|
||||
component_provider_override = "agentdhal_agentchat.agents.SocietyOfMindAgent"
|
||||
|
||||
DEFAULT_INSTRUCTION = "Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:"
|
||||
"""str: The default instruction to use when generating a response using the
|
||||
inner team's messages. The instruction will be prepended to the inner team's
|
||||
messages when generating a response using the model. It assumes the role of
|
||||
'system'."""
|
||||
|
||||
DEFAULT_RESPONSE_PROMPT = (
|
||||
"Output a standalone response to the original request, without mentioning any of the intermediate discussion."
|
||||
)
|
||||
"""str: The default response prompt to use when generating a response using
|
||||
the inner team's messages. It assumes the role of 'system'."""
|
||||
|
||||
DEFAULT_DESCRIPTION = "An agent that uses an inner team of agents to generate responses."
|
||||
"""str: The default description for a SocietyOfMindAgent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
team: Team,
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
instruction: str = DEFAULT_INSTRUCTION,
|
||||
response_prompt: str = DEFAULT_RESPONSE_PROMPT,
|
||||
model_context: ChatCompletionContext | None = None,
|
||||
) -> None:
|
||||
super().__init__(name=name, description=description)
|
||||
self._team = team
|
||||
self._model_client = model_client
|
||||
self._instruction = instruction
|
||||
self._response_prompt = response_prompt
|
||||
|
||||
if model_context is not None:
|
||||
self._model_context = model_context
|
||||
else:
|
||||
self._model_context = UnboundedChatCompletionContext()
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (TextMessage,)
|
||||
|
||||
@property
|
||||
def model_context(self) -> ChatCompletionContext:
|
||||
"""
|
||||
The model context in use by the agent.
|
||||
"""
|
||||
return self._model_context
|
||||
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
# Call the stream method and collect the messages.
|
||||
response: Response | None = None
|
||||
async for msg in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(msg, Response):
|
||||
response = msg
|
||||
assert response is not None
|
||||
return response
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
# Prepare the task for the team of agents.
|
||||
task_messages = list(messages)
|
||||
|
||||
# Run the team of agents.
|
||||
result: TaskResult | None = None
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
model_context = self._model_context
|
||||
|
||||
prev_content = await model_context.get_messages()
|
||||
if len(prev_content) > 0:
|
||||
prev_message = HandoffMessage(
|
||||
content="relevant previous messages",
|
||||
source=self.name,
|
||||
target="",
|
||||
context=prev_content,
|
||||
)
|
||||
task_messages = [prev_message] + task_messages
|
||||
|
||||
if len(task_messages) == 0:
|
||||
task = None
|
||||
else:
|
||||
task = task_messages
|
||||
|
||||
# Use the new output_task_messages parameter to avoid fragile count-based logic
|
||||
async for inner_msg in self._team.run_stream(
|
||||
task=task, cancellation_token=cancellation_token, output_task_messages=False
|
||||
):
|
||||
if isinstance(inner_msg, TaskResult):
|
||||
result = inner_msg
|
||||
else:
|
||||
yield inner_msg
|
||||
if isinstance(inner_msg, ModelClientStreamingChunkEvent):
|
||||
# Skip the model client streaming chunk events.
|
||||
continue
|
||||
inner_messages.append(inner_msg)
|
||||
assert result is not None
|
||||
|
||||
if len(inner_messages) == 0:
|
||||
yield Response(
|
||||
chat_message=TextMessage(source=self.name, content="No response."),
|
||||
inner_messages=[],
|
||||
# Response's inner_messages should be empty. Cause that mean is response to outer world.
|
||||
)
|
||||
else:
|
||||
llm_messages: List[LLMMessage] = []
|
||||
|
||||
if self._model_client.model_info.get("multiple_system_messages", False):
|
||||
# The model client supports multiple system messages, so we
|
||||
llm_messages.append(SystemMessage(content=self._instruction))
|
||||
else:
|
||||
# The model client does not support multiple system messages, so we
|
||||
llm_messages.append(UserMessage(content=self._instruction, source="user"))
|
||||
|
||||
# Generate a response using the model client.
|
||||
for message in inner_messages:
|
||||
if isinstance(message, BaseChatMessage):
|
||||
llm_messages.append(message.to_model_message())
|
||||
|
||||
if self._model_client.model_info.get("multiple_system_messages", False):
|
||||
# The model client supports multiple system messages, so we
|
||||
llm_messages.append(SystemMessage(content=self._response_prompt))
|
||||
else:
|
||||
# The model client does not support multiple system messages, so we
|
||||
llm_messages.append(UserMessage(content=self._response_prompt, source="user"))
|
||||
completion = await self._model_client.create(messages=llm_messages, cancellation_token=cancellation_token)
|
||||
assert isinstance(completion.content, str)
|
||||
yield Response(
|
||||
chat_message=TextMessage(source=self.name, content=completion.content, models_usage=completion.usage),
|
||||
inner_messages=[],
|
||||
# Response's inner_messages should be empty. Cause that mean is response to outer world.
|
||||
)
|
||||
|
||||
# Add new user/handoff messages to the model context
|
||||
await self._add_messages_to_context(
|
||||
model_context=model_context,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Reset the team.
|
||||
await self._team.reset()
|
||||
|
||||
@staticmethod
|
||||
async def _add_messages_to_context(
|
||||
model_context: ChatCompletionContext,
|
||||
messages: Sequence[BaseChatMessage],
|
||||
) -> None:
|
||||
"""
|
||||
Add incoming messages to the model context.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HandoffMessage):
|
||||
for llm_msg in msg.context:
|
||||
await model_context.add_message(llm_msg)
|
||||
await model_context.add_message(msg.to_model_message())
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
await self._team.reset()
|
||||
await self._model_context.clear()
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
team_state = await self._team.save_state()
|
||||
state = SocietyOfMindAgentState(inner_team_state=team_state)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
society_of_mind_state = SocietyOfMindAgentState.model_validate(state)
|
||||
await self._team.load_state(society_of_mind_state.inner_team_state)
|
||||
|
||||
def _to_config(self) -> SocietyOfMindAgentConfig:
|
||||
return SocietyOfMindAgentConfig(
|
||||
name=self.name,
|
||||
team=self._team.dump_component(),
|
||||
model_client=self._model_client.dump_component(),
|
||||
description=self.description,
|
||||
instruction=self._instruction,
|
||||
response_prompt=self._response_prompt,
|
||||
model_context=self._model_context.dump_component(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: SocietyOfMindAgentConfig) -> Self:
|
||||
model_client = ChatCompletionClient.load_component(config.model_client)
|
||||
team = Team.load_component(config.team)
|
||||
return cls(
|
||||
name=config.name,
|
||||
team=team,
|
||||
model_client=model_client,
|
||||
description=config.description or cls.DEFAULT_DESCRIPTION,
|
||||
instruction=config.instruction or cls.DEFAULT_INSTRUCTION,
|
||||
response_prompt=config.response_prompt or cls.DEFAULT_RESPONSE_PROMPT,
|
||||
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
|
||||
)
|
||||
249
agent_dhal/agentdhal_agentchat/agents/_user_proxy_agent.py
Normal file
249
agent_dhal/agentdhal_agentchat/agents/_user_proxy_agent.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, ClassVar, Generator, Optional, Sequence, Union, cast
|
||||
|
||||
from agentdhal_core import CancellationToken, Component
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..base import Response
|
||||
from ..messages import BaseAgentEvent, BaseChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
SyncInputFunc = Callable[[str], str]
|
||||
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
|
||||
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
|
||||
|
||||
|
||||
# TODO: check if using to_thread fixes this in jupyter
|
||||
async def cancellable_input(prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
|
||||
task: asyncio.Task[str] = asyncio.create_task(asyncio.to_thread(input, prompt))
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(task)
|
||||
return await task
|
||||
|
||||
|
||||
class UserProxyAgentConfig(BaseModel):
|
||||
"""Declarative configuration for the UserProxyAgent."""
|
||||
|
||||
name: str
|
||||
description: str = "A human user"
|
||||
input_func: str | None = None
|
||||
|
||||
|
||||
class UserProxyAgent(BaseChatAgent, Component[UserProxyAgentConfig]):
|
||||
"""An agent that can represent a human user through an input function.
|
||||
|
||||
This agent can be used to represent a human user in a chat system by providing a custom input function.
|
||||
|
||||
.. note::
|
||||
|
||||
Using :class:`UserProxyAgent` puts a running team in a temporary blocked
|
||||
state until the user responds. So it is important to time out the user input
|
||||
function and cancel using the :class:`~agentdhal_core.CancellationToken` if the user does not respond.
|
||||
The input function should also handle exceptions and return a default response if needed.
|
||||
|
||||
For typical use cases that involve
|
||||
slow human responses, it is recommended to use termination conditions
|
||||
such as :class:`~agentdhal_agentchat.conditions.HandoffTermination` or :class:`~agentdhal_agentchat.conditions.SourceMatchTermination`
|
||||
to stop the running team and return the control to the application.
|
||||
You can run the team again with the user input. This way, the state of the team
|
||||
can be saved and restored when the user responds.
|
||||
|
||||
See `Human-in-the-loop <https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/human-in-the-loop.html>`_ for more information.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
description (str, optional): A description of the agent.
|
||||
input_func (Optional[Callable[[str], str]], Callable[[str, Optional[CancellationToken]], Awaitable[str]]): A function that takes a prompt and returns a user input string.
|
||||
|
||||
For examples of integrating with web and UI frameworks, see the following:
|
||||
|
||||
* `FastAPI <https://github.com/microsoft/autogen/tree/main/python/samples/agentchat_fastapi>`_
|
||||
* `ChainLit <https://github.com/microsoft/autogen/tree/main/python/samples/agentchat_chainlit>`_
|
||||
|
||||
Example:
|
||||
Simple usage case::
|
||||
|
||||
import asyncio
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_agentchat.agents import UserProxyAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
|
||||
|
||||
async def simple_user_agent():
|
||||
agent = UserProxyAgent("user_proxy")
|
||||
response = await asyncio.create_task(
|
||||
agent.on_messages(
|
||||
[TextMessage(content="What is your name? ", source="user")],
|
||||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
)
|
||||
assert isinstance(response.chat_message, TextMessage)
|
||||
print(f"Your name is {response.chat_message.content}")
|
||||
|
||||
Example:
|
||||
Cancellable usage case::
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_agentchat.agents import UserProxyAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
|
||||
|
||||
token = CancellationToken()
|
||||
agent = UserProxyAgent("user_proxy")
|
||||
|
||||
|
||||
async def timeout(delay: float):
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
|
||||
def cancellation_callback(task: asyncio.Task[Any]):
|
||||
token.cancel()
|
||||
|
||||
|
||||
async def cancellable_user_agent():
|
||||
try:
|
||||
timeout_task = asyncio.create_task(timeout(3))
|
||||
timeout_task.add_done_callback(cancellation_callback)
|
||||
agent_task = asyncio.create_task(
|
||||
agent.on_messages(
|
||||
[TextMessage(content="What is your name? ", source="user")],
|
||||
cancellation_token=token,
|
||||
)
|
||||
)
|
||||
response = await agent_task
|
||||
assert isinstance(response.chat_message, TextMessage)
|
||||
print(f"Your name is {response.chat_message.content}")
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}")
|
||||
except BaseException as e:
|
||||
print(f"BaseException: {e}")
|
||||
"""
|
||||
|
||||
component_type = "agent"
|
||||
component_provider_override = "agentdhal_agentchat.agents.UserProxyAgent"
|
||||
component_config_schema = UserProxyAgentConfig
|
||||
|
||||
class InputRequestContext:
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError(
|
||||
"InputRequestContext cannot be instantiated. It is a static class that provides context management for user input requests."
|
||||
)
|
||||
|
||||
_INPUT_REQUEST_CONTEXT_VAR: ClassVar[ContextVar[str]] = ContextVar("_INPUT_REQUEST_CONTEXT_VAR")
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def populate_context(cls, ctx: str) -> Generator[None, Any, None]:
|
||||
""":meta private:"""
|
||||
token = UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.reset(token)
|
||||
|
||||
@classmethod
|
||||
def request_id(cls) -> str:
|
||||
try:
|
||||
return cls._INPUT_REQUEST_CONTEXT_VAR.get()
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"InputRequestContext.runtime() must be called within the input callback of a UserProxyAgent."
|
||||
) from e
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
description: str = "A human user",
|
||||
input_func: Optional[InputFuncType] = None,
|
||||
) -> None:
|
||||
"""Initialize the UserProxyAgent."""
|
||||
super().__init__(name=name, description=description)
|
||||
self.input_func = input_func or cancellable_input
|
||||
self._is_async = iscoroutinefunction(self.input_func)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""Message types this agent can produce."""
|
||||
return (TextMessage, HandoffMessage)
|
||||
|
||||
def _get_latest_handoff(self, messages: Sequence[BaseChatMessage]) -> Optional[HandoffMessage]:
|
||||
"""Find the HandoffMessage in the message sequence that addresses this agent."""
|
||||
if len(messages) > 0 and isinstance(messages[-1], HandoffMessage):
|
||||
if messages[-1].target == self.name:
|
||||
return messages[-1]
|
||||
else:
|
||||
raise RuntimeError(f"Handoff message target does not match agent name: {messages[-1].source}")
|
||||
return None
|
||||
|
||||
async def _get_input(self, prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
|
||||
"""Handle input based on function signature."""
|
||||
try:
|
||||
if self._is_async:
|
||||
# Cast to AsyncInputFunc for proper typing
|
||||
async_func = cast(AsyncInputFunc, self.input_func)
|
||||
return await async_func(prompt, cancellation_token)
|
||||
else:
|
||||
# Cast to SyncInputFunc for proper typing
|
||||
sync_func = cast(SyncInputFunc, self.input_func)
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, sync_func, prompt)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
|
||||
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
return message
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""Handle incoming messages by requesting user input."""
|
||||
try:
|
||||
# Check for handoff first
|
||||
handoff = self._get_latest_handoff(messages)
|
||||
prompt = (
|
||||
f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: "
|
||||
)
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
input_requested_event = UserInputRequestedEvent(request_id=request_id, source=self.name)
|
||||
yield input_requested_event
|
||||
with UserProxyAgent.InputRequestContext.populate_context(request_id):
|
||||
user_input = await self._get_input(prompt, cancellation_token)
|
||||
|
||||
# Return appropriate message type based on handoff presence
|
||||
if handoff:
|
||||
yield Response(chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name))
|
||||
else:
|
||||
yield Response(chat_message=TextMessage(content=user_input, source=self.name))
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
|
||||
|
||||
async def on_reset(self, cancellation_token: Optional[CancellationToken] = None) -> None:
|
||||
"""Reset agent state."""
|
||||
pass
|
||||
|
||||
def _to_config(self) -> UserProxyAgentConfig:
|
||||
# TODO: Add ability to serialie input_func
|
||||
return UserProxyAgentConfig(name=self.name, description=self.description, input_func=None)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: UserProxyAgentConfig) -> Self:
|
||||
return cls(name=config.name, description=config.description, input_func=None)
|
||||
Reference in New Issue
Block a user