first commit
This commit is contained in:
179
agent_dhal/agentdhal_agentchat/base/_termination.py
Normal file
179
agent_dhal/agentdhal_agentchat/base/_termination.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Sequence
|
||||
|
||||
from agentdhal_core import Component, ComponentBase, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..messages import BaseAgentEvent, BaseChatMessage, StopMessage
|
||||
|
||||
|
||||
class TerminatedException(BaseException): ...
|
||||
|
||||
|
||||
class TerminationCondition(ABC, ComponentBase[BaseModel]):
|
||||
"""A stateful condition that determines when a conversation should be terminated.
|
||||
|
||||
A termination condition is a callable that takes a sequence of BaseChatMessage objects
|
||||
since the last time the condition was called, and returns a StopMessage if the
|
||||
conversation should be terminated, or None otherwise.
|
||||
Once a termination condition has been reached, it must be reset before it can be used again.
|
||||
|
||||
Termination conditions can be combined using the AND and OR operators.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.conditions import MaxMessageTermination, TextMentionTermination
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Terminate the conversation after 10 turns or if the text "TERMINATE" is mentioned.
|
||||
cond1 = MaxMessageTermination(10) | TextMentionTermination("TERMINATE")
|
||||
|
||||
# Terminate the conversation after 10 turns and if the text "TERMINATE" is mentioned.
|
||||
cond2 = MaxMessageTermination(10) & TextMentionTermination("TERMINATE")
|
||||
|
||||
# ...
|
||||
|
||||
# Reset the termination condition.
|
||||
await cond1.reset()
|
||||
await cond2.reset()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_type = "termination"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def terminated(self) -> bool:
|
||||
"""Check if the termination condition has been reached"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
"""Check if the conversation should be terminated based on the messages received
|
||||
since the last time the condition was called.
|
||||
Return a StopMessage if the conversation should be terminated, or None otherwise.
|
||||
|
||||
Args:
|
||||
messages: The messages received since the last time the condition was called.
|
||||
|
||||
Returns:
|
||||
StopMessage | None: A StopMessage if the conversation should be terminated, or None otherwise.
|
||||
|
||||
Raises:
|
||||
TerminatedException: If the termination condition has already been reached."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""Reset the termination condition."""
|
||||
...
|
||||
|
||||
def __and__(self, other: "TerminationCondition") -> "TerminationCondition":
|
||||
"""Combine two termination conditions with an AND operation."""
|
||||
return AndTerminationCondition(self, other)
|
||||
|
||||
def __or__(self, other: "TerminationCondition") -> "TerminationCondition":
|
||||
"""Combine two termination conditions with an OR operation."""
|
||||
return OrTerminationCondition(self, other)
|
||||
|
||||
|
||||
class AndTerminationConditionConfig(BaseModel):
|
||||
conditions: List[ComponentModel]
|
||||
|
||||
|
||||
class AndTerminationCondition(TerminationCondition, Component[AndTerminationConditionConfig]):
|
||||
component_config_schema = AndTerminationConditionConfig
|
||||
component_type = "termination"
|
||||
component_provider_override = "agentdhal_agentchat.base.AndTerminationCondition"
|
||||
|
||||
def __init__(self, *conditions: TerminationCondition) -> None:
|
||||
self._conditions = conditions
|
||||
self._stop_messages: List[StopMessage] = []
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return all(condition.terminated for condition in self._conditions)
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self.terminated:
|
||||
raise TerminatedException("Termination condition has already been reached.")
|
||||
# Check all remaining conditions.
|
||||
stop_messages = await asyncio.gather(
|
||||
*[condition(messages) for condition in self._conditions if not condition.terminated]
|
||||
)
|
||||
# Collect stop messages.
|
||||
for stop_message in stop_messages:
|
||||
if stop_message is not None:
|
||||
self._stop_messages.append(stop_message)
|
||||
if any(stop_message is None for stop_message in stop_messages):
|
||||
# If any remaining condition has not reached termination, it is not terminated.
|
||||
return None
|
||||
content = ", ".join(stop_message.content for stop_message in self._stop_messages)
|
||||
source = ", ".join(stop_message.source for stop_message in self._stop_messages)
|
||||
return StopMessage(content=content, source=source)
|
||||
|
||||
async def reset(self) -> None:
|
||||
for condition in self._conditions:
|
||||
await condition.reset()
|
||||
self._stop_messages.clear()
|
||||
|
||||
def _to_config(self) -> AndTerminationConditionConfig:
|
||||
"""Convert the AND termination condition to a config."""
|
||||
return AndTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions])
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: AndTerminationConditionConfig) -> Self:
|
||||
"""Create an AND termination condition from a config."""
|
||||
conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions]
|
||||
return cls(*conditions)
|
||||
|
||||
|
||||
class OrTerminationConditionConfig(BaseModel):
|
||||
conditions: List[ComponentModel]
|
||||
"""List of termination conditions where any one being satisfied is sufficient."""
|
||||
|
||||
|
||||
class OrTerminationCondition(TerminationCondition, Component[OrTerminationConditionConfig]):
|
||||
component_config_schema = OrTerminationConditionConfig
|
||||
component_type = "termination"
|
||||
component_provider_override = "agentdhal_agentchat.base.OrTerminationCondition"
|
||||
|
||||
def __init__(self, *conditions: TerminationCondition) -> None:
|
||||
self._conditions = conditions
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return any(condition.terminated for condition in self._conditions)
|
||||
|
||||
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
|
||||
if self.terminated:
|
||||
raise RuntimeError("Termination condition has already been reached")
|
||||
stop_messages = await asyncio.gather(*[condition(messages) for condition in self._conditions])
|
||||
stop_messages_filter = [stop_message for stop_message in stop_messages if stop_message is not None]
|
||||
if len(stop_messages_filter) > 0:
|
||||
content = ", ".join(stop_message.content for stop_message in stop_messages_filter)
|
||||
source = ", ".join(stop_message.source for stop_message in stop_messages_filter)
|
||||
return StopMessage(content=content, source=source)
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
for condition in self._conditions:
|
||||
await condition.reset()
|
||||
|
||||
def _to_config(self) -> OrTerminationConditionConfig:
|
||||
"""Convert the OR termination condition to a config."""
|
||||
return OrTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions])
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: OrTerminationConditionConfig) -> Self:
|
||||
"""Create an OR termination condition from a config."""
|
||||
conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions]
|
||||
return cls(*conditions)
|
||||
Reference in New Issue
Block a user