first commit
This commit is contained in:
18
agent_dhal/agentdhal_agentchat/base/__init__.py
Normal file
18
agent_dhal/agentdhal_agentchat/base/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from ._chat_agent import ChatAgent, Response
|
||||
from ._handoff import Handoff
|
||||
from ._task import TaskResult, TaskRunner
|
||||
from ._team import Team
|
||||
from ._termination import AndTerminationCondition, OrTerminationCondition, TerminatedException, TerminationCondition
|
||||
|
||||
__all__ = [
|
||||
"ChatAgent",
|
||||
"Response",
|
||||
"Team",
|
||||
"TerminatedException",
|
||||
"TerminationCondition",
|
||||
"AndTerminationCondition",
|
||||
"OrTerminationCondition",
|
||||
"TaskResult",
|
||||
"TaskRunner",
|
||||
"Handoff",
|
||||
]
|
||||
94
agent_dhal/agentdhal_agentchat/base/_chat_agent.py
Normal file
94
agent_dhal/agentdhal_agentchat/base/_chat_agent.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Mapping, Sequence
|
||||
|
||||
from agentdhal_core import CancellationToken, ComponentBase
|
||||
from pydantic import BaseModel, SerializeAsAny
|
||||
|
||||
from ..messages import BaseAgentEvent, BaseChatMessage
|
||||
from ._task import TaskRunner
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Response:
|
||||
"""A response from calling :meth:`ChatAgent.on_messages`."""
|
||||
|
||||
chat_message: SerializeAsAny[BaseChatMessage]
|
||||
"""A chat message produced by the agent as the response."""
|
||||
|
||||
inner_messages: Sequence[SerializeAsAny[BaseAgentEvent | BaseChatMessage]] | None = None
|
||||
"""Inner messages produced by the agent, they can be :class:`BaseAgentEvent`
|
||||
or :class:`BaseChatMessage`."""
|
||||
|
||||
|
||||
class ChatAgent(ABC, TaskRunner, ComponentBase[BaseModel]):
|
||||
"""Protocol for a chat agent."""
|
||||
|
||||
component_type = "agent"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""The name of the agent. This is used by team to uniquely identify
|
||||
the agent. It should be unique within the team."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""The description of the agent. This is used by team to
|
||||
make decisions about which agents to use. The description should
|
||||
describe the agent's capabilities and how to interact with it."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""The types of messages that the agent produces in the
|
||||
:attr:`Response.chat_message` field. They must be :class:`BaseChatMessage` types."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
"""Handles incoming messages and returns a response."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def on_messages_stream(
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""Handles incoming messages and returns a stream of inner messages and
|
||||
and the final item is the response."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Resets the agent to its initialization state."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_pause(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Called when the agent is paused. The agent may be running in :meth:`on_messages` or
|
||||
:meth:`on_messages_stream` when this method is called."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_resume(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Called when the agent is resumed. The agent may be running in :meth:`on_messages` or
|
||||
:meth:`on_messages_stream` when this method is called."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save agent state for later restoration"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Restore agent from saved state"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Release any resources held by the agent."""
|
||||
...
|
||||
62
agent_dhal/agentdhal_agentchat/base/_handoff.py
Normal file
62
agent_dhal/agentdhal_agentchat/base/_handoff.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from agentdhal_core.tools import BaseTool, FunctionTool
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from .. import EVENT_LOGGER_NAME
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class Handoff(BaseModel):
|
||||
"""Handoff configuration."""
|
||||
|
||||
target: str
|
||||
"""The name of the target agent to handoff to."""
|
||||
|
||||
description: str = Field(default="")
|
||||
"""The description of the handoff such as the condition under which it should happen and the target agent's ability.
|
||||
If not provided, it is generated from the target agent's name."""
|
||||
|
||||
name: str = Field(default="")
|
||||
"""The name of this handoff configuration. If not provided, it is generated from the target agent's name."""
|
||||
|
||||
message: str = Field(default="")
|
||||
"""The message to the target agent.
|
||||
By default, it will be the result for the handoff tool.
|
||||
If not provided, it is generated from the target agent's name."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if not values.get("description"):
|
||||
values["description"] = f"Handoff to {values['target']}."
|
||||
if not values.get("name"):
|
||||
values["name"] = f"transfer_to_{values['target']}".lower()
|
||||
else:
|
||||
name = values["name"]
|
||||
if not isinstance(name, str):
|
||||
raise ValueError(f"Handoff name must be a string: {values['name']}")
|
||||
# Check if name is a valid identifier.
|
||||
if not name.isidentifier():
|
||||
raise ValueError(f"Handoff name must be a valid identifier: {values['name']}")
|
||||
if not values.get("message"):
|
||||
values["message"] = (
|
||||
f"Transferred to {values['target']}, adopting the role of {values['target']} immediately."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def handoff_tool(self) -> BaseTool[BaseModel, BaseModel]:
|
||||
"""Create a handoff tool from this handoff configuration."""
|
||||
|
||||
def _handoff_tool() -> str:
|
||||
return self.message
|
||||
|
||||
return FunctionTool(_handoff_tool, name=self.name, description=self.description, strict=True)
|
||||
|
||||
"""
|
||||
The tool that can be used to handoff to the target agent.
|
||||
Typically, the results of the tool's execution are provided to the target agent.
|
||||
"""
|
||||
65
agent_dhal/agentdhal_agentchat/base/_task.py
Normal file
65
agent_dhal/agentdhal_agentchat/base/_task.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import AsyncGenerator, Protocol, Sequence
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from pydantic import BaseModel, SerializeAsAny
|
||||
|
||||
from ..messages import BaseAgentEvent, BaseChatMessage
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
"""Result of running a task."""
|
||||
|
||||
messages: Sequence[SerializeAsAny[BaseAgentEvent | BaseChatMessage]]
|
||||
"""Messages produced by the task."""
|
||||
|
||||
stop_reason: str | None = None
|
||||
"""The reason the task stopped."""
|
||||
|
||||
|
||||
class TaskRunner(Protocol):
|
||||
"""A task runner."""
|
||||
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
output_task_messages: bool = True,
|
||||
) -> TaskResult:
|
||||
"""Run the task and return the result.
|
||||
|
||||
The task can be a string, a single message, or a sequence of messages.
|
||||
|
||||
The runner is stateful and a subsequent call to this method will continue
|
||||
from where the previous call left off. If the task is not specified,
|
||||
the runner will continue with the current task.
|
||||
|
||||
Args:
|
||||
task: The task to run. Can be a string, a single message, or a sequence of messages.
|
||||
cancellation_token: The cancellation token to kill the task immediately.
|
||||
output_task_messages: Whether to include task messages in :attr:`TaskResult.messages`. Defaults to True for backward compatibility.
|
||||
"""
|
||||
...
|
||||
|
||||
def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
output_task_messages: bool = True,
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
|
||||
"""Run the task and produces a stream of messages and the final result
|
||||
:class:`TaskResult` as the last item in the stream.
|
||||
|
||||
The task can be a string, a single message, or a sequence of messages.
|
||||
|
||||
The runner is stateful and a subsequent call to this method will continue
|
||||
from where the previous call left off. If the task is not specified,
|
||||
the runner will continue with the current task.
|
||||
|
||||
Args:
|
||||
task: The task to run. Can be a string, a single message, or a sequence of messages.
|
||||
cancellation_token: The cancellation token to kill the task immediately.
|
||||
output_task_messages: Whether to include task messages in the output stream. Defaults to True for backward compatibility.
|
||||
"""
|
||||
...
|
||||
54
agent_dhal/agentdhal_agentchat/base/_team.py
Normal file
54
agent_dhal/agentdhal_agentchat/base/_team.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Mapping
|
||||
|
||||
from agentdhal_core import ComponentBase
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._task import TaskRunner
|
||||
|
||||
|
||||
class Team(ABC, TaskRunner, ComponentBase[BaseModel]):
|
||||
component_type = "team"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""The name of the team. This is used by team to uniquely identify itself
|
||||
in a larger team of teams."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""A description of the team. This is used to provide context about the
|
||||
team and its purpose to its parent orchestrator."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""Reset the team and all its participants to its initial state."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def pause(self) -> None:
|
||||
"""Pause the team and all its participants. This is useful for
|
||||
pausing the :meth:`agentdhal_agentchat.base.TaskRunner.run` or
|
||||
:meth:`agentdhal_agentchat.base.TaskRunner.run_stream` methods from
|
||||
concurrently, while keeping them alive."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def resume(self) -> None:
|
||||
"""Resume the team and all its participants from a pause after
|
||||
:meth:`pause` was called."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the current state of the team."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of the team."""
|
||||
...
|
||||
179
agent_dhal/agentdhal_agentchat/base/_termination.py
Normal file
179
agent_dhal/agentdhal_agentchat/base/_termination.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Sequence
|
||||
|
||||
from agentdhal_core import Component, ComponentBase, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..messages import BaseAgentEvent, BaseChatMessage, StopMessage
|
||||
|
||||
|
||||
class TerminatedException(BaseException): ...
|
||||
|
||||
|
||||
class TerminationCondition(ABC, ComponentBase[BaseModel]):
|
||||
"""A stateful condition that determines when a conversation should be terminated.
|
||||
|
||||
A termination condition is a callable that takes a sequence of BaseChatMessage objects
|
||||
since the last time the condition was called, and returns a StopMessage if the
|
||||
conversation should be terminated, or None otherwise.
|
||||
Once a termination condition has been reached, it must be reset before it can be used again.
|
||||
|
||||
Termination conditions can be combined using the AND and OR operators.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination, TextMentionTermination
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Terminate the conversation after 10 turns or if the text "TERMINATE" is mentioned.
|
||||
cond1 = MaxMessageTermination(10) | TextMentionTermination("TERMINATE")
|
||||
|
||||
# Terminate the conversation after 10 turns and if the text "TERMINATE" is mentioned.
|
||||
cond2 = MaxMessageTermination(10) & TextMentionTermination("TERMINATE")
|
||||
|
||||
# ...
|
||||
|
||||
# Reset the termination condition.
|
||||
await cond1.reset()
|
||||
await cond2.reset()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_type = "termination"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def terminated(self) -> bool:
|
||||
"""Check if the termination condition has been reached"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
"""Check if the conversation should be terminated based on the messages received
|
||||
since the last time the condition was called.
|
||||
Return a StopMessage if the conversation should be terminated, or None otherwise.
|
||||
|
||||
Args:
|
||||
messages: The messages received since the last time the condition was called.
|
||||
|
||||
Returns:
|
||||
StopMessage | None: A StopMessage if the conversation should be terminated, or None otherwise.
|
||||
|
||||
Raises:
|
||||
TerminatedException: If the termination condition has already been reached."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""Reset the termination condition."""
|
||||
...
|
||||
|
||||
def __and__(self, other: "TerminationCondition") -> "TerminationCondition":
|
||||
"""Combine two termination conditions with an AND operation."""
|
||||
return AndTerminationCondition(self, other)
|
||||
|
||||
def __or__(self, other: "TerminationCondition") -> "TerminationCondition":
|
||||
"""Combine two termination conditions with an OR operation."""
|
||||
return OrTerminationCondition(self, other)
|
||||
|
||||
|
||||
class AndTerminationConditionConfig(BaseModel):
|
||||
conditions: List[ComponentModel]
|
||||
|
||||
|
||||
class AndTerminationCondition(TerminationCondition, Component[AndTerminationConditionConfig]):
|
||||
component_config_schema = AndTerminationConditionConfig
|
||||
component_type = "termination"
|
||||
component_provider_override = "agentdhal_agentchat.base.AndTerminationCondition"
|
||||
|
||||
def __init__(self, *conditions: TerminationCondition) -> None:
|
||||
self._conditions = conditions
|
||||
self._stop_messages: List[StopMessage] = []
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return all(condition.terminated for condition in self._conditions)
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self.terminated:
|
||||
raise TerminatedException("Termination condition has already been reached.")
|
||||
# Check all remaining conditions.
|
||||
stop_messages = await asyncio.gather(
|
||||
*[condition(messages) for condition in self._conditions if not condition.terminated]
|
||||
)
|
||||
# Collect stop messages.
|
||||
for stop_message in stop_messages:
|
||||
if stop_message is not None:
|
||||
self._stop_messages.append(stop_message)
|
||||
if any(stop_message is None for stop_message in stop_messages):
|
||||
# If any remaining condition has not reached termination, it is not terminated.
|
||||
return None
|
||||
content = ", ".join(stop_message.content for stop_message in self._stop_messages)
|
||||
source = ", ".join(stop_message.source for stop_message in self._stop_messages)
|
||||
return StopMessage(content=content, source=source)
|
||||
|
||||
async def reset(self) -> None:
|
||||
for condition in self._conditions:
|
||||
await condition.reset()
|
||||
self._stop_messages.clear()
|
||||
|
||||
def _to_config(self) -> AndTerminationConditionConfig:
|
||||
"""Convert the AND termination condition to a config."""
|
||||
return AndTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions])
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: AndTerminationConditionConfig) -> Self:
|
||||
"""Create an AND termination condition from a config."""
|
||||
conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions]
|
||||
return cls(*conditions)
|
||||
|
||||
|
||||
class OrTerminationConditionConfig(BaseModel):
|
||||
conditions: List[ComponentModel]
|
||||
"""List of termination conditions where any one being satisfied is sufficient."""
|
||||
|
||||
|
||||
class OrTerminationCondition(TerminationCondition, Component[OrTerminationConditionConfig]):
|
||||
component_config_schema = OrTerminationConditionConfig
|
||||
component_type = "termination"
|
||||
component_provider_override = "agentdhal_agentchat.base.OrTerminationCondition"
|
||||
|
||||
def __init__(self, *conditions: TerminationCondition) -> None:
|
||||
self._conditions = conditions
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return any(condition.terminated for condition in self._conditions)
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self.terminated:
|
||||
raise RuntimeError("Termination condition has already been reached")
|
||||
stop_messages = await asyncio.gather(*[condition(messages) for condition in self._conditions])
|
||||
stop_messages_filter = [stop_message for stop_message in stop_messages if stop_message is not None]
|
||||
if len(stop_messages_filter) > 0:
|
||||
content = ", ".join(stop_message.content for stop_message in stop_messages_filter)
|
||||
source = ", ".join(stop_message.source for stop_message in stop_messages_filter)
|
||||
return StopMessage(content=content, source=source)
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
for condition in self._conditions:
|
||||
await condition.reset()
|
||||
|
||||
def _to_config(self) -> OrTerminationConditionConfig:
|
||||
"""Convert the OR termination condition to a config."""
|
||||
return OrTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions])
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: OrTerminationConditionConfig) -> Self:
|
||||
"""Create an OR termination condition from a config."""
|
||||
conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions]
|
||||
return cls(*conditions)
|
||||
Reference in New Issue
Block a user