first commit
This commit is contained in:
142
agent_dhal/agentdhal_core/__init__.py
Normal file
142
agent_dhal/agentdhal_core/__init__.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# AgentDhal Core Module - Self-contained version
|
||||
__version__ = "1.0.0"
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_instantiation import AgentInstantiationContext
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_proxy import AgentProxy
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._agent_type import AgentType
|
||||
from ._base_agent import BaseAgent
|
||||
from ._cache_store import CacheStore, InMemoryStore
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._closure_agent import ClosureAgent, ClosureContext
|
||||
from ._component_config import (
|
||||
Component,
|
||||
ComponentBase,
|
||||
ComponentFromConfig,
|
||||
ComponentLoader,
|
||||
ComponentModel,
|
||||
ComponentSchemaType,
|
||||
ComponentToConfig,
|
||||
ComponentType,
|
||||
is_component_class,
|
||||
is_component_instance,
|
||||
)
|
||||
from ._constants import (
|
||||
EVENT_LOGGER_NAME as EVENT_LOGGER_NAME_ALIAS,
|
||||
)
|
||||
from ._constants import (
|
||||
ROOT_LOGGER_NAME as ROOT_LOGGER_NAME_ALIAS,
|
||||
)
|
||||
from ._constants import (
|
||||
TRACE_LOGGER_NAME as TRACE_LOGGER_NAME_ALIAS,
|
||||
)
|
||||
from ._default_subscription import DefaultSubscription, default_subscription, type_subscription
|
||||
from ._default_topic import DefaultTopicId
|
||||
from ._image import Image
|
||||
from ._intervention import (
|
||||
DefaultInterventionHandler,
|
||||
DropMessage,
|
||||
InterventionHandler,
|
||||
)
|
||||
from ._message_context import MessageContext
|
||||
from ._message_handler_context import MessageHandlerContext
|
||||
from ._routed_agent import RoutedAgent, event, message_handler, rpc
|
||||
from ._serialization import (
|
||||
JSON_DATA_CONTENT_TYPE as JSON_DATA_CONTENT_TYPE_ALIAS,
|
||||
)
|
||||
from ._serialization import (
|
||||
PROTOBUF_DATA_CONTENT_TYPE as PROTOBUF_DATA_CONTENT_TYPE_ALIAS,
|
||||
)
|
||||
from ._serialization import (
|
||||
MessageSerializer,
|
||||
UnknownPayload,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
||||
from ._subscription import Subscription
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._telemetry import (
|
||||
trace_create_agent_span,
|
||||
trace_invoke_agent_span,
|
||||
trace_tool_span,
|
||||
)
|
||||
from ._topic import TopicId
|
||||
from ._type_prefix_subscription import TypePrefixSubscription
|
||||
from ._type_subscription import TypeSubscription
|
||||
from ._types import FunctionCall
|
||||
|
||||
EVENT_LOGGER_NAME = EVENT_LOGGER_NAME_ALIAS
|
||||
"""The name of the logger used for structured events."""
|
||||
|
||||
ROOT_LOGGER_NAME = ROOT_LOGGER_NAME_ALIAS
|
||||
"""The name of the root logger."""
|
||||
|
||||
TRACE_LOGGER_NAME = TRACE_LOGGER_NAME_ALIAS
|
||||
"""Logger name used for developer intended trace logging. The content and format of this log should not be depended upon."""
|
||||
|
||||
JSON_DATA_CONTENT_TYPE = JSON_DATA_CONTENT_TYPE_ALIAS
|
||||
"""The content type for JSON data."""
|
||||
|
||||
PROTOBUF_DATA_CONTENT_TYPE = PROTOBUF_DATA_CONTENT_TYPE_ALIAS
|
||||
"""The content type for Protobuf data."""
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"AgentId",
|
||||
"AgentProxy",
|
||||
"AgentMetadata",
|
||||
"AgentRuntime",
|
||||
"BaseAgent",
|
||||
"CacheStore",
|
||||
"InMemoryStore",
|
||||
"CancellationToken",
|
||||
"AgentInstantiationContext",
|
||||
"TopicId",
|
||||
"Subscription",
|
||||
"MessageContext",
|
||||
"AgentType",
|
||||
"SubscriptionInstantiationContext",
|
||||
"MessageHandlerContext",
|
||||
"MessageSerializer",
|
||||
"try_get_known_serializers_for_type",
|
||||
"UnknownPayload",
|
||||
"Image",
|
||||
"RoutedAgent",
|
||||
"ClosureAgent",
|
||||
"ClosureContext",
|
||||
"message_handler",
|
||||
"event",
|
||||
"rpc",
|
||||
"FunctionCall",
|
||||
"TypeSubscription",
|
||||
"DefaultSubscription",
|
||||
"DefaultTopicId",
|
||||
"default_subscription",
|
||||
"type_subscription",
|
||||
"TypePrefixSubscription",
|
||||
"JSON_DATA_CONTENT_TYPE",
|
||||
"PROTOBUF_DATA_CONTENT_TYPE",
|
||||
"SingleThreadedAgentRuntime",
|
||||
"ROOT_LOGGER_NAME",
|
||||
"EVENT_LOGGER_NAME",
|
||||
"TRACE_LOGGER_NAME",
|
||||
"Component",
|
||||
"ComponentBase",
|
||||
"ComponentFromConfig",
|
||||
"ComponentLoader",
|
||||
"ComponentModel",
|
||||
"ComponentSchemaType",
|
||||
"ComponentToConfig",
|
||||
"ComponentType",
|
||||
"is_component_class",
|
||||
"is_component_instance",
|
||||
"DropMessage",
|
||||
"InterventionHandler",
|
||||
"DefaultInterventionHandler",
|
||||
"trace_create_agent_span",
|
||||
"trace_invoke_agent_span",
|
||||
"trace_tool_span",
|
||||
]
|
||||
64
agent_dhal/agentdhal_core/_agent.py
Normal file
64
agent_dhal/agentdhal_core/_agent.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Protocol, runtime_checkable
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._message_context import MessageContext
|
||||
|
||||
# Forward declaration for type checking only
|
||||
if TYPE_CHECKING:
|
||||
from ._agent_runtime import AgentRuntime
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agent(Protocol):
|
||||
@property
|
||||
def metadata(self) -> AgentMetadata:
|
||||
"""Metadata of the agent."""
|
||||
...
|
||||
|
||||
@property
|
||||
def id(self) -> AgentId:
|
||||
"""ID of the agent."""
|
||||
...
|
||||
|
||||
async def bind_id_and_runtime(self, id: AgentId, runtime: "AgentRuntime") -> None:
|
||||
"""Function used to bind an Agent instance to an `AgentRuntime`.
|
||||
|
||||
Args:
|
||||
agent_id (AgentId): ID of the agent.
|
||||
runtime (AgentRuntime): AgentRuntime instance to bind the agent to.
|
||||
"""
|
||||
...
|
||||
|
||||
async def on_message(self, message: Any, ctx: MessageContext) -> Any:
|
||||
"""Message handler for the agent. This should only be called by the runtime, not by other agents.
|
||||
|
||||
Args:
|
||||
message (Any): Received message. Type is one of the types in `subscriptions`.
|
||||
ctx (MessageContext): Context of the message.
|
||||
|
||||
Returns:
|
||||
Any: Response to the message. Can be None.
|
||||
|
||||
Raises:
|
||||
asyncio.CancelledError: If the message was cancelled.
|
||||
CantHandleException: If the agent cannot handle the message.
|
||||
"""
|
||||
...
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the agent. The result must be JSON serializable."""
|
||||
...
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load in the state of the agent obtained from `save_state`.
|
||||
|
||||
Args:
|
||||
state (Mapping[str, Any]): State of the agent. Must be JSON serializable.
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Called when the runtime is closed"""
|
||||
...
|
||||
68
agent_dhal/agentdhal_core/_agent_id.py
Normal file
68
agent_dhal/agentdhal_core/_agent_id.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import re
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._agent_type import AgentType
|
||||
|
||||
|
||||
def is_valid_agent_type(value: str) -> bool:
|
||||
return bool(re.match(r"^[\w\-\.]+\Z", value))
|
||||
|
||||
|
||||
class AgentId:
|
||||
"""
|
||||
Agent ID uniquely identifies an agent instance within an agent runtime - including distributed runtime. It is the 'address' of the agent instance for receiving messages.
|
||||
|
||||
See here for more information: :ref:`agentid_and_lifecycle`
|
||||
"""
|
||||
|
||||
def __init__(self, type: str | AgentType, key: str) -> None:
|
||||
if isinstance(type, AgentType):
|
||||
type = type.type
|
||||
|
||||
if not is_valid_agent_type(type):
|
||||
raise ValueError(rf"Invalid agent type: {type}. Allowed values MUST match the regex: `^[\w\-\.]+\Z`")
|
||||
|
||||
self._type = type
|
||||
self._key = key
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._type, self._key))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self._type}/{self._key}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'AgentId(type="{self._type}", key="{self._key}")'
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
if not isinstance(value, AgentId):
|
||||
return False
|
||||
return self._type == value.type and self._key == value.key
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, agent_id: str) -> Self:
|
||||
"""Convert a string of the format ``type/key`` into an AgentId"""
|
||||
items = agent_id.split("/", maxsplit=1)
|
||||
if len(items) != 2:
|
||||
raise ValueError(f"Invalid agent id: {agent_id}")
|
||||
type, key = items[0], items[1]
|
||||
return cls(type, key)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""
|
||||
An identifier that associates an agent with a specific factory function.
|
||||
|
||||
Strings may only be composed of alphanumeric letters (a-z) and (0-9), or underscores (_).
|
||||
"""
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""
|
||||
Agent instance identifier.
|
||||
|
||||
Strings may only be composed of alphanumeric letters (a-z) and (0-9), or underscores (_).
|
||||
"""
|
||||
return self._key
|
||||
126
agent_dhal/agentdhal_core/_agent_instantiation.py
Normal file
126
agent_dhal/agentdhal_core/_agent_instantiation.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, ClassVar, Generator
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_runtime import AgentRuntime
|
||||
|
||||
|
||||
class AgentInstantiationContext:
|
||||
"""A static class that provides context for agent instantiation.
|
||||
|
||||
This static class can be used to access the current runtime and agent ID
|
||||
during agent instantiation -- inside the factory function or the agent's
|
||||
class constructor.
|
||||
|
||||
Example:
|
||||
|
||||
Get the current runtime and agent ID inside the factory function and
|
||||
the agent's constructor:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agentdhal_core import (
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
MessageContext,
|
||||
RoutedAgent,
|
||||
SingleThreadedAgentRuntime,
|
||||
message_handler,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestMessage:
|
||||
content: str
|
||||
|
||||
|
||||
class TestAgent(RoutedAgent):
|
||||
def __init__(self, description: str):
|
||||
super().__init__(description)
|
||||
# Get the current runtime -- we don't use it here, but it's available.
|
||||
_ = AgentInstantiationContext.current_runtime()
|
||||
# Get the current agent ID.
|
||||
agent_id = AgentInstantiationContext.current_agent_id()
|
||||
print(f"Current AgentID from constructor: {agent_id}")
|
||||
|
||||
@message_handler
|
||||
async def handle_test_message(self, message: TestMessage, ctx: MessageContext) -> None:
|
||||
print(f"Received message: {message.content}")
|
||||
|
||||
|
||||
def test_agent_factory() -> TestAgent:
|
||||
# Get the current runtime -- we don't use it here, but it's available.
|
||||
_ = AgentInstantiationContext.current_runtime()
|
||||
# Get the current agent ID.
|
||||
agent_id = AgentInstantiationContext.current_agent_id()
|
||||
print(f"Current AgentID from factory: {agent_id}")
|
||||
return TestAgent(description="Test agent")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create a SingleThreadedAgentRuntime instance.
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Start the runtime.
|
||||
runtime.start()
|
||||
|
||||
# Register the agent type with a factory function.
|
||||
await runtime.register_factory("test_agent", test_agent_factory)
|
||||
|
||||
# Send a message to the agent. The runtime will instantiate the agent and call the message handler.
|
||||
await runtime.send_message(TestMessage(content="Hello, world!"), AgentId("test_agent", "default"))
|
||||
|
||||
# Stop the runtime.
|
||||
await runtime.stop()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError(
|
||||
"AgentInstantiationContext cannot be instantiated. It is a static class that provides context management for agent instantiation."
|
||||
)
|
||||
|
||||
_AGENT_INSTANTIATION_CONTEXT_VAR: ClassVar[ContextVar[tuple[AgentRuntime, AgentId]]] = ContextVar(
|
||||
"_AGENT_INSTANTIATION_CONTEXT_VAR"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def populate_context(cls, ctx: tuple[AgentRuntime, AgentId]) -> Generator[None, Any, None]:
|
||||
""":meta private:"""
|
||||
token = AgentInstantiationContext._AGENT_INSTANTIATION_CONTEXT_VAR.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
AgentInstantiationContext._AGENT_INSTANTIATION_CONTEXT_VAR.reset(token)
|
||||
|
||||
@classmethod
|
||||
def current_runtime(cls) -> AgentRuntime:
|
||||
try:
|
||||
return cls._AGENT_INSTANTIATION_CONTEXT_VAR.get()[0]
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"AgentInstantiationContext.runtime() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def current_agent_id(cls) -> AgentId:
|
||||
try:
|
||||
return cls._AGENT_INSTANTIATION_CONTEXT_VAR.get()[1]
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"AgentInstantiationContext.agent_id() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def is_in_factory_call(cls) -> bool:
|
||||
if cls._AGENT_INSTANTIATION_CONTEXT_VAR.get(None) is None:
|
||||
return False
|
||||
return True
|
||||
7
agent_dhal/agentdhal_core/_agent_metadata.py
Normal file
7
agent_dhal/agentdhal_core/_agent_metadata.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class AgentMetadata(TypedDict):
|
||||
type: str
|
||||
key: str
|
||||
description: str
|
||||
56
agent_dhal/agentdhal_core/_agent_proxy.py
Normal file
56
agent_dhal/agentdhal_core/_agent_proxy.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Mapping
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._agent_runtime import AgentRuntime
|
||||
|
||||
|
||||
class AgentProxy:
|
||||
"""A helper class that allows you to use an :class:`~agentdhal_core.AgentId` in place of its associated :class:`~agentdhal_core.Agent`"""
|
||||
|
||||
def __init__(self, agent: AgentId, runtime: AgentRuntime):
|
||||
self._agent = agent
|
||||
self._runtime = runtime
|
||||
|
||||
@property
|
||||
def id(self) -> AgentId:
|
||||
"""Target agent for this proxy"""
|
||||
return self._agent
|
||||
|
||||
@property
|
||||
def metadata(self) -> Awaitable[AgentMetadata]:
|
||||
"""Metadata of the agent."""
|
||||
return self._runtime.agent_metadata(self._agent)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
sender: AgentId,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
return await self._runtime.send_message(
|
||||
message,
|
||||
recipient=self._agent,
|
||||
sender=sender,
|
||||
cancellation_token=cancellation_token,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the agent. The result must be JSON serializable."""
|
||||
return await self._runtime.agent_save_state(self._agent)
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load in the state of the agent obtained from `save_state`.
|
||||
|
||||
Args:
|
||||
state (Mapping[str, Any]): State of the agent. Must be JSON serializable.
|
||||
"""
|
||||
await self._runtime.agent_load_state(self._agent, state)
|
||||
295
agent_dhal/agentdhal_core/_agent_runtime.py
Normal file
295
agent_dhal/agentdhal_core/_agent_runtime.py
Normal file
@@ -0,0 +1,295 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_type import AgentType
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._serialization import MessageSerializer
|
||||
from ._subscription import Subscription
|
||||
from ._topic import TopicId
|
||||
|
||||
# Undeliverable - error
|
||||
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentRuntime(Protocol):
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
recipient: AgentId,
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
"""Send a message to an agent and get a response.
|
||||
|
||||
Args:
|
||||
message (Any): The message to send.
|
||||
recipient (AgentId): The agent to send the message to.
|
||||
sender (AgentId | None, optional): Agent which sent the message. Should **only** be None if this was sent from no agent, such as directly to the runtime externally. Defaults to None.
|
||||
cancellation_token (CancellationToken | None, optional): Token used to cancel an in progress . Defaults to None.
|
||||
|
||||
Raises:
|
||||
CantHandleException: If the recipient cannot handle the message.
|
||||
UndeliverableException: If the message cannot be delivered.
|
||||
Other: Any other exception raised by the recipient.
|
||||
|
||||
Returns:
|
||||
Any: The response from the agent.
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
topic_id: TopicId,
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
"""Publish a message to all agents in the given namespace, or if no namespace is provided, the namespace of the sender.
|
||||
|
||||
No responses are expected from publishing.
|
||||
|
||||
Args:
|
||||
message (Any): The message to publish.
|
||||
topic_id (TopicId): The topic to publish the message to.
|
||||
sender (AgentId | None, optional): The agent which sent the message. Defaults to None.
|
||||
cancellation_token (CancellationToken | None, optional): Token used to cancel an in progress. Defaults to None.
|
||||
message_id (str | None, optional): The message id. If None, a new message id will be generated. Defaults to None. This message id must be unique. and is recommended to be a UUID.
|
||||
|
||||
Raises:
|
||||
UndeliverableException: If the message cannot be delivered.
|
||||
"""
|
||||
...
|
||||
|
||||
async def register_factory(
|
||||
self,
|
||||
type: str | AgentType,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
*,
|
||||
expected_class: type[T] | None = None,
|
||||
) -> AgentType:
|
||||
"""Register an agent factory with the runtime associated with a specific type. The type must be unique. This API does not add any subscriptions.
|
||||
|
||||
.. note::
|
||||
|
||||
This is a low level API and usually the agent class's `register` method should be used instead, as this also handles subscriptions automatically.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agentdhal_core import AgentRuntime, MessageContext, RoutedAgent, event
|
||||
from agentdhal_core.models import UserMessage
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyMessage:
|
||||
content: str
|
||||
|
||||
|
||||
class MyAgent(RoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("My core agent")
|
||||
|
||||
@event
|
||||
async def handler(self, message: UserMessage, context: MessageContext) -> None:
|
||||
print("Event received: ", message.content)
|
||||
|
||||
|
||||
async def my_agent_factory():
|
||||
return MyAgent()
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime: AgentRuntime = ... # type: ignore
|
||||
await runtime.register_factory("my_agent", lambda: MyAgent())
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Args:
|
||||
type (str): The type of agent this factory creates. It is not the same as agent class name. The `type` parameter is used to differentiate between different factory functions rather than agent classes.
|
||||
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agentdhal_core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
|
||||
expected_class (type[T] | None, optional): The expected class of the agent, used for runtime validation of the factory. Defaults to None. If None, no validation is performed.
|
||||
"""
|
||||
...
|
||||
|
||||
async def register_agent_instance(
|
||||
self,
|
||||
agent_instance: Agent,
|
||||
agent_id: AgentId,
|
||||
) -> AgentId:
|
||||
"""Register an agent instance with the runtime. The type may be reused, but each agent_id must be unique. All agent instances within a type must be of the same object type. This API does not add any subscriptions.
|
||||
|
||||
.. note::
|
||||
|
||||
This is a low level API and usually the agent class's `register_instance` method should be used instead, as this also handles subscriptions automatically.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agentdhal_core import AgentId, AgentRuntime, MessageContext, RoutedAgent, event
|
||||
from agentdhal_core.models import UserMessage
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyMessage:
|
||||
content: str
|
||||
|
||||
|
||||
class MyAgent(RoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("My core agent")
|
||||
|
||||
@event
|
||||
async def handler(self, message: UserMessage, context: MessageContext) -> None:
|
||||
print("Event received: ", message.content)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime: AgentRuntime = ... # type: ignore
|
||||
agent = MyAgent()
|
||||
await runtime.register_agent_instance(
|
||||
agent_instance=agent, agent_id=AgentId(type="my_agent", key="default")
|
||||
)
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Args:
|
||||
agent_instance (Agent): A concrete instance of the agent.
|
||||
agent_id (AgentId): The agent's identifier. The agent's type is `agent_id.type`.
|
||||
"""
|
||||
...
|
||||
|
||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||
"""Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases.
|
||||
|
||||
If the underlying agent is not accessible, this will raise an exception.
|
||||
|
||||
Args:
|
||||
id (AgentId): The agent id.
|
||||
type (Type[T], optional): The expected type of the agent. Defaults to Agent.
|
||||
|
||||
Returns:
|
||||
T: The concrete agent instance.
|
||||
|
||||
Raises:
|
||||
LookupError: If the agent is not found.
|
||||
NotAccessibleError: If the agent is not accessible, for example if it is located remotely.
|
||||
TypeError: If the agent is not of the expected type.
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
async def get(self, id: AgentId, /, *, lazy: bool = ...) -> AgentId: ...
|
||||
|
||||
@overload
|
||||
async def get(self, type: AgentType | str, /, key: str = ..., *, lazy: bool = ...) -> AgentId: ...
|
||||
|
||||
async def get(
|
||||
self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True
|
||||
) -> AgentId: ...
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the entire runtime, including all hosted agents. The only way to restore the state is to pass it to :meth:`load_state`.
|
||||
|
||||
The structure of the state is implementation defined and can be any JSON serializable object.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: The saved state.
|
||||
"""
|
||||
...
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of the entire runtime, including all hosted agents. The state should be the same as the one returned by :meth:`save_state`.
|
||||
|
||||
Args:
|
||||
state (Mapping[str, Any]): The saved state.
|
||||
"""
|
||||
...
|
||||
|
||||
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
"""Get the metadata for an agent.
|
||||
|
||||
Args:
|
||||
agent (AgentId): The agent id.
|
||||
|
||||
Returns:
|
||||
AgentMetadata: The agent metadata.
|
||||
"""
|
||||
...
|
||||
|
||||
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
|
||||
"""Save the state of a single agent.
|
||||
|
||||
The structure of the state is implementation defined and can be any JSON serializable object.
|
||||
|
||||
Args:
|
||||
agent (AgentId): The agent id.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: The saved state.
|
||||
"""
|
||||
...
|
||||
|
||||
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of a single agent.
|
||||
|
||||
Args:
|
||||
agent (AgentId): The agent id.
|
||||
state (Mapping[str, Any]): The saved state.
|
||||
"""
|
||||
...
|
||||
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
"""Add a new subscription that the runtime should fulfill when processing published messages
|
||||
|
||||
Args:
|
||||
subscription (Subscription): The subscription to add
|
||||
"""
|
||||
...
|
||||
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
"""Remove a subscription from the runtime
|
||||
|
||||
Args:
|
||||
id (str): id of the subscription to remove
|
||||
|
||||
Raises:
|
||||
LookupError: If the subscription does not exist
|
||||
"""
|
||||
...
|
||||
|
||||
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
|
||||
"""Add a new message serialization serializer to the runtime
|
||||
|
||||
Note: This will deduplicate serializers based on the type_name and data_content_type properties
|
||||
|
||||
Args:
|
||||
serializer (MessageSerializer[Any] | Sequence[MessageSerializer[Any]]): The serializer/s to add
|
||||
"""
|
||||
...
|
||||
7
agent_dhal/agentdhal_core/_agent_type.py
Normal file
7
agent_dhal/agentdhal_core/_agent_type.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class AgentType:
|
||||
type: str
|
||||
"""String representation of this agent type."""
|
||||
254
agent_dhal/agentdhal_core/_base_agent.py
Normal file
254
agent_dhal/agentdhal_core/_base_agent.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar, final
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_instantiation import AgentInstantiationContext
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._agent_type import AgentType
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._message_context import MessageContext
|
||||
from ._serialization import MessageSerializer, try_get_known_serializers_for_type
|
||||
from ._subscription import Subscription, UnboundSubscription
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._topic import TopicId
|
||||
from ._type_prefix_subscription import TypePrefixSubscription
|
||||
from ._type_subscription import TypeSubscription
|
||||
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")
|
||||
|
||||
|
||||
# Decorator for adding an unbound subscription to an agent
|
||||
def subscription_factory(subscription: UnboundSubscription) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
|
||||
""":meta private:"""
|
||||
|
||||
def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]:
|
||||
cls.internal_unbound_subscriptions_list.append(subscription)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def handles(
|
||||
type: Type[Any], serializer: MessageSerializer[Any] | List[MessageSerializer[Any]] | None = None
|
||||
) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
|
||||
def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]:
|
||||
if serializer is None:
|
||||
serializer_list = try_get_known_serializers_for_type(type)
|
||||
else:
|
||||
serializer_list = [serializer] if not isinstance(serializer, Sequence) else serializer
|
||||
|
||||
if len(serializer_list) == 0:
|
||||
raise ValueError(f"No serializers found for type {type}. Please provide an explicit serializer.")
|
||||
|
||||
cls.internal_extra_handles_types.append((type, serializer_list))
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class BaseAgent(ABC, Agent):
|
||||
internal_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = []
|
||||
""":meta private:"""
|
||||
internal_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = []
|
||||
""":meta private:"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Automatically set class_variable in each subclass so that they are not shared between subclasses
|
||||
cls.internal_extra_handles_types = []
|
||||
cls.internal_unbound_subscriptions_list = []
|
||||
|
||||
@classmethod
|
||||
def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]:
|
||||
return cls.internal_extra_handles_types
|
||||
|
||||
@classmethod
|
||||
def _unbound_subscriptions(cls) -> List[UnboundSubscription]:
|
||||
return cls.internal_unbound_subscriptions_list
|
||||
|
||||
@property
|
||||
def metadata(self) -> AgentMetadata:
|
||||
assert self._id is not None
|
||||
return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description)
|
||||
|
||||
def __init__(self, description: str) -> None:
|
||||
if AgentInstantiationContext.is_in_factory_call():
|
||||
self._runtime: AgentRuntime = AgentInstantiationContext.current_runtime()
|
||||
self._id = AgentInstantiationContext.current_agent_id()
|
||||
if not isinstance(description, str):
|
||||
raise ValueError("Agent description must be a string")
|
||||
self._description = description
|
||||
|
||||
async def bind_id_and_runtime(self, id: AgentId, runtime: AgentRuntime) -> None:
|
||||
if hasattr(self, "_id"):
|
||||
if self._id != id:
|
||||
raise RuntimeError("Agent is already bound to a different ID")
|
||||
|
||||
if hasattr(self, "_runtime"):
|
||||
if self._runtime != runtime:
|
||||
raise RuntimeError("Agent is already bound to a different runtime")
|
||||
|
||||
self._id = id
|
||||
self._runtime = runtime
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.id.type
|
||||
|
||||
@property
|
||||
def id(self) -> AgentId:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def runtime(self) -> AgentRuntime:
|
||||
return self._runtime
|
||||
|
||||
@final
|
||||
async def on_message(self, message: Any, ctx: MessageContext) -> Any:
|
||||
return await self.on_message_impl(message, ctx)
|
||||
|
||||
@abstractmethod
|
||||
async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any: ...
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
recipient: AgentId,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
"""See :py:meth:`agentdhal_core.AgentRuntime.send_message` for more information."""
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
return await self._runtime.send_message(
|
||||
message,
|
||||
sender=self.id,
|
||||
recipient=recipient,
|
||||
cancellation_token=cancellation_token,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
topic_id: TopicId,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> None:
|
||||
await self._runtime.publish_message(message, topic_id, sender=self.id, cancellation_token=cancellation_token)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
warnings.warn("save_state not implemented", stacklevel=2)
|
||||
return {}
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
warnings.warn("load_state not implemented", stacklevel=2)
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_instance(
|
||||
self,
|
||||
runtime: AgentRuntime,
|
||||
agent_id: AgentId,
|
||||
*,
|
||||
skip_class_subscriptions: bool = True,
|
||||
skip_direct_message_subscription: bool = False,
|
||||
) -> AgentId:
|
||||
"""
|
||||
This function is similar to `register` but is used for registering an instance of an agent. A subscription based on the agent ID is created and added to the runtime.
|
||||
"""
|
||||
agent_id = await runtime.register_agent_instance(agent_instance=self, agent_id=agent_id)
|
||||
|
||||
id_subscription = TypeSubscription(topic_type=agent_id.key, agent_type=agent_id.type)
|
||||
await runtime.add_subscription(id_subscription)
|
||||
|
||||
if not skip_class_subscriptions:
|
||||
with SubscriptionInstantiationContext.populate_context(AgentType(agent_id.type)):
|
||||
subscriptions: List[Subscription] = []
|
||||
for unbound_subscription in self._unbound_subscriptions():
|
||||
subscriptions_list_result = unbound_subscription()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
subscriptions_list = await subscriptions_list_result
|
||||
else:
|
||||
subscriptions_list = subscriptions_list_result
|
||||
|
||||
subscriptions.extend(subscriptions_list)
|
||||
for subscription in subscriptions:
|
||||
await runtime.add_subscription(subscription)
|
||||
|
||||
if not skip_direct_message_subscription:
|
||||
# Additionally adds a special prefix subscription for this agent to receive direct messages
|
||||
try:
|
||||
await runtime.add_subscription(
|
||||
TypePrefixSubscription(
|
||||
# The prefix MUST include ":" to avoid collisions with other agents
|
||||
topic_type_prefix=agent_id.type + ":",
|
||||
agent_type=agent_id.type,
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
# We don't care if the subscription already exists
|
||||
pass
|
||||
|
||||
# TODO: deduplication
|
||||
for _message_type, serializer in self._handles_types():
|
||||
runtime.add_message_serializer(serializer)
|
||||
|
||||
return agent_id
|
||||
|
||||
@classmethod
|
||||
async def register(
|
||||
cls,
|
||||
runtime: AgentRuntime,
|
||||
type: str,
|
||||
factory: Callable[[], Self | Awaitable[Self]],
|
||||
*,
|
||||
skip_class_subscriptions: bool = False,
|
||||
skip_direct_message_subscription: bool = False,
|
||||
) -> AgentType:
|
||||
agent_type = AgentType(type)
|
||||
agent_type = await runtime.register_factory(type=agent_type, agent_factory=factory, expected_class=cls)
|
||||
if not skip_class_subscriptions:
|
||||
with SubscriptionInstantiationContext.populate_context(agent_type):
|
||||
subscriptions: List[Subscription] = []
|
||||
for unbound_subscription in cls._unbound_subscriptions():
|
||||
subscriptions_list_result = unbound_subscription()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
subscriptions_list = await subscriptions_list_result
|
||||
else:
|
||||
subscriptions_list = subscriptions_list_result
|
||||
|
||||
subscriptions.extend(subscriptions_list)
|
||||
for subscription in subscriptions:
|
||||
await runtime.add_subscription(subscription)
|
||||
|
||||
if not skip_direct_message_subscription:
|
||||
# Additionally adds a special prefix subscription for this agent to receive direct messages
|
||||
await runtime.add_subscription(
|
||||
TypePrefixSubscription(
|
||||
# The prefix MUST include ":" to avoid collisions with other agents
|
||||
topic_type_prefix=agent_type.type + ":",
|
||||
agent_type=agent_type.type,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: deduplication
|
||||
for _message_type, serializer in cls._handles_types():
|
||||
runtime.add_message_serializer(serializer)
|
||||
|
||||
return agent_type
|
||||
70
agent_dhal/agentdhal_core/_cache_store.py
Normal file
70
agent_dhal/agentdhal_core/_cache_store.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Generic, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._component_config import Component, ComponentBase
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class CacheStore(ABC, Generic[T], ComponentBase[BaseModel]):
|
||||
"""
|
||||
This protocol defines the basic interface for store/cache operations.
|
||||
|
||||
Sub-classes should handle the lifecycle of underlying storage.
|
||||
"""
|
||||
|
||||
component_type = "cache_store"
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
|
||||
"""
|
||||
Retrieve an item from the store.
|
||||
|
||||
Args:
|
||||
key: The key identifying the item in the store.
|
||||
default (optional): The default value to return if the key is not found.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
The value associated with the key if found, else the default value.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, value: T) -> None:
|
||||
"""
|
||||
Set an item in the store.
|
||||
|
||||
Args:
|
||||
key: The key under which the item is to be stored.
|
||||
value: The value to be stored in the store.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class InMemoryStoreConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryStore(CacheStore[T], Component[InMemoryStoreConfig]):
|
||||
component_provider_override = "agentdhal_core.InMemoryStore"
|
||||
component_config_schema = InMemoryStoreConfig
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.store: Dict[str, T] = {}
|
||||
|
||||
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
|
||||
return self.store.get(key, default)
|
||||
|
||||
def set(self, key: str, value: T) -> None:
|
||||
self.store[key] = value
|
||||
|
||||
def _to_config(self) -> InMemoryStoreConfig:
|
||||
return InMemoryStoreConfig()
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: InMemoryStoreConfig) -> Self:
|
||||
return cls()
|
||||
46
agent_dhal/agentdhal_core/_cancellation_token.py
Normal file
46
agent_dhal/agentdhal_core/_cancellation_token.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import threading
|
||||
from asyncio import Future
|
||||
from typing import Any, Callable, List
|
||||
|
||||
|
||||
class CancellationToken:
|
||||
"""A token used to cancel pending async calls"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._cancelled: bool = False
|
||||
self._lock: threading.Lock = threading.Lock()
|
||||
self._callbacks: List[Callable[[], None]] = []
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel pending async calls linked to this cancellation token."""
|
||||
with self._lock:
|
||||
if not self._cancelled:
|
||||
self._cancelled = True
|
||||
for callback in self._callbacks:
|
||||
callback()
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
"""Check if the CancellationToken has been used"""
|
||||
with self._lock:
|
||||
return self._cancelled
|
||||
|
||||
def add_callback(self, callback: Callable[[], None]) -> None:
|
||||
"""Attach a callback that will be called when cancel is invoked"""
|
||||
with self._lock:
|
||||
if self._cancelled:
|
||||
callback()
|
||||
else:
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def link_future(self, future: Future[Any]) -> Future[Any]:
|
||||
"""Link a pending async call to a token to allow its cancellation"""
|
||||
with self._lock:
|
||||
if self._cancelled:
|
||||
future.cancel()
|
||||
else:
|
||||
|
||||
def _cancel() -> None:
|
||||
future.cancel()
|
||||
|
||||
self._callbacks.append(_cancel)
|
||||
return future
|
||||
241
agent_dhal/agentdhal_core/_closure_agent.py
Normal file
241
agent_dhal/agentdhal_core/_closure_agent.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Awaitable, Callable, List, Literal, Mapping, Protocol, Sequence, TypeVar, get_type_hints
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_instantiation import AgentInstantiationContext
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._agent_type import AgentType
|
||||
from ._base_agent import BaseAgent
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._message_context import MessageContext
|
||||
from ._serialization import try_get_known_serializers_for_type
|
||||
from ._subscription import Subscription
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._topic import TopicId
|
||||
from ._type_helpers import get_types
|
||||
from .exceptions import CantHandleException
|
||||
|
||||
T = TypeVar("T")
|
||||
ClosureAgentType = TypeVar("ClosureAgentType", bound="ClosureAgent")
|
||||
|
||||
|
||||
def get_handled_types_from_closure(
|
||||
closure: Callable[[ClosureAgent, T, MessageContext], Awaitable[Any]],
|
||||
) -> Sequence[type]:
|
||||
args = inspect.getfullargspec(closure)[0]
|
||||
if len(args) != 3:
|
||||
raise AssertionError("Closure must have 4 arguments")
|
||||
|
||||
message_arg_name = args[1]
|
||||
|
||||
type_hints = get_type_hints(closure)
|
||||
|
||||
if "return" not in type_hints:
|
||||
raise AssertionError("return not found in function signature")
|
||||
|
||||
# Get the type of the message parameter
|
||||
target_types = get_types(type_hints[message_arg_name])
|
||||
if target_types is None:
|
||||
raise AssertionError("Message type not found")
|
||||
|
||||
# print(type_hints)
|
||||
return_types = get_types(type_hints["return"])
|
||||
|
||||
if return_types is None:
|
||||
raise AssertionError("Return type not found")
|
||||
|
||||
return target_types
|
||||
|
||||
|
||||
class ClosureContext(Protocol):
|
||||
@property
|
||||
def id(self) -> AgentId: ...
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
recipient: AgentId,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any: ...
|
||||
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
topic_id: TopicId,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class ClosureAgent(BaseAgent, ClosureContext):
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]],
|
||||
*,
|
||||
unknown_type_policy: Literal["error", "warn", "ignore"] = "warn",
|
||||
) -> None:
|
||||
try:
|
||||
runtime = AgentInstantiationContext.current_runtime()
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"ClosureAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."
|
||||
) from e
|
||||
|
||||
self._runtime: AgentRuntime = runtime
|
||||
self._id: AgentId = id
|
||||
self._description = description
|
||||
handled_types = get_handled_types_from_closure(closure)
|
||||
self._expected_types = handled_types
|
||||
self._closure = closure
|
||||
self._unknown_type_policy = unknown_type_policy
|
||||
super().__init__(description)
|
||||
|
||||
@property
|
||||
def metadata(self) -> AgentMetadata:
|
||||
assert self._id is not None
|
||||
return AgentMetadata(
|
||||
key=self._id.key,
|
||||
type=self._id.type,
|
||||
description=self._description,
|
||||
)
|
||||
|
||||
@property
|
||||
def id(self) -> AgentId:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def runtime(self) -> AgentRuntime:
|
||||
return self._runtime
|
||||
|
||||
async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any:
|
||||
if type(message) not in self._expected_types:
|
||||
if self._unknown_type_policy == "warn":
|
||||
warnings.warn(
|
||||
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'error' to raise an exception, or 'ignore' to suppress this warning.",
|
||||
stacklevel=1,
|
||||
)
|
||||
return None
|
||||
elif self._unknown_type_policy == "error":
|
||||
raise CantHandleException(
|
||||
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'warn' to suppress this exception, or 'ignore' to suppress this warning."
|
||||
)
|
||||
|
||||
return await self._closure(self, message, ctx)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Closure agents do not have state. So this method always returns an empty dictionary."""
|
||||
return {}
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Closure agents do not have state. So this method does nothing."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def register_closure(
|
||||
cls,
|
||||
runtime: AgentRuntime,
|
||||
type: str,
|
||||
closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]],
|
||||
*,
|
||||
unknown_type_policy: Literal["error", "warn", "ignore"] = "warn",
|
||||
skip_direct_message_subscription: bool = False,
|
||||
description: str = "",
|
||||
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None,
|
||||
) -> AgentType:
|
||||
"""The closure agent allows you to define an agent using a closure, or function without needing to define a class. It allows values to be extracted out of the runtime.
|
||||
|
||||
The closure can define the type of message which is expected, or `Any` can be used to accept any type of message.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_core import SingleThreadedAgentRuntime, MessageContext, ClosureAgent, ClosureContext
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agentdhal_core._default_subscription import DefaultSubscription
|
||||
from agentdhal_core._default_topic import DefaultTopicId
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyMessage:
|
||||
content: str
|
||||
|
||||
|
||||
async def main():
|
||||
queue = asyncio.Queue[MyMessage]()
|
||||
|
||||
async def output_result(_ctx: ClosureContext, message: MyMessage, ctx: MessageContext) -> None:
|
||||
await queue.put(message)
|
||||
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
await ClosureAgent.register_closure(
|
||||
runtime, "output_result", output_result, subscriptions=lambda: [DefaultSubscription()]
|
||||
)
|
||||
|
||||
runtime.start()
|
||||
await runtime.publish_message(MyMessage("Hello, world!"), DefaultTopicId())
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
result = await queue.get()
|
||||
print(result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Args:
|
||||
runtime (AgentRuntime): Runtime to register the agent to
|
||||
type (str): Agent type of registered agent
|
||||
closure (Callable[[ClosureContext, T, MessageContext], Awaitable[Any]]): Closure to handle messages
|
||||
unknown_type_policy (Literal["error", "warn", "ignore"], optional): What to do if a type is encountered that does not match the closure type. Defaults to "warn".
|
||||
skip_direct_message_subscription (bool, optional): Do not add direct message subscription for this agent. Defaults to False.
|
||||
description (str, optional): Description of what agent does. Defaults to "".
|
||||
subscriptions (Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None, optional): List of subscriptions for this closure agent. Defaults to None.
|
||||
|
||||
Returns:
|
||||
AgentType: Type of the agent that was registered
|
||||
"""
|
||||
|
||||
def factory() -> ClosureAgent:
|
||||
return ClosureAgent(description=description, closure=closure, unknown_type_policy=unknown_type_policy)
|
||||
|
||||
assert len(cls._unbound_subscriptions()) == 0, "Closure agents are expected to have no class subscriptions"
|
||||
agent_type = await cls.register(
|
||||
runtime=runtime,
|
||||
type=type,
|
||||
factory=factory, # type: ignore
|
||||
# There should be no need to process class subscriptions, as the closure agent does not have any subscriptions.s
|
||||
skip_class_subscriptions=True,
|
||||
skip_direct_message_subscription=skip_direct_message_subscription,
|
||||
)
|
||||
|
||||
subscriptions_list: List[Subscription] = []
|
||||
if subscriptions is not None:
|
||||
with SubscriptionInstantiationContext.populate_context(agent_type):
|
||||
subscriptions_list_result = subscriptions()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
subscriptions_list.extend(await subscriptions_list_result)
|
||||
else:
|
||||
# just ignore mypy here
|
||||
subscriptions_list.extend(subscriptions_list_result) # type: ignore
|
||||
|
||||
for subscription in subscriptions_list:
|
||||
await runtime.add_subscription(subscription)
|
||||
|
||||
handled_types = get_handled_types_from_closure(closure)
|
||||
for message_type in handled_types:
|
||||
# TODO: support custom serializers
|
||||
serializer = try_get_known_serializers_for_type(message_type)
|
||||
runtime.add_message_serializer(serializer)
|
||||
|
||||
return agent_type
|
||||
361
agent_dhal/agentdhal_core/_component_config.py
Normal file
361
agent_dhal/agentdhal_core/_component_config.py
Normal file
@@ -0,0 +1,361 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import warnings
|
||||
from typing import Any, ClassVar, Dict, Generic, Literal, Type, TypeGuard, cast, overload
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self, TypeVar
|
||||
|
||||
ComponentType = Literal["model", "agent", "tool", "termination", "token_provider", "workbench"] | str
|
||||
ConfigT = TypeVar("ConfigT", bound=BaseModel)
|
||||
FromConfigT = TypeVar("FromConfigT", bound=BaseModel, contravariant=True)
|
||||
ToConfigT = TypeVar("ToConfigT", bound=BaseModel, covariant=True)
|
||||
|
||||
T = TypeVar("T", bound=BaseModel, covariant=True)
|
||||
|
||||
|
||||
class ComponentModel(BaseModel):
|
||||
"""Model class for a component. Contains all information required to instantiate a component."""
|
||||
|
||||
provider: str
|
||||
"""Describes how the component can be instantiated."""
|
||||
|
||||
component_type: ComponentType | None = None
|
||||
"""Logical type of the component. If missing, the component assumes the default type of the provider."""
|
||||
|
||||
version: int | None = None
|
||||
"""Version of the component specification. If missing, the component assumes whatever is the current version of the library used to load it. This is obviously dangerous and should be used for user authored ephmeral config. For all other configs version should be specified."""
|
||||
|
||||
component_version: int | None = None
|
||||
"""Version of the component. If missing, the component assumes the default version of the provider."""
|
||||
|
||||
description: str | None = None
|
||||
"""Description of the component."""
|
||||
|
||||
label: str | None = None
|
||||
"""Human readable label for the component. If missing the component assumes the class name of the provider."""
|
||||
|
||||
config: dict[str, Any]
|
||||
"""The schema validated config field is passed to a given class's implmentation of :py:meth:`agentdhal_core.ComponentConfigImpl._from_config` to create a new instance of the component class."""
|
||||
|
||||
|
||||
def _type_to_provider_str(t: type) -> str:
|
||||
return f"{t.__module__}.{t.__qualname__}"
|
||||
|
||||
|
||||
WELL_KNOWN_PROVIDERS = {
|
||||
"azure_openai_chat_completion_client": "agentdhal_extensions.models.openai.AzureOpenAIChatCompletionClient",
|
||||
"AzureOpenAIChatCompletionClient": "agentdhal_extensions.models.openai.AzureOpenAIChatCompletionClient",
|
||||
"openai_chat_completion_client": "agentdhal_extensions.models.openai.OpenAIChatCompletionClient",
|
||||
"OpenAIChatCompletionClient": "agentdhal_extensions.models.openai.OpenAIChatCompletionClient",
|
||||
}
|
||||
|
||||
|
||||
class ComponentFromConfig(Generic[FromConfigT]):
|
||||
@classmethod
|
||||
def _from_config(cls, config: FromConfigT) -> Self:
|
||||
"""Create a new instance of the component from a configuration object.
|
||||
|
||||
Args:
|
||||
config (T): The configuration object.
|
||||
|
||||
Returns:
|
||||
Self: The new instance of the component.
|
||||
|
||||
:meta public:
|
||||
"""
|
||||
raise NotImplementedError("This component does not support dumping to config")
|
||||
|
||||
@classmethod
|
||||
def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self:
|
||||
"""Create a new instance of the component from a previous version of the configuration object.
|
||||
|
||||
This is only called when the version of the configuration object is less than the current version, since in this case the schema is not known.
|
||||
|
||||
Args:
|
||||
config (Dict[str, Any]): The configuration object.
|
||||
version (int): The version of the configuration object.
|
||||
|
||||
Returns:
|
||||
Self: The new instance of the component.
|
||||
|
||||
:meta public:
|
||||
"""
|
||||
raise NotImplementedError("This component does not support loading from past versions")
|
||||
|
||||
|
||||
class ComponentToConfig(Generic[ToConfigT]):
|
||||
"""The two methods a class must implement to be a component.
|
||||
|
||||
Args:
|
||||
Protocol (ConfigT): Type which derives from :py:class:`pydantic.BaseModel`.
|
||||
"""
|
||||
|
||||
component_type: ClassVar[ComponentType]
|
||||
"""The logical type of the component."""
|
||||
component_version: ClassVar[int] = 1
|
||||
"""The version of the component, if schema incompatibilities are introduced this should be updated."""
|
||||
component_provider_override: ClassVar[str | None] = None
|
||||
"""Override the provider string for the component. This should be used to prevent internal module names being a part of the module name."""
|
||||
component_description: ClassVar[str | None] = None
|
||||
"""A description of the component. If not provided, the docstring of the class will be used."""
|
||||
component_label: ClassVar[str | None] = None
|
||||
"""A human readable label for the component. If not provided, the component class name will be used."""
|
||||
|
||||
def _to_config(self) -> ToConfigT:
|
||||
"""Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance.
|
||||
|
||||
Returns:
|
||||
T: The configuration of the component.
|
||||
|
||||
:meta public:
|
||||
"""
|
||||
raise NotImplementedError("This component does not support dumping to config")
|
||||
|
||||
def dump_component(self) -> ComponentModel:
|
||||
"""Dump the component to a model that can be loaded back in.
|
||||
|
||||
Raises:
|
||||
TypeError: If the component is a local class.
|
||||
|
||||
Returns:
|
||||
ComponentModel: The model representing the component.
|
||||
"""
|
||||
if self.component_provider_override is not None:
|
||||
provider = self.component_provider_override
|
||||
else:
|
||||
provider = _type_to_provider_str(self.__class__)
|
||||
# Warn if internal module name is used,
|
||||
if "._" in provider:
|
||||
warnings.warn(
|
||||
"Internal module name used in provider string. This is not recommended and may cause issues in the future. Silence this warning by setting component_provider_override to this value.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if "<locals>" in provider:
|
||||
raise TypeError("Cannot dump component with local class")
|
||||
|
||||
if not hasattr(self, "component_type"):
|
||||
raise AttributeError("component_type not defined")
|
||||
|
||||
description = self.component_description
|
||||
if description is None and self.__class__.__doc__:
|
||||
# use docstring as description
|
||||
docstring = self.__class__.__doc__.strip()
|
||||
for marker in ["\n\nArgs:", "\n\nParameters:", "\n\nAttributes:", "\n\n"]:
|
||||
docstring = docstring.split(marker)[0]
|
||||
description = docstring.strip()
|
||||
|
||||
obj_config = self._to_config().model_dump(exclude_none=True)
|
||||
model = ComponentModel(
|
||||
provider=provider,
|
||||
component_type=self.component_type,
|
||||
version=self.component_version,
|
||||
component_version=self.component_version,
|
||||
description=description,
|
||||
label=self.component_label or self.__class__.__name__,
|
||||
config=obj_config,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
ExpectedType = TypeVar("ExpectedType")
|
||||
|
||||
|
||||
class ComponentLoader:
|
||||
@overload
|
||||
@classmethod
|
||||
def load_component(cls, model: ComponentModel | Dict[str, Any], expected: None = None) -> Self: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def load_component(cls, model: ComponentModel | Dict[str, Any], expected: Type[ExpectedType]) -> ExpectedType: ...
|
||||
|
||||
@classmethod
|
||||
def load_component(
|
||||
cls, model: ComponentModel | Dict[str, Any], expected: Type[ExpectedType] | None = None
|
||||
) -> Self | ExpectedType:
|
||||
"""Load a component from a model. Intended to be used with the return type of :py:meth:`agentdhal_core.ComponentConfig.dump_component`.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core import ComponentModel
|
||||
from agentdhal_core.models import ChatCompletionClient
|
||||
|
||||
component: ComponentModel = ... # type: ignore
|
||||
|
||||
model_client = ChatCompletionClient.load_component(component)
|
||||
|
||||
Args:
|
||||
model (ComponentModel): The model to load the component from.
|
||||
|
||||
Returns:
|
||||
Self: The loaded component.
|
||||
|
||||
Args:
|
||||
model (ComponentModel): _description_
|
||||
expected (Type[ExpectedType] | None, optional): Explicit type only if used directly on ComponentLoader. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider string is invalid.
|
||||
TypeError: Provider is not a subclass of ComponentConfigImpl, or the expected type does not match.
|
||||
|
||||
Returns:
|
||||
Self | ExpectedType: The loaded component.
|
||||
"""
|
||||
|
||||
# Use global and add further type checks
|
||||
|
||||
if isinstance(model, dict):
|
||||
loaded_model = ComponentModel(**model)
|
||||
else:
|
||||
loaded_model = model
|
||||
|
||||
# First, do a look up in well known providers
|
||||
if loaded_model.provider in WELL_KNOWN_PROVIDERS:
|
||||
loaded_model.provider = WELL_KNOWN_PROVIDERS[loaded_model.provider]
|
||||
|
||||
output = loaded_model.provider.rsplit(".", maxsplit=1)
|
||||
if len(output) != 2:
|
||||
raise ValueError("Invalid")
|
||||
|
||||
module_path, class_name = output
|
||||
module = importlib.import_module(module_path)
|
||||
component_class = module.__getattribute__(class_name)
|
||||
|
||||
if not is_component_class(component_class):
|
||||
raise TypeError("Invalid component class")
|
||||
|
||||
# We need to check the schema is valid
|
||||
if not hasattr(component_class, "component_config_schema"):
|
||||
raise AttributeError("component_config_schema not defined")
|
||||
|
||||
if not hasattr(component_class, "component_type"):
|
||||
raise AttributeError("component_type not defined")
|
||||
|
||||
loaded_config_version = loaded_model.component_version or component_class.component_version
|
||||
if loaded_config_version < component_class.component_version:
|
||||
try:
|
||||
instance = component_class._from_config_past_version(loaded_model.config, loaded_config_version) # type: ignore
|
||||
except NotImplementedError as e:
|
||||
raise NotImplementedError(
|
||||
f"Tried to load component {component_class} which is on version {component_class.component_version} with a config on version {loaded_config_version} but _from_config_past_version is not implemented"
|
||||
) from e
|
||||
else:
|
||||
schema = component_class.component_config_schema # type: ignore
|
||||
validated_config = schema.model_validate(loaded_model.config)
|
||||
|
||||
# We're allowed to use the private method here
|
||||
instance = component_class._from_config(validated_config) # type: ignore
|
||||
|
||||
if expected is None and not isinstance(instance, cls):
|
||||
raise TypeError("Expected type does not match")
|
||||
elif expected is None:
|
||||
return cast(Self, instance)
|
||||
elif not isinstance(instance, expected):
|
||||
raise TypeError("Expected type does not match")
|
||||
else:
|
||||
return cast(ExpectedType, instance)
|
||||
|
||||
|
||||
class ComponentSchemaType(Generic[ConfigT]):
|
||||
# Ideally would be ClassVar[Type[ConfigT]], but this is disallowed https://github.com/python/typing/discussions/1424 (despite being valid in this context)
|
||||
component_config_schema: Type[ConfigT]
|
||||
"""The Pydantic model class which represents the configuration of the component."""
|
||||
|
||||
required_class_vars = ["component_config_schema", "component_type"]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
if cls.__name__ != "Component" and not cls.__name__ == "_ConcreteComponent":
|
||||
# TODO: validate provider is loadable
|
||||
for var in cls.required_class_vars:
|
||||
if not hasattr(cls, var):
|
||||
warnings.warn(
|
||||
f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
class ComponentBase(ComponentToConfig[ConfigT], ComponentLoader, Generic[ConfigT]): ...
|
||||
|
||||
|
||||
class Component(
|
||||
ComponentFromConfig[ConfigT],
|
||||
ComponentSchemaType[ConfigT],
|
||||
Generic[ConfigT],
|
||||
):
|
||||
"""To create a component class, inherit from this class for the concrete class and ComponentBase on the interface. Then implement two class variables:
|
||||
|
||||
- :py:attr:`component_config_schema` - A Pydantic model class which represents the configuration of the component. This is also the type parameter of Component.
|
||||
- :py:attr:`component_type` - What is the logical type of the component.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
from agentdhal_core import Component
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class MyComponent(Component[Config]):
|
||||
component_type = "custom"
|
||||
component_config_schema = Config
|
||||
|
||||
def __init__(self, value: str):
|
||||
self.value = value
|
||||
|
||||
def _to_config(self) -> Config:
|
||||
return Config(value=self.value)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: Config) -> MyComponent:
|
||||
return cls(value=config.value)
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
if not is_component_class(cls):
|
||||
warnings.warn(
|
||||
f"Component class '{cls.__name__}' must subclass the following: ComponentFromConfig, ComponentToConfig, ComponentSchemaType, ComponentLoader, individually or with ComponentBase and Component. Look at the component config documentation or how OpenAIChatCompletionClient does it.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
# Should never be used directly, only for type checking
|
||||
class _ConcreteComponent(
|
||||
ComponentFromConfig[ConfigT],
|
||||
ComponentSchemaType[ConfigT],
|
||||
ComponentToConfig[ConfigT],
|
||||
ComponentLoader,
|
||||
Generic[ConfigT],
|
||||
): ...
|
||||
|
||||
|
||||
def is_component_instance(cls: Any) -> TypeGuard[_ConcreteComponent[BaseModel]]:
|
||||
return (
|
||||
isinstance(cls, ComponentFromConfig)
|
||||
and isinstance(cls, ComponentToConfig)
|
||||
and isinstance(cls, ComponentSchemaType)
|
||||
and isinstance(cls, ComponentLoader)
|
||||
)
|
||||
|
||||
|
||||
def is_component_class(cls: type) -> TypeGuard[Type[_ConcreteComponent[BaseModel]]]:
|
||||
return (
|
||||
issubclass(cls, ComponentFromConfig)
|
||||
and issubclass(cls, ComponentToConfig)
|
||||
and issubclass(cls, ComponentSchemaType)
|
||||
and issubclass(cls, ComponentLoader)
|
||||
)
|
||||
9
agent_dhal/agentdhal_core/_constants.py
Normal file
9
agent_dhal/agentdhal_core/_constants.py
Normal file
@@ -0,0 +1,9 @@
|
||||
ROOT_LOGGER_NAME = "agentdhal_core"
|
||||
"""str: Logger name used for root logger"""
|
||||
|
||||
EVENT_LOGGER_NAME = "agentdhal_core.events"
|
||||
"""str: Logger name used for structured event logging"""
|
||||
|
||||
|
||||
TRACE_LOGGER_NAME = "agentdhal_core.trace"
|
||||
"""str: Logger name used for developer intended trace logging. The content and format of this log should not be depended upon."""
|
||||
53
agent_dhal/agentdhal_core/_default_subscription.py
Normal file
53
agent_dhal/agentdhal_core/_default_subscription.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Callable, Type, TypeVar, overload
|
||||
|
||||
from ._agent_type import AgentType
|
||||
from ._base_agent import BaseAgent, subscription_factory
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._type_subscription import TypeSubscription
|
||||
from .exceptions import CantHandleException
|
||||
|
||||
|
||||
class DefaultSubscription(TypeSubscription):
|
||||
"""The default subscription is designed to be a sensible default for applications that only need global scope for agents.
|
||||
|
||||
This topic by default uses the "default" topic type and attempts to detect the agent type to use based on the instantiation context.
|
||||
|
||||
Args:
|
||||
topic_type (str, optional): The topic type to subscribe to. Defaults to "default".
|
||||
agent_type (str, optional): The agent type to use for the subscription. Defaults to None, in which case it will attempt to detect the agent type based on the instantiation context.
|
||||
"""
|
||||
|
||||
def __init__(self, topic_type: str = "default", agent_type: str | AgentType | None = None):
|
||||
if agent_type is None:
|
||||
try:
|
||||
agent_type = SubscriptionInstantiationContext.agent_type().type
|
||||
except RuntimeError as e:
|
||||
raise CantHandleException(
|
||||
"If agent_type is not specified DefaultSubscription must be created within the subscription callback in AgentRuntime.register"
|
||||
) from e
|
||||
|
||||
super().__init__(topic_type, agent_type)
|
||||
|
||||
|
||||
BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")
|
||||
|
||||
|
||||
@overload
|
||||
def default_subscription() -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def default_subscription(cls: Type[BaseAgentType]) -> Type[BaseAgentType]: ...
|
||||
|
||||
|
||||
def default_subscription(
|
||||
cls: Type[BaseAgentType] | None = None,
|
||||
) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]] | Type[BaseAgentType]:
|
||||
if cls is None:
|
||||
return subscription_factory(lambda: [DefaultSubscription()])
|
||||
else:
|
||||
return subscription_factory(lambda: [DefaultSubscription()])(cls)
|
||||
|
||||
|
||||
def type_subscription(topic_type: str) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
|
||||
return subscription_factory(lambda: [DefaultSubscription(topic_type=topic_type)])
|
||||
23
agent_dhal/agentdhal_core/_default_topic.py
Normal file
23
agent_dhal/agentdhal_core/_default_topic.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from ._message_handler_context import MessageHandlerContext
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
class DefaultTopicId(TopicId):
|
||||
"""DefaultTopicId provides a sensible default for the topic_id and source fields of a TopicId.
|
||||
|
||||
If created in the context of a message handler, the source will be set to the agent_id of the message handler, otherwise it will be set to "default".
|
||||
|
||||
Args:
|
||||
type (str, optional): Topic type to publish message to. Defaults to "default".
|
||||
source (str | None, optional): Topic source to publish message to. If None, the source will be set to the agent_id of the message handler if in the context of a message handler, otherwise it will be set to "default". Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, type: str = "default", source: str | None = None) -> None:
|
||||
if source is None:
|
||||
try:
|
||||
source = MessageHandlerContext.agent_id().key
|
||||
# If we aren't in the context of a message handler, we use the default source
|
||||
except RuntimeError:
|
||||
source = "default"
|
||||
|
||||
super().__init__(type, source)
|
||||
324
agent_dhal/agentdhal_core/_function_utils.py
Normal file
324
agent_dhal/agentdhal_core/_function_utils.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/function_utils.py
|
||||
# Credit to original authors
|
||||
|
||||
import inspect
|
||||
import typing
|
||||
from functools import partial
|
||||
from logging import getLogger
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter, create_model # type: ignore
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import Literal
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
"""Get the signature of a function with type annotations.
|
||||
|
||||
Args:
|
||||
call: The function to get the signature for
|
||||
|
||||
Returns:
|
||||
The signature of the function with type annotations
|
||||
"""
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
func_call = call.func if isinstance(call, partial) else call
|
||||
type_hints = typing.get_type_hints(func_call, globalns, include_extras=True)
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=type_hints[param.name],
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
return_annotation = type_hints.get("return", inspect.Signature.empty)
|
||||
typed_signature = inspect.Signature(typed_params, return_annotation=return_annotation)
|
||||
return typed_signature
|
||||
|
||||
|
||||
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
||||
"""Get the return annotation of a function.
|
||||
|
||||
Args:
|
||||
call: The function to get the return annotation for
|
||||
|
||||
Returns:
|
||||
The return annotation of the function
|
||||
"""
|
||||
signature = inspect.signature(call)
|
||||
annotation = signature.return_annotation
|
||||
|
||||
if annotation is inspect.Signature.empty:
|
||||
return None
|
||||
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
type_hints = typing.get_type_hints(call, globalns, include_extras=True)
|
||||
return type_hints.get("return", inspect.Signature.empty)
|
||||
|
||||
|
||||
def get_param_annotations(
|
||||
typed_signature: inspect.Signature,
|
||||
) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]:
|
||||
"""Get the type annotations of the parameters of a function
|
||||
|
||||
Args:
|
||||
typed_signature: The signature of the function with type annotations
|
||||
|
||||
Returns:
|
||||
A dictionary of the type annotations of the parameters of the function
|
||||
"""
|
||||
return {
|
||||
k: v.annotation for k, v in typed_signature.parameters.items() if v.annotation is not inspect.Signature.empty
|
||||
}
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
"""Parameters of a function as defined by the OpenAI API"""
|
||||
|
||||
type: Literal["object"] = "object"
|
||||
properties: Dict[str, Dict[str, Any]]
|
||||
required: List[str]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""A function as defined by the OpenAI API"""
|
||||
|
||||
description: Annotated[str, Field(description="Description of the function")]
|
||||
name: Annotated[str, Field(description="Name of the function")]
|
||||
parameters: Annotated[Parameters, Field(description="Parameters of the function")]
|
||||
|
||||
|
||||
class ToolFunction(BaseModel):
|
||||
"""A function under tool as defined by the OpenAI API."""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
function: Annotated[Function, Field(description="Function under tool")]
|
||||
|
||||
|
||||
def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
|
||||
# handles Annotated
|
||||
if hasattr(v, "__metadata__"):
|
||||
retval = v.__metadata__[0]
|
||||
if isinstance(retval, str):
|
||||
return retval
|
||||
else:
|
||||
raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
|
||||
else:
|
||||
return k
|
||||
|
||||
|
||||
def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get a JSON schema for a parameter as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
k: The name of the parameter
|
||||
v: The type of the parameter
|
||||
default_values: The default values of the parameters of the function
|
||||
|
||||
Returns:
|
||||
A Pydanitc model for the parameter
|
||||
"""
|
||||
|
||||
schema = TypeAdapter(v).json_schema()
|
||||
if k in default_values:
|
||||
dv = default_values[k]
|
||||
schema["default"] = dv
|
||||
|
||||
schema["description"] = type2description(k, v)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def get_required_params(typed_signature: inspect.Signature) -> List[str]:
|
||||
"""Get the required parameters of a function
|
||||
|
||||
Args:
|
||||
typed_signature: The signature of the function as returned by inspect.signature
|
||||
|
||||
Returns:
|
||||
A list of the required parameters of the function
|
||||
"""
|
||||
return [k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty]
|
||||
|
||||
|
||||
def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]:
|
||||
"""Get default values of parameters of a function
|
||||
|
||||
Args:
|
||||
typed_signature: The signature of the function as returned by inspect.signature
|
||||
|
||||
Returns:
|
||||
A dictionary of the default values of the parameters of the function
|
||||
"""
|
||||
return {k: v.default for k, v in typed_signature.parameters.items() if v.default != inspect.Signature.empty}
|
||||
|
||||
|
||||
def get_parameters(
|
||||
required: List[str],
|
||||
param_annotations: Dict[str, Union[Annotated[Type[Any], str], Type[Any]]],
|
||||
default_values: Dict[str, Any],
|
||||
) -> Parameters:
|
||||
"""Get the parameters of a function as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
required: The required parameters of the function
|
||||
param_annotations: A dictionary of the type annotations of the parameters of the function
|
||||
default_values: The default values of the parameters of the function
|
||||
|
||||
Returns:
|
||||
A Pydantic model for the parameters of the function
|
||||
"""
|
||||
return Parameters(
|
||||
properties={
|
||||
k: get_parameter_json_schema(k, v, default_values)
|
||||
for k, v in param_annotations.items()
|
||||
if v is not inspect.Signature.empty
|
||||
},
|
||||
required=required,
|
||||
)
|
||||
|
||||
|
||||
def get_missing_annotations(typed_signature: inspect.Signature, required: List[str]) -> Tuple[Set[str], Set[str]]:
|
||||
"""Get the missing annotations of a function
|
||||
|
||||
Ignores the parameters with default values as they are not required to be annotated, but logs a warning.
|
||||
Args:
|
||||
typed_signature: The signature of the function with type annotations
|
||||
required: The required parameters of the function
|
||||
|
||||
Returns:
|
||||
A set of the missing annotations of the function
|
||||
"""
|
||||
all_missing = {k for k, v in typed_signature.parameters.items() if v.annotation is inspect.Signature.empty}
|
||||
missing = all_missing.intersection(set(required))
|
||||
unannotated_with_default = all_missing.difference(missing)
|
||||
return missing, unannotated_with_default
|
||||
|
||||
|
||||
def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, description: str) -> Dict[str, Any]:
|
||||
"""Get a JSON schema for a function as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
f: The function to get the JSON schema for
|
||||
name: The name of the function
|
||||
description: The description of the function
|
||||
|
||||
Returns:
|
||||
A JSON schema for the function
|
||||
|
||||
Raises:
|
||||
TypeError: If the function is not annotated
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def f(
|
||||
a: Annotated[str, "Parameter a"],
|
||||
b: int = 2,
|
||||
c: Annotated[float, "Parameter c"] = 0.1,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
get_function_schema(f, description="function f")
|
||||
|
||||
# {'type': 'function',
|
||||
# 'function': {'description': 'function f',
|
||||
# 'name': 'f',
|
||||
# 'parameters': {'type': 'object',
|
||||
# 'properties': {'a': {'type': 'str', 'description': 'Parameter a'},
|
||||
# 'b': {'type': 'int', 'description': 'b'},
|
||||
# 'c': {'type': 'float', 'description': 'Parameter c'}},
|
||||
# 'required': ['a']}}}
|
||||
|
||||
"""
|
||||
typed_signature = get_typed_signature(f)
|
||||
required = get_required_params(typed_signature)
|
||||
default_values = get_default_values(typed_signature)
|
||||
param_annotations = get_param_annotations(typed_signature)
|
||||
return_annotation = get_typed_return_annotation(f)
|
||||
missing, unannotated_with_default = get_missing_annotations(typed_signature, required)
|
||||
|
||||
if return_annotation is None:
|
||||
logger.warning(
|
||||
f"The return type of the function '{f.__name__}' is not annotated. Although annotating it is "
|
||||
+ "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'."
|
||||
)
|
||||
|
||||
if unannotated_with_default != set():
|
||||
unannotated_with_default_s = [f"'{k}'" for k in sorted(unannotated_with_default)]
|
||||
logger.warning(
|
||||
f"The following parameters of the function '{f.__name__}' with default values are not annotated: "
|
||||
+ f"{', '.join(unannotated_with_default_s)}."
|
||||
)
|
||||
|
||||
if missing != set():
|
||||
missing_s = [f"'{k}'" for k in sorted(missing)]
|
||||
raise TypeError(
|
||||
f"All parameters of the function '{f.__name__}' without default values must be annotated. "
|
||||
+ f"The annotations are missing for the following parameters: {', '.join(missing_s)}"
|
||||
)
|
||||
|
||||
fname = name if name else f.__name__
|
||||
|
||||
parameters = get_parameters(required, param_annotations, default_values=default_values)
|
||||
|
||||
function = ToolFunction(
|
||||
function=Function(
|
||||
description=description,
|
||||
name=fname,
|
||||
parameters=parameters,
|
||||
)
|
||||
)
|
||||
|
||||
return function.model_dump()
|
||||
|
||||
|
||||
def normalize_annotated_type(type_hint: Type[Any]) -> Type[Any]:
|
||||
"""Normalize typing.Annotated types to the inner type."""
|
||||
if get_origin(type_hint) is Annotated:
|
||||
# Extract the inner type from Annotated
|
||||
return get_args(type_hint)[0] # type: ignore
|
||||
return type_hint
|
||||
|
||||
|
||||
def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
|
||||
fields: Dict[str, tuple[Type[Any], Any]] = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
# This is handled externally
|
||||
if param_name == "cancellation_token":
|
||||
continue
|
||||
|
||||
if param.annotation is inspect.Parameter.empty:
|
||||
raise ValueError("No annotation")
|
||||
|
||||
type = normalize_annotated_type(param.annotation)
|
||||
description = type2description(param_name, param.annotation)
|
||||
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined
|
||||
|
||||
fields[param_name] = (type, Field(default=default_value, description=description))
|
||||
|
||||
return cast(BaseModel, create_model(name, **fields)) # type: ignore
|
||||
127
agent_dhal/agentdhal_core/_image.py
Normal file
127
agent_dhal/agentdhal_core/_image.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import re
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from pydantic import GetCoreSchemaHandler, ValidationInfo
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class Image:
|
||||
"""Represents an image.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
Loading an image from a URL:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core import Image
|
||||
from PIL import Image as PILImage
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
|
||||
async def from_url(url: str) -> Image:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
content = await response.read()
|
||||
return Image.from_pil(PILImage.open(content))
|
||||
|
||||
|
||||
image = asyncio.run(from_url("https://example.com/image"))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, image: PILImage.Image):
|
||||
self.image: PILImage.Image = image.convert("RGB")
|
||||
|
||||
@classmethod
|
||||
def from_pil(cls, pil_image: PILImage.Image) -> Image:
|
||||
return cls(pil_image)
|
||||
|
||||
@classmethod
|
||||
def from_uri(cls, uri: str) -> Image:
|
||||
if not re.match(r"data:image/(?:png|jpeg);base64,", uri):
|
||||
raise ValueError("Invalid URI format. It should be a base64 encoded image URI.")
|
||||
|
||||
# A URI. Remove the prefix and decode the base64 string.
|
||||
base64_data = re.sub(r"data:image/(?:png|jpeg);base64,", "", uri)
|
||||
return cls.from_base64(base64_data)
|
||||
|
||||
@classmethod
|
||||
def from_base64(cls, base64_str: str) -> Image:
|
||||
return cls(PILImage.open(BytesIO(base64.b64decode(base64_str))))
|
||||
|
||||
def to_base64(self) -> str:
|
||||
buffered = BytesIO()
|
||||
self.image.save(buffered, format="PNG")
|
||||
content = buffered.getvalue()
|
||||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file_path: Path) -> Image:
|
||||
return cls(PILImage.open(file_path))
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
# Show the image in Jupyter notebook
|
||||
return f'<img src="{self.data_uri}"/>'
|
||||
|
||||
@property
|
||||
def data_uri(self) -> str:
|
||||
return _convert_base64_to_data_uri(self.to_base64())
|
||||
|
||||
# Returns openai.types.chat.ChatCompletionContentPartImageParam, which is a TypedDict
|
||||
# We don't use the explicit type annotation so that we can avoid a dependency on the OpenAI Python SDK in this package.
|
||||
def to_openai_format(self, detail: Literal["auto", "low", "high"] = "auto") -> Dict[str, Any]:
|
||||
return {"type": "image_url", "image_url": {"url": self.data_uri, "detail": detail}}
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
# Custom validation
|
||||
def validate(value: Any, validation_info: ValidationInfo) -> Image:
|
||||
if isinstance(value, dict):
|
||||
base_64 = cast(str | None, value.get("data")) # type: ignore
|
||||
if base_64 is None:
|
||||
raise ValueError("Expected 'data' key in the dictionary")
|
||||
return cls.from_base64(base_64)
|
||||
elif isinstance(value, cls):
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"Expected dict or {cls.__name__} instance, got {type(value)}")
|
||||
|
||||
# Custom serialization
|
||||
def serialize(value: Image) -> dict[str, Any]:
|
||||
return {"data": value.to_base64()}
|
||||
|
||||
return core_schema.with_info_after_validator_function(
|
||||
validate,
|
||||
core_schema.any_schema(), # Accept any type; adjust if needed
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(serialize),
|
||||
)
|
||||
|
||||
|
||||
def _convert_base64_to_data_uri(base64_image: str) -> str:
|
||||
def _get_mime_type_from_data_uri(base64_image: str) -> str:
|
||||
# Decode the base64 string
|
||||
image_data = base64.b64decode(base64_image)
|
||||
# Check the first few bytes for known signatures
|
||||
if image_data.startswith(b"\xff\xd8\xff"):
|
||||
return "image/jpeg"
|
||||
elif image_data.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||
return "image/png"
|
||||
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
|
||||
return "image/gif"
|
||||
elif image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP":
|
||||
return "image/webp"
|
||||
return "image/jpeg" # use jpeg for unknown formats, best guess.
|
||||
|
||||
mime_type = _get_mime_type_from_data_uri(base64_image)
|
||||
data_uri = f"data:{mime_type};base64,{base64_image}"
|
||||
return data_uri
|
||||
83
agent_dhal/agentdhal_core/_intervention.py
Normal file
83
agent_dhal/agentdhal_core/_intervention.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Any, Protocol, final
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._message_context import MessageContext
|
||||
|
||||
__all__ = [
|
||||
"DropMessage",
|
||||
"InterventionHandler",
|
||||
"DefaultInterventionHandler",
|
||||
]
|
||||
|
||||
|
||||
@final
|
||||
class DropMessage:
|
||||
"""Marker type for signalling that a message should be dropped by an intervention handler. The type itself should be returned from the handler."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class InterventionHandler(Protocol):
|
||||
"""An intervention handler is a class that can be used to modify, log or drop messages that are being processed by the :class:`agentdhal_core.base.AgentRuntime`.
|
||||
|
||||
The handler is called when the message is submitted to the runtime.
|
||||
|
||||
Currently the only runtime which supports this is the :class:`agentdhal_core.base.SingleThreadedAgentRuntime`.
|
||||
|
||||
Note: Returning None from any of the intervention handler methods will result in a warning being issued and treated as "no change". If you intend to drop a message, you should return :class:`DropMessage` explicitly.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core import DefaultInterventionHandler, MessageContext, AgentId, SingleThreadedAgentRuntime
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyMessage:
|
||||
content: str
|
||||
|
||||
|
||||
class MyInterventionHandler(DefaultInterventionHandler):
|
||||
async def on_send(self, message: Any, *, message_context: MessageContext, recipient: AgentId) -> MyMessage:
|
||||
if isinstance(message, MyMessage):
|
||||
message.content = message.content.upper()
|
||||
return message
|
||||
|
||||
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handlers=[MyInterventionHandler()])
|
||||
|
||||
"""
|
||||
|
||||
async def on_send(
|
||||
self, message: Any, *, message_context: MessageContext, recipient: AgentId
|
||||
) -> Any | type[DropMessage]:
|
||||
"""Called when a message is submitted to the AgentRuntime using :meth:`agentdhal_core.base.AgentRuntime.send_message`."""
|
||||
...
|
||||
|
||||
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]:
|
||||
"""Called when a message is published to the AgentRuntime using :meth:`agentdhal_core.base.AgentRuntime.publish_message`."""
|
||||
...
|
||||
|
||||
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
|
||||
"""Called when a response is received by the AgentRuntime from an Agent's message handler returning a value."""
|
||||
...
|
||||
|
||||
|
||||
class DefaultInterventionHandler(InterventionHandler):
|
||||
"""Simple class that provides a default implementation for all intervention
|
||||
handler methods, that simply returns the message unchanged. Allows for easy
|
||||
subclassing to override only the desired methods."""
|
||||
|
||||
async def on_send(
|
||||
self, message: Any, *, message_context: MessageContext, recipient: AgentId
|
||||
) -> Any | type[DropMessage]:
|
||||
return message
|
||||
|
||||
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]:
|
||||
return message
|
||||
|
||||
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
|
||||
return message
|
||||
14
agent_dhal/agentdhal_core/_message_context.py
Normal file
14
agent_dhal/agentdhal_core/_message_context.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageContext:
|
||||
sender: AgentId | None
|
||||
topic_id: TopicId | None
|
||||
is_rpc: bool
|
||||
cancellation_token: CancellationToken
|
||||
message_id: str
|
||||
31
agent_dhal/agentdhal_core/_message_handler_context.py
Normal file
31
agent_dhal/agentdhal_core/_message_handler_context.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, ClassVar, Generator
|
||||
|
||||
from ._agent_id import AgentId
|
||||
|
||||
|
||||
class MessageHandlerContext:
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError(
|
||||
"MessageHandlerContext cannot be instantiated. It is a static class that provides context management for message handling."
|
||||
)
|
||||
|
||||
_MESSAGE_HANDLER_CONTEXT: ClassVar[ContextVar[AgentId]] = ContextVar("_MESSAGE_HANDLER_CONTEXT")
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def populate_context(cls, ctx: AgentId) -> Generator[None, Any, None]:
|
||||
""":meta private:"""
|
||||
token = MessageHandlerContext._MESSAGE_HANDLER_CONTEXT.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
MessageHandlerContext._MESSAGE_HANDLER_CONTEXT.reset(token)
|
||||
|
||||
@classmethod
|
||||
def agent_id(cls) -> AgentId:
|
||||
try:
|
||||
return cls._MESSAGE_HANDLER_CONTEXT.get()
|
||||
except LookupError as e:
|
||||
raise RuntimeError("MessageHandlerContext.agent_id() must be called within a message handler.") from e
|
||||
264
agent_dhal/agentdhal_core/_queue.py
Normal file
264
agent_dhal/agentdhal_core/_queue.py
Normal file
@@ -0,0 +1,264 @@
|
||||
# Copy of Asyncio queue: https://github.com/python/cpython/blob/main/Lib/asyncio/queues.py
|
||||
# So that shutdown can be used in <3.13
|
||||
# Modified to work outside of the asyncio package
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import threading
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
_global_lock = threading.Lock()
|
||||
|
||||
|
||||
class _LoopBoundMixin:
|
||||
_loop = None
|
||||
|
||||
def _get_loop(self) -> asyncio.AbstractEventLoop:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if self._loop is None:
|
||||
with _global_lock:
|
||||
if self._loop is None:
|
||||
self._loop = loop
|
||||
if loop is not self._loop:
|
||||
raise RuntimeError(f"{self!r} is bound to a different event loop")
|
||||
return loop
|
||||
|
||||
|
||||
class QueueShutDown(Exception):
|
||||
"""Raised when putting on to or getting from a shut-down Queue."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Queue(_LoopBoundMixin, Generic[T]):
|
||||
def __init__(self, maxsize: int = 0):
|
||||
self._maxsize = maxsize
|
||||
self._getters = collections.deque[asyncio.Future[None]]()
|
||||
self._putters = collections.deque[asyncio.Future[None]]()
|
||||
self._unfinished_tasks = 0
|
||||
self._finished = asyncio.Event()
|
||||
self._finished.set()
|
||||
self._queue = collections.deque[T]()
|
||||
self._is_shutdown = False
|
||||
|
||||
# These three are overridable in subclasses.
|
||||
|
||||
def _get(self) -> T:
|
||||
return self._queue.popleft()
|
||||
|
||||
def _put(self, item: T) -> None:
|
||||
self._queue.append(item)
|
||||
|
||||
# End of the overridable methods.
|
||||
|
||||
def _wakeup_next(self, waiters: collections.deque[asyncio.Future[None]]) -> None:
|
||||
# Wake up the next waiter (if any) that isn't cancelled.
|
||||
while waiters:
|
||||
waiter = waiters.popleft()
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
break
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{type(self).__name__} at {id(self):#x} {self._format()}>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"<{type(self).__name__} {self._format()}>"
|
||||
|
||||
def _format(self) -> str:
|
||||
result = f"maxsize={self._maxsize!r}"
|
||||
if getattr(self, "_queue", None):
|
||||
result += f" _queue={list(self._queue)!r}"
|
||||
if self._getters:
|
||||
result += f" _getters[{len(self._getters)}]"
|
||||
if self._putters:
|
||||
result += f" _putters[{len(self._putters)}]"
|
||||
if self._unfinished_tasks:
|
||||
result += f" tasks={self._unfinished_tasks}"
|
||||
if self._is_shutdown:
|
||||
result += " shutdown"
|
||||
return result
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""Number of items in the queue."""
|
||||
return len(self._queue)
|
||||
|
||||
@property
|
||||
def maxsize(self) -> int:
|
||||
"""Number of items allowed in the queue."""
|
||||
return self._maxsize
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Return True if the queue is empty, False otherwise."""
|
||||
return not self._queue
|
||||
|
||||
def full(self) -> bool:
|
||||
"""Return True if there are maxsize items in the queue.
|
||||
|
||||
Note: if the Queue was initialized with maxsize=0 (the default),
|
||||
then full() is never True.
|
||||
"""
|
||||
if self._maxsize <= 0:
|
||||
return False
|
||||
else:
|
||||
return self.qsize() >= self._maxsize
|
||||
|
||||
async def put(self, item: T) -> None:
|
||||
"""Put an item into the queue.
|
||||
|
||||
Put an item into the queue. If the queue is full, wait until a free
|
||||
slot is available before adding item.
|
||||
|
||||
Raises QueueShutDown if the queue has been shut down.
|
||||
"""
|
||||
while self.full():
|
||||
if self._is_shutdown:
|
||||
raise QueueShutDown
|
||||
putter = self._get_loop().create_future()
|
||||
self._putters.append(putter)
|
||||
try:
|
||||
await putter
|
||||
except:
|
||||
putter.cancel() # Just in case putter is not done yet.
|
||||
try:
|
||||
# Clean self._putters from canceled putters.
|
||||
self._putters.remove(putter)
|
||||
except ValueError:
|
||||
# The putter could be removed from self._putters by a
|
||||
# previous get_nowait call or a shutdown call.
|
||||
pass
|
||||
if not self.full() and not putter.cancelled():
|
||||
# We were woken up by get_nowait(), but can't take
|
||||
# the call. Wake up the next in line.
|
||||
self._wakeup_next(self._putters)
|
||||
raise
|
||||
return self.put_nowait(item)
|
||||
|
||||
def put_nowait(self, item: T) -> None:
|
||||
"""Put an item into the queue without blocking.
|
||||
|
||||
If no free slot is immediately available, raise QueueFull.
|
||||
|
||||
Raises QueueShutDown if the queue has been shut down.
|
||||
"""
|
||||
if self._is_shutdown:
|
||||
raise QueueShutDown
|
||||
if self.full():
|
||||
raise asyncio.QueueFull
|
||||
self._put(item)
|
||||
self._unfinished_tasks += 1
|
||||
self._finished.clear()
|
||||
self._wakeup_next(self._getters)
|
||||
|
||||
async def get(self) -> T:
|
||||
"""Remove and return an item from the queue.
|
||||
|
||||
If queue is empty, wait until an item is available.
|
||||
|
||||
Raises QueueShutDown if the queue has been shut down and is empty, or
|
||||
if the queue has been shut down immediately.
|
||||
"""
|
||||
while self.empty():
|
||||
if self._is_shutdown and self.empty():
|
||||
raise QueueShutDown
|
||||
getter = self._get_loop().create_future()
|
||||
self._getters.append(getter)
|
||||
try:
|
||||
await getter
|
||||
except:
|
||||
getter.cancel() # Just in case getter is not done yet.
|
||||
try:
|
||||
# Clean self._getters from canceled getters.
|
||||
self._getters.remove(getter)
|
||||
except ValueError:
|
||||
# The getter could be removed from self._getters by a
|
||||
# previous put_nowait call, or a shutdown call.
|
||||
pass
|
||||
if not self.empty() and not getter.cancelled():
|
||||
# We were woken up by put_nowait(), but can't take
|
||||
# the call. Wake up the next in line.
|
||||
self._wakeup_next(self._getters)
|
||||
raise
|
||||
return self.get_nowait()
|
||||
|
||||
def get_nowait(self) -> T:
|
||||
"""Remove and return an item from the queue.
|
||||
|
||||
Return an item if one is immediately available, else raise QueueEmpty.
|
||||
|
||||
Raises QueueShutDown if the queue has been shut down and is empty, or
|
||||
if the queue has been shut down immediately.
|
||||
"""
|
||||
if self.empty():
|
||||
if self._is_shutdown:
|
||||
raise QueueShutDown
|
||||
raise asyncio.QueueEmpty
|
||||
item = self._get()
|
||||
self._wakeup_next(self._putters)
|
||||
return item
|
||||
|
||||
def task_done(self) -> None:
|
||||
"""Indicate that a formerly enqueued task is complete.
|
||||
|
||||
Used by queue consumers. For each get() used to fetch a task,
|
||||
a subsequent call to task_done() tells the queue that the processing
|
||||
on the task is complete.
|
||||
|
||||
If a join() is currently blocking, it will resume when all items have
|
||||
been processed (meaning that a task_done() call was received for every
|
||||
item that had been put() into the queue).
|
||||
|
||||
shutdown(immediate=True) calls task_done() for each remaining item in
|
||||
the queue.
|
||||
|
||||
Raises ValueError if called more times than there were items placed in
|
||||
the queue.
|
||||
"""
|
||||
if self._unfinished_tasks <= 0:
|
||||
raise ValueError("task_done() called too many times")
|
||||
self._unfinished_tasks -= 1
|
||||
if self._unfinished_tasks == 0:
|
||||
self._finished.set()
|
||||
|
||||
async def join(self) -> None:
|
||||
"""Block until all items in the queue have been gotten and processed.
|
||||
|
||||
The count of unfinished tasks goes up whenever an item is added to the
|
||||
queue. The count goes down whenever a consumer calls task_done() to
|
||||
indicate that the item was retrieved and all work on it is complete.
|
||||
When the count of unfinished tasks drops to zero, join() unblocks.
|
||||
"""
|
||||
if self._unfinished_tasks > 0:
|
||||
await self._finished.wait()
|
||||
|
||||
def shutdown(self, immediate: bool = False) -> None:
|
||||
"""Shut-down the queue, making queue gets and puts raise QueueShutDown.
|
||||
|
||||
By default, gets will only raise once the queue is empty. Set
|
||||
'immediate' to True to make gets raise immediately instead.
|
||||
|
||||
All blocked callers of put() and get() will be unblocked. If
|
||||
'immediate', a task is marked as done for each item remaining in
|
||||
the queue, which may unblock callers of join().
|
||||
"""
|
||||
self._is_shutdown = True
|
||||
if immediate:
|
||||
while not self.empty():
|
||||
self._get()
|
||||
if self._unfinished_tasks > 0:
|
||||
self._unfinished_tasks -= 1
|
||||
if self._unfinished_tasks == 0:
|
||||
self._finished.set()
|
||||
# All getters need to re-check queue-empty to raise ShutDown
|
||||
while self._getters:
|
||||
getter = self._getters.popleft()
|
||||
if not getter.done():
|
||||
getter.set_result(None)
|
||||
while self._putters:
|
||||
putter = self._putters.popleft()
|
||||
if not putter.done():
|
||||
putter.set_result(None)
|
||||
518
agent_dhal/agentdhal_core/_routed_agent.py
Normal file
518
agent_dhal/agentdhal_core/_routed_agent.py
Normal file
@@ -0,0 +1,518 @@
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
DefaultDict,
|
||||
List,
|
||||
Literal,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_type_hints,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from ._base_agent import BaseAgent
|
||||
from ._message_context import MessageContext
|
||||
from ._serialization import MessageSerializer, try_get_known_serializers_for_type
|
||||
from ._type_helpers import AnyType, get_types
|
||||
from .exceptions import CantHandleException
|
||||
|
||||
logger = logging.getLogger("agentdhal_core")
|
||||
|
||||
AgentT = TypeVar("AgentT")
|
||||
ReceivesT = TypeVar("ReceivesT")
|
||||
ProducesT = TypeVar("ProducesT", covariant=True)
|
||||
|
||||
# TODO: Generic typevar bound binding U to agent type
|
||||
# Can't do because python doesnt support it
|
||||
|
||||
|
||||
# Pyright and mypy disagree on the variance of ReceivesT. Mypy thinks it should be contravariant here.
|
||||
# Revisit this later to see if we can remove the ignore.
|
||||
@runtime_checkable
|
||||
class MessageHandler(Protocol[AgentT, ReceivesT, ProducesT]): # type: ignore
|
||||
target_types: Sequence[type]
|
||||
produces_types: Sequence[type]
|
||||
is_message_handler: Literal[True]
|
||||
router: Callable[[ReceivesT, MessageContext], bool]
|
||||
|
||||
# agent_instance binds to self in the method
|
||||
@staticmethod
|
||||
async def __call__(agent_instance: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: ...
|
||||
|
||||
|
||||
# NOTE: this works on concrete types and not inheritance
|
||||
# TODO: Use a protocol for the outer function to check checked arg names
|
||||
|
||||
|
||||
@overload
|
||||
def message_handler(
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, ProducesT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def message_handler(
|
||||
func: None = None,
|
||||
*,
|
||||
match: None = ...,
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def message_handler(
|
||||
func: None = None,
|
||||
*,
|
||||
match: Callable[[ReceivesT, MessageContext], bool],
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
def message_handler(
|
||||
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]
|
||||
| MessageHandler[AgentT, ReceivesT, ProducesT]
|
||||
):
|
||||
"""Decorator for generic message handlers.
|
||||
|
||||
Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle both event and RPC messages.
|
||||
These methods must have a specific signature that needs to be followed for it to be valid:
|
||||
|
||||
- The method must be an `async` method.
|
||||
- The method must be decorated with the `@message_handler` decorator.
|
||||
- The method must have exactly 3 arguments:
|
||||
1. `self`
|
||||
2. `message`: The message to be handled, this must be type-hinted with the message type that it is intended to handle.
|
||||
3. `ctx`: A :class:`agentdhal_core.MessageContext` object.
|
||||
- The method must be type hinted with what message types it can return as a response, or it can return `None` if it does not return anything.
|
||||
|
||||
Handlers can handle more than one message type by accepting a Union of the message types. It can also return more than one message type by returning a Union of the message types.
|
||||
|
||||
Args:
|
||||
func: The function to be decorated.
|
||||
strict: If `True`, the handler will raise an exception if the message type or return type is not in the target types. If `False`, it will log a warning instead.
|
||||
match: A function that takes the message and the context as arguments and returns a boolean. This is used for secondary routing after the message type. For handlers addressing the same message type, the match function is applied in alphabetical order of the handlers and the first matching handler will be called while the rest are skipped. If `None`, the first handler in alphabetical order matching the same message type will be called.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, ProducesT]:
|
||||
type_hints = get_type_hints(func)
|
||||
if "message" not in type_hints:
|
||||
raise AssertionError("message parameter not found in function signature")
|
||||
|
||||
if "return" not in type_hints:
|
||||
raise AssertionError("return parameter not found in function signature")
|
||||
|
||||
# Get the type of the message parameter
|
||||
target_types = get_types(type_hints["message"])
|
||||
if target_types is None:
|
||||
raise AssertionError("Message type not found")
|
||||
|
||||
# print(type_hints)
|
||||
return_types = get_types(type_hints["return"])
|
||||
|
||||
if return_types is None:
|
||||
raise AssertionError("Return type not found")
|
||||
|
||||
# Convert target_types to list and stash
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
else:
|
||||
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
||||
|
||||
return_value = await func(self, message, ctx)
|
||||
|
||||
if AnyType not in return_types and type(return_value) not in return_types:
|
||||
if strict:
|
||||
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
else:
|
||||
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
|
||||
return return_value
|
||||
|
||||
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper)
|
||||
wrapper_handler.target_types = list(target_types)
|
||||
wrapper_handler.produces_types = list(return_types)
|
||||
wrapper_handler.is_message_handler = True
|
||||
wrapper_handler.router = match or (lambda _message, _ctx: True)
|
||||
|
||||
return wrapper_handler
|
||||
|
||||
if func is None and not callable(func):
|
||||
return decorator
|
||||
elif callable(func):
|
||||
return decorator(func)
|
||||
else:
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
|
||||
@overload
|
||||
def event(
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def event(
|
||||
func: None = None,
|
||||
*,
|
||||
match: None = ...,
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[AgentT, ReceivesT, None],
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def event(
|
||||
func: None = None,
|
||||
*,
|
||||
match: Callable[[ReceivesT, MessageContext], bool],
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[AgentT, ReceivesT, None],
|
||||
]: ...
|
||||
|
||||
|
||||
def event(
|
||||
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[AgentT, ReceivesT, None],
|
||||
]
|
||||
| MessageHandler[AgentT, ReceivesT, None]
|
||||
):
|
||||
"""Decorator for event message handlers.
|
||||
|
||||
Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle event messages.
|
||||
These methods must have a specific signature that needs to be followed for it to be valid:
|
||||
|
||||
- The method must be an `async` method.
|
||||
- The method must be decorated with the `@message_handler` decorator.
|
||||
- The method must have exactly 3 arguments:
|
||||
1. `self`
|
||||
2. `message`: The event message to be handled, this must be type-hinted with the message type that it is intended to handle.
|
||||
3. `ctx`: A :class:`agentdhal_core.MessageContext` object.
|
||||
- The method must return `None`.
|
||||
|
||||
Handlers can handle more than one message type by accepting a Union of the message types.
|
||||
|
||||
Args:
|
||||
func: The function to be decorated.
|
||||
strict: If `True`, the handler will raise an exception if the message type is not in the target types. If `False`, it will log a warning instead.
|
||||
match: A function that takes the message and the context as arguments and returns a boolean. This is used for secondary routing after the message type. For handlers addressing the same message type, the match function is applied in alphabetical order of the handlers and the first matching handler will be called while the rest are skipped. If `None`, the first handler in alphabetical order matching the same message type will be called.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, None]:
|
||||
type_hints = get_type_hints(func)
|
||||
if "message" not in type_hints:
|
||||
raise AssertionError("message parameter not found in function signature")
|
||||
|
||||
if "return" not in type_hints:
|
||||
raise AssertionError("return parameter not found in function signature")
|
||||
|
||||
# Get the type of the message parameter
|
||||
target_types = get_types(type_hints["message"])
|
||||
if target_types is None:
|
||||
raise AssertionError("Message type not found. Please provide a type hint for the message parameter.")
|
||||
|
||||
return_types = get_types(type_hints["return"])
|
||||
|
||||
if return_types is None:
|
||||
raise AssertionError("Return type not found. Please use `None` as the type hint of the return type.")
|
||||
|
||||
# Convert target_types to list and stash
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
else:
|
||||
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
||||
|
||||
return_value = await func(self, message, ctx) # type: ignore
|
||||
|
||||
if return_value is not None:
|
||||
if strict:
|
||||
raise ValueError(f"Return type {type(return_value)} is not None.")
|
||||
else:
|
||||
logger.warning(f"Return type {type(return_value)} is not None. It will be ignored.")
|
||||
|
||||
return None
|
||||
|
||||
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, None], wrapper)
|
||||
wrapper_handler.target_types = list(target_types)
|
||||
wrapper_handler.produces_types = list(return_types)
|
||||
wrapper_handler.is_message_handler = True
|
||||
# Wrap the match function with a check on the is_rpc flag.
|
||||
wrapper_handler.router = lambda _message, _ctx: (not _ctx.is_rpc) and (match(_message, _ctx) if match else True)
|
||||
|
||||
return wrapper_handler
|
||||
|
||||
if func is None and not callable(func):
|
||||
return decorator
|
||||
elif callable(func):
|
||||
return decorator(func)
|
||||
else:
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
|
||||
@overload
|
||||
def rpc(
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, ProducesT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def rpc(
|
||||
func: None = None,
|
||||
*,
|
||||
match: None = ...,
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def rpc(
|
||||
func: None = None,
|
||||
*,
|
||||
match: Callable[[ReceivesT, MessageContext], bool],
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
def rpc(
|
||||
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]
|
||||
| MessageHandler[AgentT, ReceivesT, ProducesT]
|
||||
):
|
||||
"""Decorator for RPC message handlers.
|
||||
|
||||
Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle RPC messages.
|
||||
These methods must have a specific signature that needs to be followed for it to be valid:
|
||||
|
||||
- The method must be an `async` method.
|
||||
- The method must be decorated with the `@message_handler` decorator.
|
||||
- The method must have exactly 3 arguments:
|
||||
1. `self`
|
||||
2. `message`: The message to be handled, this must be type-hinted with the message type that it is intended to handle.
|
||||
3. `ctx`: A :class:`agentdhal_core.MessageContext` object.
|
||||
- The method must be type hinted with what message types it can return as a response, or it can return `None` if it does not return anything.
|
||||
|
||||
Handlers can handle more than one message type by accepting a Union of the message types. It can also return more than one message type by returning a Union of the message types.
|
||||
|
||||
Args:
|
||||
func: The function to be decorated.
|
||||
strict: If `True`, the handler will raise an exception if the message type or return type is not in the target types. If `False`, it will log a warning instead.
|
||||
match: A function that takes the message and the context as arguments and returns a boolean. This is used for secondary routing after the message type. For handlers addressing the same message type, the match function is applied in alphabetical order of the handlers and the first matching handler will be called while the rest are skipped. If `None`, the first handler in alphabetical order matching the same message type will be called.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, ProducesT]:
|
||||
type_hints = get_type_hints(func)
|
||||
if "message" not in type_hints:
|
||||
raise AssertionError("message parameter not found in function signature")
|
||||
|
||||
if "return" not in type_hints:
|
||||
raise AssertionError("return parameter not found in function signature")
|
||||
|
||||
# Get the type of the message parameter
|
||||
target_types = get_types(type_hints["message"])
|
||||
if target_types is None:
|
||||
raise AssertionError("Message type not found")
|
||||
|
||||
# print(type_hints)
|
||||
return_types = get_types(type_hints["return"])
|
||||
|
||||
if return_types is None:
|
||||
raise AssertionError("Return type not found")
|
||||
|
||||
# Convert target_types to list and stash
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
else:
|
||||
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
||||
|
||||
return_value = await func(self, message, ctx)
|
||||
|
||||
if AnyType not in return_types and type(return_value) not in return_types:
|
||||
if strict:
|
||||
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
else:
|
||||
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
|
||||
return return_value
|
||||
|
||||
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper)
|
||||
wrapper_handler.target_types = list(target_types)
|
||||
wrapper_handler.produces_types = list(return_types)
|
||||
wrapper_handler.is_message_handler = True
|
||||
wrapper_handler.router = lambda _message, _ctx: (_ctx.is_rpc) and (match(_message, _ctx) if match else True)
|
||||
|
||||
return wrapper_handler
|
||||
|
||||
if func is None and not callable(func):
|
||||
return decorator
|
||||
elif callable(func):
|
||||
return decorator(func)
|
||||
else:
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
|
||||
class RoutedAgent(BaseAgent):
|
||||
"""A base class for agents that route messages to handlers based on the type of the message
|
||||
and optional matching functions.
|
||||
|
||||
To create a routed agent, subclass this class and add message handlers as methods decorated with
|
||||
either :func:`event` or :func:`rpc` decorator.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dataclasses import dataclass
|
||||
from agentdhal_core import MessageContext
|
||||
from agentdhal_core import RoutedAgent, event, rpc
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageWithContent:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Response:
|
||||
pass
|
||||
|
||||
|
||||
class MyAgent(RoutedAgent):
|
||||
def __init__(self):
|
||||
super().__init__("MyAgent")
|
||||
|
||||
@event
|
||||
async def handle_event_message(self, message: Message, ctx: MessageContext) -> None:
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(MessageWithContent("event handled"), ctx.topic_id)
|
||||
|
||||
@rpc(match=lambda message, ctx: message.content == "special") # type: ignore
|
||||
async def handle_special_rpc_message(self, message: MessageWithContent, ctx: MessageContext) -> Response:
|
||||
return Response()
|
||||
"""
|
||||
|
||||
def __init__(self, description: str) -> None:
|
||||
# Self is already bound to the handlers
|
||||
self._handlers: DefaultDict[
|
||||
Type[Any],
|
||||
List[MessageHandler[RoutedAgent, Any, Any]],
|
||||
] = DefaultDict(list)
|
||||
|
||||
handlers = self._discover_handlers()
|
||||
for message_handler in handlers:
|
||||
for target_type in message_handler.target_types:
|
||||
self._handlers[target_type].append(message_handler)
|
||||
|
||||
super().__init__(description)
|
||||
|
||||
async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None:
|
||||
"""Handle a message by routing it to the appropriate message handler.
|
||||
Do not override this method in subclasses. Instead, add message handlers as methods decorated with
|
||||
either the :func:`event` or :func:`rpc` decorator."""
|
||||
key_type: Type[Any] = type(message) # type: ignore
|
||||
handlers = self._handlers.get(key_type) # type: ignore
|
||||
if handlers is not None:
|
||||
# Iterate over all handlers for this matching message type.
|
||||
# Call the first handler whose router returns True and then return the result.
|
||||
for h in handlers:
|
||||
if h.router(message, ctx):
|
||||
return await h(self, message, ctx)
|
||||
return await self.on_unhandled_message(message, ctx) # type: ignore
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
"""Called when a message is received that does not have a matching message handler.
|
||||
The default implementation logs an info message."""
|
||||
logger.info(f"Unhandled message: {message}")
|
||||
|
||||
@classmethod
|
||||
def _discover_handlers(cls) -> Sequence[MessageHandler[Any, Any, Any]]:
|
||||
handlers: List[MessageHandler[Any, Any, Any]] = []
|
||||
for attr in dir(cls):
|
||||
if callable(getattr(cls, attr, None)):
|
||||
# Since we are getting it from the class, self is not bound
|
||||
handler = getattr(cls, attr)
|
||||
if hasattr(handler, "is_message_handler"):
|
||||
handlers.append(cast(MessageHandler[Any, Any, Any], handler))
|
||||
return handlers
|
||||
|
||||
@classmethod
|
||||
def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]:
|
||||
# TODO handle deduplication
|
||||
handlers = cls._discover_handlers()
|
||||
types: List[Tuple[Type[Any], List[MessageSerializer[Any]]]] = []
|
||||
types.extend(cls.internal_extra_handles_types)
|
||||
for handler in handlers:
|
||||
for t in handler.target_types:
|
||||
# TODO: support different serializers
|
||||
serializers = try_get_known_serializers_for_type(t)
|
||||
if len(serializers) == 0:
|
||||
raise ValueError(f"No serializers found for type {t}.")
|
||||
|
||||
types.append((t, try_get_known_serializers_for_type(t)))
|
||||
return types
|
||||
78
agent_dhal/agentdhal_core/_runtime_impl_helpers.py
Normal file
78
agent_dhal/agentdhal_core/_runtime_impl_helpers.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from collections import defaultdict
|
||||
from typing import Awaitable, Callable, DefaultDict, List, Sequence, Set
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_type import AgentType
|
||||
from ._subscription import Subscription
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
async def get_impl(
|
||||
*,
|
||||
id_or_type: AgentId | AgentType | str,
|
||||
key: str,
|
||||
lazy: bool,
|
||||
instance_getter: Callable[[AgentId], Awaitable[Agent]],
|
||||
) -> AgentId:
|
||||
if isinstance(id_or_type, AgentId):
|
||||
if not lazy:
|
||||
await instance_getter(id_or_type)
|
||||
|
||||
return id_or_type
|
||||
|
||||
type_str = id_or_type if isinstance(id_or_type, str) else id_or_type.type
|
||||
id = AgentId(type_str, key)
|
||||
if not lazy:
|
||||
await instance_getter(id)
|
||||
|
||||
return id
|
||||
|
||||
|
||||
class SubscriptionManager:
|
||||
def __init__(self) -> None:
|
||||
self._subscriptions: List[Subscription] = []
|
||||
self._seen_topics: Set[TopicId] = set()
|
||||
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[Subscription]:
|
||||
return self._subscriptions
|
||||
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
# Check if the subscription already exists
|
||||
if any(sub == subscription for sub in self._subscriptions):
|
||||
raise ValueError("Subscription already exists")
|
||||
|
||||
self._subscriptions.append(subscription)
|
||||
self._rebuild_subscriptions(self._seen_topics)
|
||||
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
# Check if the subscription exists
|
||||
if not any(sub.id == id for sub in self._subscriptions):
|
||||
raise ValueError("Subscription does not exist")
|
||||
|
||||
def is_not_sub(x: Subscription) -> bool:
|
||||
return x.id != id
|
||||
|
||||
self._subscriptions = list(filter(is_not_sub, self._subscriptions))
|
||||
|
||||
# Rebuild the subscriptions
|
||||
self._rebuild_subscriptions(self._seen_topics)
|
||||
|
||||
async def get_subscribed_recipients(self, topic: TopicId) -> List[AgentId]:
|
||||
if topic not in self._seen_topics:
|
||||
self._build_for_new_topic(topic)
|
||||
return self._subscribed_recipients[topic]
|
||||
|
||||
# TODO: optimize this...
|
||||
def _rebuild_subscriptions(self, topics: Set[TopicId]) -> None:
|
||||
self._subscribed_recipients.clear()
|
||||
for topic in topics:
|
||||
self._build_for_new_topic(topic)
|
||||
|
||||
def _build_for_new_topic(self, topic: TopicId) -> None:
|
||||
self._seen_topics.add(topic)
|
||||
for subscription in self._subscriptions:
|
||||
if subscription.is_match(topic):
|
||||
self._subscribed_recipients[topic].append(subscription.map_to_agent(topic))
|
||||
258
agent_dhal/agentdhal_core/_serialization.py
Normal file
258
agent_dhal/agentdhal_core/_serialization.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable
|
||||
|
||||
from google.protobuf import any_pb2
|
||||
from google.protobuf.message import Message
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._type_helpers import is_union
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MessageSerializer(Protocol[T]):
|
||||
@property
|
||||
def data_content_type(self) -> str: ...
|
||||
|
||||
@property
|
||||
def type_name(self) -> str: ...
|
||||
|
||||
def deserialize(self, payload: bytes) -> T: ...
|
||||
|
||||
def serialize(self, message: T) -> bytes: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IsDataclass(Protocol):
|
||||
# as already noted in comments, checking for this attribute is currently
|
||||
# the most reliable way to ascertain that something is a dataclass
|
||||
__dataclass_fields__: ClassVar[Dict[str, Any]]
|
||||
|
||||
|
||||
def is_dataclass(cls: type[Any]) -> bool:
|
||||
return hasattr(cls, "__dataclass_fields__")
|
||||
|
||||
|
||||
def has_nested_dataclass(cls: type[IsDataclass]) -> bool:
|
||||
# iterate fields and check if any of them are dataclasses
|
||||
return any(is_dataclass(f.type) for f in cls.__dataclass_fields__.values())
|
||||
|
||||
|
||||
def contains_a_union(cls: type[IsDataclass]) -> bool:
|
||||
return any(is_union(f.type) for f in cls.__dataclass_fields__.values())
|
||||
|
||||
|
||||
def has_nested_base_model(cls: type[IsDataclass]) -> bool:
|
||||
for f in fields(cls):
|
||||
field_type = f.type
|
||||
# Resolve forward references and other annotations
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
# If the field type is directly a subclass of BaseModel
|
||||
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
return True
|
||||
|
||||
# If the field type is a generic type like List[BaseModel], Tuple[BaseModel, ...], etc.
|
||||
if origin is not None and args:
|
||||
for arg in args:
|
||||
# Recursively check the argument types
|
||||
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
||||
return True
|
||||
elif get_origin(arg) is not None:
|
||||
# Handle nested generics like List[List[BaseModel]]
|
||||
if has_nested_base_model_in_type(arg):
|
||||
return True
|
||||
# Handle Union types
|
||||
elif args:
|
||||
for arg in args:
|
||||
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
||||
return True
|
||||
elif get_origin(arg) is not None:
|
||||
if has_nested_base_model_in_type(arg):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def has_nested_base_model_in_type(tp: Any) -> bool:
|
||||
"""Helper function to check if a type or its arguments is a BaseModel subclass."""
|
||||
origin = get_origin(tp)
|
||||
args = get_args(tp)
|
||||
|
||||
if isinstance(tp, type) and issubclass(tp, BaseModel):
|
||||
return True
|
||||
if origin is not None and args:
|
||||
for arg in args:
|
||||
if has_nested_base_model_in_type(arg):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
DataclassT = TypeVar("DataclassT", bound=IsDataclass)
|
||||
|
||||
JSON_DATA_CONTENT_TYPE = "application/json"
|
||||
"""JSON data content type"""
|
||||
|
||||
# TODO: what's the correct content type? There seems to be some disagreement over what it should be
|
||||
PROTOBUF_DATA_CONTENT_TYPE = "application/x-protobuf"
|
||||
"""Protobuf data content type"""
|
||||
|
||||
|
||||
class DataclassJsonMessageSerializer(MessageSerializer[DataclassT]):
|
||||
def __init__(self, cls: type[DataclassT]) -> None:
|
||||
if contains_a_union(cls):
|
||||
raise ValueError("Dataclass has a union type, which is not supported. To use a union, use a Pydantic model")
|
||||
|
||||
if has_nested_dataclass(cls) or has_nested_base_model(cls):
|
||||
raise ValueError(
|
||||
"Dataclass has nested dataclasses or base models, which are not supported. To use nested types, use a Pydantic model"
|
||||
)
|
||||
|
||||
self.cls = cls
|
||||
|
||||
@property
|
||||
def data_content_type(self) -> str:
|
||||
return JSON_DATA_CONTENT_TYPE
|
||||
|
||||
@property
|
||||
def type_name(self) -> str:
|
||||
return _type_name(self.cls)
|
||||
|
||||
def deserialize(self, payload: bytes) -> DataclassT:
|
||||
message_str = payload.decode("utf-8")
|
||||
return self.cls(**json.loads(message_str))
|
||||
|
||||
def serialize(self, message: DataclassT) -> bytes:
|
||||
return json.dumps(asdict(message)).encode("utf-8")
|
||||
|
||||
|
||||
PydanticT = TypeVar("PydanticT", bound=BaseModel)
|
||||
|
||||
|
||||
class PydanticJsonMessageSerializer(MessageSerializer[PydanticT]):
|
||||
def __init__(self, cls: type[PydanticT]) -> None:
|
||||
self.cls = cls
|
||||
|
||||
@property
|
||||
def data_content_type(self) -> str:
|
||||
return JSON_DATA_CONTENT_TYPE
|
||||
|
||||
@property
|
||||
def type_name(self) -> str:
|
||||
return _type_name(self.cls)
|
||||
|
||||
def deserialize(self, payload: bytes) -> PydanticT:
|
||||
message_str = payload.decode("utf-8")
|
||||
return self.cls.model_validate_json(message_str)
|
||||
|
||||
def serialize(self, message: PydanticT) -> bytes:
|
||||
return message.model_dump_json().encode("utf-8")
|
||||
|
||||
|
||||
ProtobufT = TypeVar("ProtobufT", bound=Message)
|
||||
|
||||
|
||||
# This class serializes to and from a google.protobuf.Any message that has been serialized to a string
|
||||
class ProtobufMessageSerializer(MessageSerializer[ProtobufT]):
|
||||
def __init__(self, cls: type[ProtobufT]) -> None:
|
||||
self.cls = cls
|
||||
|
||||
@property
|
||||
def data_content_type(self) -> str:
|
||||
return PROTOBUF_DATA_CONTENT_TYPE
|
||||
|
||||
@property
|
||||
def type_name(self) -> str:
|
||||
return _type_name(self.cls)
|
||||
|
||||
def deserialize(self, payload: bytes) -> ProtobufT:
|
||||
# Parse payload into a proto any
|
||||
any_proto = any_pb2.Any()
|
||||
any_proto.ParseFromString(payload)
|
||||
|
||||
destination_message = self.cls()
|
||||
|
||||
if not any_proto.Unpack(destination_message): # type: ignore
|
||||
raise ValueError(f"Failed to unpack payload into {self.cls}")
|
||||
|
||||
return destination_message
|
||||
|
||||
def serialize(self, message: ProtobufT) -> bytes:
|
||||
any_proto = any_pb2.Any()
|
||||
any_proto.Pack(message) # type: ignore
|
||||
return any_proto.SerializeToString()
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnknownPayload:
|
||||
type_name: str
|
||||
data_content_type: str
|
||||
payload: bytes
|
||||
|
||||
|
||||
def _type_name(cls: type[Any] | Any) -> str:
|
||||
# If cls is a protobuf, then we need to determine the descriptor
|
||||
if isinstance(cls, type):
|
||||
if issubclass(cls, Message):
|
||||
return cast(str, cls.DESCRIPTOR.full_name)
|
||||
elif isinstance(cls, Message):
|
||||
return cast(str, cls.DESCRIPTOR.full_name)
|
||||
|
||||
if isinstance(cls, type):
|
||||
return cls.__name__
|
||||
else:
|
||||
return cast(str, cls.__class__.__name__)
|
||||
|
||||
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
def try_get_known_serializers_for_type(cls: type[Any]) -> list[MessageSerializer[Any]]:
|
||||
""":meta private:"""
|
||||
|
||||
serializers: List[MessageSerializer[Any]] = []
|
||||
if issubclass(cls, BaseModel):
|
||||
serializers.append(PydanticJsonMessageSerializer(cls))
|
||||
elif is_dataclass(cls):
|
||||
serializers.append(DataclassJsonMessageSerializer(cls))
|
||||
elif issubclass(cls, Message):
|
||||
serializers.append(ProtobufMessageSerializer(cls))
|
||||
|
||||
return serializers
|
||||
|
||||
|
||||
class SerializationRegistry:
|
||||
""":meta private:"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# type_name, data_content_type -> serializer
|
||||
self._serializers: dict[tuple[str, str], MessageSerializer[Any]] = {}
|
||||
|
||||
def add_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
|
||||
if isinstance(serializer, Sequence):
|
||||
for c in serializer:
|
||||
self.add_serializer(c)
|
||||
return
|
||||
|
||||
self._serializers[(serializer.type_name, serializer.data_content_type)] = serializer
|
||||
|
||||
def deserialize(self, payload: bytes, *, type_name: str, data_content_type: str) -> Any:
|
||||
serializer = self._serializers.get((type_name, data_content_type))
|
||||
if serializer is None:
|
||||
return UnknownPayload(type_name, data_content_type, payload)
|
||||
|
||||
return serializer.deserialize(payload)
|
||||
|
||||
def serialize(self, message: Any, *, type_name: str, data_content_type: str) -> bytes:
|
||||
serializer = self._serializers.get((type_name, data_content_type))
|
||||
if serializer is None:
|
||||
raise ValueError(f"Unknown type {type_name} with content type {data_content_type}")
|
||||
|
||||
return serializer.serialize(message)
|
||||
|
||||
def is_registered(self, type_name: str, data_content_type: str) -> bool:
|
||||
return (type_name, data_content_type) in self._serializers
|
||||
|
||||
def type_name(self, message: Any) -> str:
|
||||
return _type_name(message)
|
||||
1029
agent_dhal/agentdhal_core/_single_threaded_agent_runtime.py
Normal file
1029
agent_dhal/agentdhal_core/_single_threaded_agent_runtime.py
Normal file
File diff suppressed because it is too large
Load Diff
65
agent_dhal/agentdhal_core/_subscription.py
Normal file
65
agent_dhal/agentdhal_core/_subscription.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Awaitable, Callable, Protocol, runtime_checkable
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Subscription(Protocol):
|
||||
"""Subscriptions define the topics that an agent is interested in."""
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
"""Get the ID of the subscription.
|
||||
|
||||
Implementations should return a unique ID for the subscription. Usually this is a UUID.
|
||||
|
||||
Returns:
|
||||
str: ID of the subscription.
|
||||
"""
|
||||
...
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Check if two subscriptions are equal.
|
||||
|
||||
Args:
|
||||
other (object): Other subscription to compare against.
|
||||
|
||||
Returns:
|
||||
bool: True if the subscriptions are equal, False otherwise.
|
||||
"""
|
||||
if not isinstance(other, Subscription):
|
||||
return False
|
||||
|
||||
return self.id == other.id
|
||||
|
||||
def is_match(self, topic_id: TopicId) -> bool:
|
||||
"""Check if a given topic_id matches the subscription.
|
||||
|
||||
Args:
|
||||
topic_id (TopicId): TopicId to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the topic_id matches the subscription, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
def map_to_agent(self, topic_id: TopicId) -> AgentId:
|
||||
"""Map a topic_id to an agent. Should only be called if `is_match` returns True for the given topic_id.
|
||||
|
||||
Args:
|
||||
topic_id (TopicId): TopicId to map.
|
||||
|
||||
Returns:
|
||||
AgentId: ID of the agent that should handle the topic_id.
|
||||
|
||||
Raises:
|
||||
CantHandleException: If the subscription cannot handle the topic_id.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# Helper alias to represent the lambdas used to define subscriptions
|
||||
UnboundSubscription = Callable[[], list[Subscription] | Awaitable[list[Subscription]]]
|
||||
33
agent_dhal/agentdhal_core/_subscription_context.py
Normal file
33
agent_dhal/agentdhal_core/_subscription_context.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, ClassVar, Generator
|
||||
|
||||
from ._agent_type import AgentType
|
||||
|
||||
|
||||
class SubscriptionInstantiationContext:
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError(
|
||||
"SubscriptionInstantiationContext cannot be instantiated. It is a static class that provides context management for subscription instantiation."
|
||||
)
|
||||
|
||||
_SUBSCRIPTION_CONTEXT_VAR: ClassVar[ContextVar[AgentType]] = ContextVar("_SUBSCRIPTION_CONTEXT_VAR")
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def populate_context(cls, ctx: AgentType) -> Generator[None, Any, None]:
|
||||
""":meta private:"""
|
||||
token = SubscriptionInstantiationContext._SUBSCRIPTION_CONTEXT_VAR.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
SubscriptionInstantiationContext._SUBSCRIPTION_CONTEXT_VAR.reset(token)
|
||||
|
||||
@classmethod
|
||||
def agent_type(cls) -> AgentType:
|
||||
try:
|
||||
return cls._SUBSCRIPTION_CONTEXT_VAR.get()
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"SubscriptionInstantiationContext.runtime() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
|
||||
) from e
|
||||
25
agent_dhal/agentdhal_core/_telemetry/__init__.py
Normal file
25
agent_dhal/agentdhal_core/_telemetry/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from ._genai import (
|
||||
trace_create_agent_span,
|
||||
trace_invoke_agent_span,
|
||||
trace_tool_span,
|
||||
)
|
||||
from ._propagation import (
|
||||
EnvelopeMetadata,
|
||||
TelemetryMetadataContainer,
|
||||
get_telemetry_envelope_metadata,
|
||||
get_telemetry_grpc_metadata,
|
||||
)
|
||||
from ._tracing import TraceHelper
|
||||
from ._tracing_config import MessageRuntimeTracingConfig
|
||||
|
||||
__all__ = [
|
||||
"EnvelopeMetadata",
|
||||
"get_telemetry_envelope_metadata",
|
||||
"get_telemetry_grpc_metadata",
|
||||
"TelemetryMetadataContainer",
|
||||
"TraceHelper",
|
||||
"MessageRuntimeTracingConfig",
|
||||
"trace_create_agent_span",
|
||||
"trace_invoke_agent_span",
|
||||
"trace_tool_span",
|
||||
]
|
||||
1
agent_dhal/agentdhal_core/_telemetry/_constants.py
Normal file
1
agent_dhal/agentdhal_core/_telemetry/_constants.py
Normal file
@@ -0,0 +1 @@
|
||||
NAMESPACE = "agentdhal"
|
||||
214
agent_dhal/agentdhal_core/_telemetry/_genai.py
Normal file
214
agent_dhal/agentdhal_core/_telemetry/_genai.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.trace import Span, SpanKind
|
||||
|
||||
from .._agent_instantiation import AgentInstantiationContext
|
||||
|
||||
# OpenTelemetry semantic convention constants for GenAI operations
|
||||
# Copied from opentelemetry-semantic-conventions to avoid dependency
|
||||
|
||||
# GenAI Agent attributes
|
||||
GEN_AI_AGENT_DESCRIPTION = "gen_ai.agent.description"
|
||||
GEN_AI_AGENT_ID = "gen_ai.agent.id"
|
||||
GEN_AI_AGENT_NAME = "gen_ai.agent.name"
|
||||
|
||||
# GenAI Operation attributes
|
||||
GEN_AI_OPERATION_NAME = "gen_ai.operation.name"
|
||||
GEN_AI_SYSTEM = "gen_ai.system"
|
||||
|
||||
# GenAI Tool attributes
|
||||
GEN_AI_TOOL_CALL_ID = "gen_ai.tool.call.id"
|
||||
GEN_AI_TOOL_DESCRIPTION = "gen_ai.tool.description"
|
||||
GEN_AI_TOOL_NAME = "gen_ai.tool.name"
|
||||
|
||||
# Error attributes
|
||||
ERROR_TYPE = "error.type"
|
||||
|
||||
|
||||
class GenAiOperationNameValues(Enum):
|
||||
"""Enum for GenAI operation name values."""
|
||||
|
||||
CHAT = "chat"
|
||||
CREATE_AGENT = "create_agent"
|
||||
EMBEDDINGS = "embeddings"
|
||||
EXECUTE_TOOL = "execute_tool"
|
||||
GENERATE_CONTENT = "generate_content"
|
||||
INVOKE_AGENT = "invoke_agent"
|
||||
TEXT_COMPLETION = "text_completion"
|
||||
|
||||
|
||||
# Constant for system name
|
||||
GENAI_SYSTEM_AUTOGEN = "agentdhal"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_tool_span(
|
||||
tool_name: str,
|
||||
*,
|
||||
tracer: Optional[trace.Tracer] = None,
|
||||
parent: Optional[Span] = None,
|
||||
tool_description: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
) -> Generator[Span, Any, None]:
|
||||
"""Context manager to create a span for tool execution following the
|
||||
OpenTelemetry Semantic conventions for generative AI systems.
|
||||
|
||||
See the GenAI semantic conventions documentation:
|
||||
`OpenTelemetry GenAI Semantic Conventions <https://opentelemetry.io/docs/specs/semconv/gen-ai/>`__
|
||||
|
||||
.. warning::
|
||||
|
||||
The GenAI Semantic Conventions are still in incubation and
|
||||
subject to changes in future releases.
|
||||
|
||||
|
||||
Args:
|
||||
tool_name (str): The name of the tool being executed.
|
||||
tracer (Optional[trace.Tracer]): The tracer to use for creating the span.
|
||||
parent (Optional[Span]): The parent span to link this span to.
|
||||
tool_description (Optional[str]): A description of the tool.
|
||||
tool_call_id (Optional[str]): A unique identifier for the tool call.
|
||||
"""
|
||||
if tracer is None:
|
||||
tracer = trace.get_tracer("agentdhal-core")
|
||||
span_attributes = {
|
||||
GEN_AI_OPERATION_NAME: GenAiOperationNameValues.EXECUTE_TOOL.value,
|
||||
GEN_AI_SYSTEM: GENAI_SYSTEM_AUTOGEN,
|
||||
GEN_AI_TOOL_NAME: tool_name,
|
||||
}
|
||||
if tool_description is not None:
|
||||
span_attributes[GEN_AI_TOOL_DESCRIPTION] = tool_description
|
||||
if tool_call_id is not None:
|
||||
span_attributes[GEN_AI_TOOL_CALL_ID] = tool_call_id
|
||||
with tracer.start_as_current_span(
|
||||
f"{GenAiOperationNameValues.EXECUTE_TOOL.value} {tool_name}",
|
||||
kind=SpanKind.INTERNAL,
|
||||
context=trace.set_span_in_context(parent) if parent else None,
|
||||
attributes=span_attributes,
|
||||
) as span:
|
||||
try:
|
||||
yield span
|
||||
except Exception as e:
|
||||
# Set the exception details on the span if an error occurs
|
||||
span.record_exception(e)
|
||||
span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
|
||||
span.set_attribute(ERROR_TYPE, type(e).__name__)
|
||||
raise
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_create_agent_span(
|
||||
agent_name: str,
|
||||
*,
|
||||
tracer: Optional[trace.Tracer] = None,
|
||||
parent: Optional[Span] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
agent_description: Optional[str] = None,
|
||||
) -> Generator[Span, Any, None]:
|
||||
"""Context manager to create a span for agent creation following the
|
||||
OpenTelemetry Semantic conventions for generative AI systems.
|
||||
|
||||
See the GenAI semantic conventions documentation:
|
||||
`OpenTelemetry GenAI Semantic Conventions <https://opentelemetry.io/docs/specs/semconv/gen-ai/>`__
|
||||
|
||||
.. warning::
|
||||
|
||||
The GenAI Semantic Conventions are still in incubation and
|
||||
subject to changes in future releases.
|
||||
|
||||
Args:
|
||||
agent_name (str): The name of the agent being created.
|
||||
tracer (Optional[trace.Tracer]): The tracer to use for creating the span.
|
||||
parent (Optional[Span]): The parent span to link this span to.
|
||||
agent_id (Optional[str]): The unique identifier for the agent.
|
||||
agent_description (Optional[str]): A description of the agent.
|
||||
"""
|
||||
if tracer is None:
|
||||
tracer = trace.get_tracer("agentdhal-core")
|
||||
span_attributes = {
|
||||
GEN_AI_OPERATION_NAME: GenAiOperationNameValues.CREATE_AGENT.value,
|
||||
GEN_AI_SYSTEM: GENAI_SYSTEM_AUTOGEN,
|
||||
GEN_AI_AGENT_NAME: agent_name,
|
||||
}
|
||||
if agent_id is None:
|
||||
# Try to see if we can get the agent ID from the current context
|
||||
try:
|
||||
agent_id = str(AgentInstantiationContext.current_agent_id())
|
||||
except RuntimeError:
|
||||
agent_id = None
|
||||
if agent_id is not None:
|
||||
span_attributes[GEN_AI_AGENT_ID] = agent_id
|
||||
if agent_description is not None:
|
||||
span_attributes[GEN_AI_AGENT_DESCRIPTION] = agent_description
|
||||
with tracer.start_as_current_span(
|
||||
f"{GenAiOperationNameValues.CREATE_AGENT.value} {agent_name}",
|
||||
kind=SpanKind.CLIENT,
|
||||
context=trace.set_span_in_context(parent) if parent else None,
|
||||
attributes=span_attributes,
|
||||
) as span:
|
||||
try:
|
||||
yield span
|
||||
except Exception as e:
|
||||
# Set the exception details on the span if an error occurs
|
||||
span.record_exception(e)
|
||||
span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
|
||||
span.set_attribute(ERROR_TYPE, type(e).__name__)
|
||||
raise
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_invoke_agent_span(
|
||||
agent_name: str,
|
||||
*,
|
||||
tracer: Optional[trace.Tracer] = None,
|
||||
parent: Optional[Span] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
agent_description: Optional[str] = None,
|
||||
) -> Generator[Span, Any, None]:
|
||||
"""Context manager to create a span for invoking an agent following the
|
||||
OpenTelemetry Semantic conventions for generative AI systems.
|
||||
|
||||
See the GenAI semantic conventions documentation:
|
||||
`OpenTelemetry GenAI Semantic Conventions <https://opentelemetry.io/docs/specs/semconv/gen-ai/>`__
|
||||
|
||||
.. warning::
|
||||
|
||||
The GenAI Semantic Conventions are still in incubation and
|
||||
subject to changes in future releases.
|
||||
|
||||
Args:
|
||||
agent_name (str): The name of the agent being invoked.
|
||||
tracer (Optional[trace.Tracer]): The tracer to use for creating the span.
|
||||
parent (Optional[Span]): The parent span to link this span to.
|
||||
agent_id (Optional[str]): The unique identifier for the agent.
|
||||
agent_description (Optional[str]): A description of the agent.
|
||||
"""
|
||||
if tracer is None:
|
||||
tracer = trace.get_tracer("agentdhal-core")
|
||||
span_attributes = {
|
||||
GEN_AI_OPERATION_NAME: GenAiOperationNameValues.INVOKE_AGENT.value,
|
||||
GEN_AI_SYSTEM: GENAI_SYSTEM_AUTOGEN,
|
||||
GEN_AI_AGENT_NAME: agent_name,
|
||||
}
|
||||
if agent_id is not None:
|
||||
span_attributes[GEN_AI_AGENT_ID] = agent_id
|
||||
if agent_description is not None:
|
||||
span_attributes[GEN_AI_AGENT_DESCRIPTION] = agent_description
|
||||
with tracer.start_as_current_span(
|
||||
f"{GenAiOperationNameValues.INVOKE_AGENT.value} {agent_name}",
|
||||
kind=SpanKind.CLIENT,
|
||||
context=trace.set_span_in_context(parent) if parent else None,
|
||||
attributes=span_attributes,
|
||||
) as span:
|
||||
try:
|
||||
yield span
|
||||
except Exception as e:
|
||||
# Set the exception details on the span if an error occurs
|
||||
span.record_exception(e)
|
||||
span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
|
||||
span.set_attribute(ERROR_TYPE, type(e).__name__)
|
||||
raise
|
||||
127
agent_dhal/agentdhal_core/_telemetry/_propagation.py
Normal file
127
agent_dhal/agentdhal_core/_telemetry/_propagation.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Mapping, Optional, Sequence
|
||||
|
||||
from opentelemetry.context import Context
|
||||
from opentelemetry.propagate import extract
|
||||
from opentelemetry.trace import Link, get_current_span
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class EnvelopeMetadata:
|
||||
"""Metadata for an envelope."""
|
||||
|
||||
traceparent: Optional[str] = None
|
||||
tracestate: Optional[str] = None
|
||||
links: Optional[Sequence[Link]] = None
|
||||
|
||||
|
||||
def _get_carrier_for_envelope_metadata(envelope_metadata: EnvelopeMetadata) -> Dict[str, str]:
|
||||
carrier: Dict[str, str] = {}
|
||||
if envelope_metadata.traceparent is not None:
|
||||
carrier["traceparent"] = envelope_metadata.traceparent
|
||||
if envelope_metadata.tracestate is not None:
|
||||
carrier["tracestate"] = envelope_metadata.tracestate
|
||||
return carrier
|
||||
|
||||
|
||||
def get_telemetry_envelope_metadata() -> EnvelopeMetadata:
|
||||
"""
|
||||
Retrieves the telemetry envelope metadata.
|
||||
|
||||
Returns:
|
||||
EnvelopeMetadata: The envelope metadata containing the traceparent and tracestate.
|
||||
"""
|
||||
carrier: Dict[str, str] = {}
|
||||
TraceContextTextMapPropagator().inject(carrier)
|
||||
return EnvelopeMetadata(
|
||||
traceparent=carrier.get("traceparent"),
|
||||
tracestate=carrier.get("tracestate"),
|
||||
)
|
||||
|
||||
|
||||
def _get_carrier_for_remote_call_metadata(remote_call_metadata: Mapping[str, str]) -> Dict[str, str]:
|
||||
carrier: Dict[str, str] = {}
|
||||
traceparent = remote_call_metadata.get("traceparent")
|
||||
tracestate = remote_call_metadata.get("tracestate")
|
||||
if traceparent:
|
||||
carrier["traceparent"] = traceparent
|
||||
if tracestate:
|
||||
carrier["tracestate"] = tracestate
|
||||
return carrier
|
||||
|
||||
|
||||
def get_telemetry_grpc_metadata(existingMetadata: Optional[Mapping[str, str]] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Retrieves the telemetry gRPC metadata.
|
||||
|
||||
Args:
|
||||
existingMetadata (Optional[Mapping[str, str]]): The existing metadata to include in the gRPC metadata.
|
||||
|
||||
Returns:
|
||||
Mapping[str, str]: The gRPC metadata containing the traceparent and tracestate.
|
||||
"""
|
||||
carrier: Dict[str, str] = {}
|
||||
TraceContextTextMapPropagator().inject(carrier)
|
||||
traceparent = carrier.get("traceparent")
|
||||
tracestate = carrier.get("tracestate")
|
||||
metadata: Dict[str, str] = {}
|
||||
if existingMetadata is not None:
|
||||
for key, value in existingMetadata.items():
|
||||
metadata[key] = value
|
||||
if traceparent is not None:
|
||||
metadata["traceparent"] = traceparent
|
||||
if tracestate is not None:
|
||||
metadata["tracestate"] = tracestate
|
||||
return metadata
|
||||
|
||||
|
||||
TelemetryMetadataContainer = Optional[EnvelopeMetadata] | Mapping[str, str]
|
||||
|
||||
|
||||
def get_telemetry_context(metadata: TelemetryMetadataContainer) -> Context:
|
||||
"""
|
||||
Retrieves the telemetry context from the given metadata.
|
||||
|
||||
Args:
|
||||
metadata (Optional[EnvelopeMetadata]): The metadata containing the telemetry context.
|
||||
|
||||
Returns:
|
||||
Context: The telemetry context extracted from the metadata, or an empty context if the metadata is None.
|
||||
"""
|
||||
if metadata is None:
|
||||
return Context()
|
||||
elif isinstance(metadata, EnvelopeMetadata):
|
||||
return extract(_get_carrier_for_envelope_metadata(metadata))
|
||||
elif hasattr(metadata, "__getitem__"):
|
||||
return extract(_get_carrier_for_remote_call_metadata(metadata))
|
||||
else:
|
||||
raise ValueError(f"Unknown metadata type: {type(metadata)}")
|
||||
|
||||
|
||||
def get_telemetry_links(
|
||||
metadata: TelemetryMetadataContainer,
|
||||
) -> Optional[Sequence[Link]]:
|
||||
"""
|
||||
Retrieves the telemetry links from the given metadata.
|
||||
|
||||
Args:
|
||||
metadata (Optional[EnvelopeMetadata]): The metadata containing the telemetry links.
|
||||
|
||||
Returns:
|
||||
Optional[Sequence[Link]]: The telemetry links extracted from the metadata, or None if there are no links.
|
||||
"""
|
||||
if metadata is None:
|
||||
return None
|
||||
elif isinstance(metadata, EnvelopeMetadata):
|
||||
context = extract(_get_carrier_for_envelope_metadata(metadata))
|
||||
elif hasattr(metadata, "__getitem__"):
|
||||
context = extract(_get_carrier_for_remote_call_metadata(metadata))
|
||||
else:
|
||||
return None
|
||||
# Retrieve the extracted SpanContext from the context.
|
||||
linked_span = get_current_span(context)
|
||||
# Use the linked span to get the SpanContext.
|
||||
span_context = linked_span.get_span_context()
|
||||
# Create a Link object using the SpanContext.
|
||||
return [Link(span_context)]
|
||||
99
agent_dhal/agentdhal_core/_telemetry/_tracing.py
Normal file
99
agent_dhal/agentdhal_core/_telemetry/_tracing.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import contextlib
|
||||
import os
|
||||
from typing import Dict, Generic, Iterator, Optional
|
||||
|
||||
from opentelemetry.trace import NoOpTracerProvider, Span, SpanKind, TracerProvider, get_tracer_provider
|
||||
from opentelemetry.util import types
|
||||
|
||||
from ._propagation import TelemetryMetadataContainer, get_telemetry_links
|
||||
from ._tracing_config import Destination, ExtraAttributes, Operation, TracingConfig
|
||||
|
||||
|
||||
class TraceHelper(Generic[Operation, Destination, ExtraAttributes]):
|
||||
"""
|
||||
TraceHelper is a utility class to assist with tracing operations using OpenTelemetry.
|
||||
|
||||
This class provides a context manager `trace_block` to create and manage spans for tracing operations,
|
||||
following semantic conventions and supporting nested spans through metadata contexts.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tracer_provider: TracerProvider | None,
|
||||
instrumentation_builder_config: TracingConfig[Operation, Destination, ExtraAttributes],
|
||||
) -> None:
|
||||
self.instrumentation_builder_config = instrumentation_builder_config
|
||||
|
||||
disable_runtime_tracing = os.environ.get("AUTOGEN_DISABLE_RUNTIME_TRACING") == "true"
|
||||
if disable_runtime_tracing:
|
||||
self.tracer_provider: TracerProvider = NoOpTracerProvider()
|
||||
self.tracer = self.tracer_provider.get_tracer(f"agentdhal {instrumentation_builder_config.name}")
|
||||
return
|
||||
|
||||
# Evaluate in order: first try tracer_provider param, then get_tracer_provider(), finally fallback to NoOp
|
||||
# This allows for nested tracing with a default tracer provided by the user
|
||||
self.tracer_provider = tracer_provider or get_tracer_provider() or NoOpTracerProvider()
|
||||
self.tracer = self.tracer_provider.get_tracer(f"agentdhal {instrumentation_builder_config.name}")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def trace_block(
|
||||
self,
|
||||
operation: Operation,
|
||||
destination: Destination,
|
||||
parent: Optional[TelemetryMetadataContainer],
|
||||
*,
|
||||
extraAttributes: ExtraAttributes | None = None,
|
||||
kind: Optional[SpanKind] = None,
|
||||
attributes: Optional[types.Attributes] = None,
|
||||
start_time: Optional[int] = None,
|
||||
record_exception: bool = True,
|
||||
set_status_on_exception: bool = True,
|
||||
end_on_exit: bool = True,
|
||||
) -> Iterator[Span]:
|
||||
"""
|
||||
Thin wrapper on top of start_as_current_span.
|
||||
1. It helps us follow semantic conventions
|
||||
2. It helps us get contexts from metadata so we can get nested spans
|
||||
|
||||
Args:
|
||||
operation (MessagingOperation): The messaging operation being performed.
|
||||
destination (MessagingDestination): The messaging destination being used.
|
||||
parent Optional[TelemetryMetadataContainer]: The parent telemetry metadta context
|
||||
kind (SpanKind, optional): The kind of span. If not provided, it maps to PRODUCER or CONSUMER depending on the operation.
|
||||
extraAttributes (ExtraAttributes, optional): Additional defined attributes for the span. Defaults to None.
|
||||
attributes (Optional[types.Attributes], optional): Additional non-defined attributes for the span. Defaults to None.
|
||||
start_time (Optional[int], optional): The start time of the span. Defaults to None.
|
||||
record_exception (bool, optional): Whether to record exceptions. Defaults to True.
|
||||
set_status_on_exception (bool, optional): Whether to set the status on exception. Defaults to True.
|
||||
end_on_exit (bool, optional): Whether to end the span on exit. Defaults to True.
|
||||
|
||||
Yields:
|
||||
Iterator[Span]: The span object.
|
||||
|
||||
"""
|
||||
span_name = self.instrumentation_builder_config.get_span_name(operation, destination)
|
||||
span_kind = kind or self.instrumentation_builder_config.get_span_kind(operation)
|
||||
# context = get_telemetry_context(parent) if parent else None
|
||||
context = None # TODO: we may need to remove other code for using custom context.
|
||||
links = get_telemetry_links(parent) if parent else None
|
||||
attributes_with_defaults: Dict[str, types.AttributeValue] = {}
|
||||
for key, value in (attributes or {}).items():
|
||||
attributes_with_defaults[key] = value
|
||||
instrumentation_attributes = self.instrumentation_builder_config.build_attributes(
|
||||
operation, destination, extraAttributes
|
||||
)
|
||||
for key, value in instrumentation_attributes.items():
|
||||
attributes_with_defaults[key] = value
|
||||
with self.tracer.start_as_current_span(
|
||||
span_name,
|
||||
context,
|
||||
span_kind,
|
||||
attributes_with_defaults,
|
||||
links,
|
||||
start_time,
|
||||
record_exception,
|
||||
set_status_on_exception,
|
||||
end_on_exit,
|
||||
) as span:
|
||||
yield span
|
||||
201
agent_dhal/agentdhal_core/_telemetry/_tracing_config.py
Normal file
201
agent_dhal/agentdhal_core/_telemetry/_tracing_config.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Generic, List, Literal, TypedDict, TypeVar, Union
|
||||
|
||||
from opentelemetry.trace import SpanKind
|
||||
from opentelemetry.util import types
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from .._agent_id import AgentId
|
||||
from .._topic import TopicId
|
||||
from ._constants import NAMESPACE
|
||||
|
||||
logger = logging.getLogger("agentdhal_core")
|
||||
event_logger = logging.getLogger("agentdhal_core.events")
|
||||
|
||||
Operation = TypeVar("Operation", bound=str)
|
||||
Destination = TypeVar("Destination")
|
||||
ExtraAttributes = TypeVar("ExtraAttributes")
|
||||
|
||||
|
||||
class TracingConfig(ABC, Generic[Operation, Destination, ExtraAttributes]):
|
||||
"""
|
||||
A protocol that defines the configuration for instrumentation.
|
||||
|
||||
This protocol specifies the required properties and methods that any
|
||||
instrumentation configuration class must implement. It includes a
|
||||
property to get the name of the module being instrumented and a method
|
||||
to build attributes for the instrumentation configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Returns:
|
||||
The name of the module that is being instrumented.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def build_attributes(
|
||||
self,
|
||||
operation: Operation,
|
||||
destination: Destination,
|
||||
extraAttributes: ExtraAttributes | None,
|
||||
) -> Dict[str, types.AttributeValue]:
|
||||
"""
|
||||
Builds the attributes for the instrumentation configuration.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: The attributes for the instrumentation configuration.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_span_name(
|
||||
self,
|
||||
operation: Operation,
|
||||
destination: Destination,
|
||||
) -> str:
|
||||
"""
|
||||
Returns the span name based on the given operation and destination.
|
||||
|
||||
Parameters:
|
||||
operation (MessagingOperation): The messaging operation.
|
||||
destination (Optional[MessagingDestination]): The messaging destination.
|
||||
|
||||
Returns:
|
||||
str: The span name.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_span_kind(
|
||||
self,
|
||||
operation: Operation,
|
||||
) -> SpanKind:
|
||||
"""
|
||||
Determines the span kind based on the given messaging operation.
|
||||
|
||||
Parameters:
|
||||
operation (MessagingOperation): The messaging operation.
|
||||
|
||||
Returns:
|
||||
SpanKind: The span kind based on the messaging operation.
|
||||
"""
|
||||
|
||||
|
||||
class ExtraMessageRuntimeAttributes(TypedDict):
|
||||
message_size: NotRequired[int]
|
||||
message_type: NotRequired[str]
|
||||
|
||||
|
||||
MessagingDestination = Union[AgentId, TopicId, str, None]
|
||||
MessagingOperation = Literal["create", "send", "publish", "receive", "intercept", "process", "ack"]
|
||||
|
||||
|
||||
class MessageRuntimeTracingConfig(
|
||||
TracingConfig[MessagingOperation, MessagingDestination, ExtraMessageRuntimeAttributes]
|
||||
):
|
||||
"""
|
||||
A class that defines the configuration for message runtime instrumentation.
|
||||
|
||||
This class implements the TracingConfig protocol and provides
|
||||
the name of the module being instrumented and the attributes for the
|
||||
instrumentation configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, runtime_name: str) -> None:
|
||||
self._runtime_name = runtime_name
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._runtime_name
|
||||
|
||||
def build_attributes(
|
||||
self,
|
||||
operation: MessagingOperation,
|
||||
destination: MessagingDestination,
|
||||
extraAttributes: ExtraMessageRuntimeAttributes | None,
|
||||
) -> Dict[str, types.AttributeValue]:
|
||||
attrs: Dict[str, types.AttributeValue] = {
|
||||
"messaging.operation": self._get_operation_type(operation),
|
||||
"messaging.destination": self._get_destination_str(destination),
|
||||
}
|
||||
if extraAttributes:
|
||||
# TODO: Make this more pythonic?
|
||||
if "message_size" in extraAttributes:
|
||||
attrs["messaging.message.envelope.size"] = extraAttributes["message_size"]
|
||||
if "message_type" in extraAttributes:
|
||||
attrs["messaging.message.type"] = extraAttributes["message_type"]
|
||||
return attrs
|
||||
|
||||
def get_span_name(
|
||||
self,
|
||||
operation: MessagingOperation,
|
||||
destination: MessagingDestination,
|
||||
) -> str:
|
||||
"""
|
||||
Returns the span name based on the given operation and destination.
|
||||
Semantic Conventions - https://opentelemetry.io/docs/specs/semconv/messaging/messaging-spans/#span-name
|
||||
|
||||
Parameters:
|
||||
operation (MessagingOperation): The messaging operation.
|
||||
destination (Optional[MessagingDestination]): The messaging destination.
|
||||
|
||||
Returns:
|
||||
str: The span name.
|
||||
"""
|
||||
span_parts: List[str] = [operation]
|
||||
destination_str = self._get_destination_str(destination)
|
||||
if destination_str:
|
||||
span_parts.append(destination_str)
|
||||
span_name = " ".join(span_parts)
|
||||
return f"{NAMESPACE} {span_name}"
|
||||
|
||||
def get_span_kind(
|
||||
self,
|
||||
operation: MessagingOperation,
|
||||
) -> SpanKind:
|
||||
"""
|
||||
Determines the span kind based on the given messaging operation.
|
||||
Semantic Conventions - https://opentelemetry.io/docs/specs/semconv/messaging/messaging-spans/#span-kind
|
||||
|
||||
Parameters:
|
||||
operation (MessagingOperation): The messaging operation.
|
||||
|
||||
Returns:
|
||||
SpanKind: The span kind based on the messaging operation.
|
||||
"""
|
||||
if operation in ["create", "send", "publish"]:
|
||||
return SpanKind.PRODUCER
|
||||
elif operation in ["receive", "intercept", "process", "ack"]:
|
||||
return SpanKind.CONSUMER
|
||||
else:
|
||||
return SpanKind.CLIENT
|
||||
|
||||
# TODO: Use stringified convention
|
||||
def _get_destination_str(self, destination: MessagingDestination) -> str:
|
||||
if isinstance(destination, AgentId):
|
||||
return f"{destination.type}.({destination.key})-A"
|
||||
elif isinstance(destination, TopicId):
|
||||
return f"{destination.type}.({destination.source})-T"
|
||||
elif isinstance(destination, str):
|
||||
return destination
|
||||
elif destination is None:
|
||||
return ""
|
||||
else:
|
||||
raise ValueError(f"Unknown destination type: {type(destination)}")
|
||||
|
||||
def _get_operation_type(self, operation: MessagingOperation) -> str:
|
||||
if operation in ["send", "publish"]:
|
||||
return "publish"
|
||||
if operation in ["create"]:
|
||||
return "create"
|
||||
elif operation in ["receive", "intercept", "ack"]:
|
||||
return "receive"
|
||||
elif operation in ["process"]:
|
||||
return "process"
|
||||
else:
|
||||
return "Unknown"
|
||||
47
agent_dhal/agentdhal_core/_topic.py
Normal file
47
agent_dhal/agentdhal_core/_topic.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
def is_valid_topic_type(value: str) -> bool:
|
||||
return bool(re.match(r"^[\w\-\.\:\=]+\Z", value))
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class TopicId:
|
||||
"""
|
||||
TopicId defines the scope of a broadcast message. In essence, agent runtime implements a publish-subscribe model through its broadcast API: when publishing a message, the topic must be specified.
|
||||
|
||||
See here for more information: :ref:`topic_and_subscription_topic`
|
||||
"""
|
||||
|
||||
type: str
|
||||
"""Type of the event that this topic_id contains. Adhere's to the cloud event spec.
|
||||
|
||||
Must match the pattern: ^[\\w\\-\\.\\:\\=]+\\Z
|
||||
|
||||
Learn more here: https://github.com/cloudevents/spec/blob/main/cloudevents/spec.md#type
|
||||
"""
|
||||
|
||||
source: str
|
||||
"""Identifies the context in which an event happened. Adhere's to the cloud event spec.
|
||||
|
||||
Learn more here: https://github.com/cloudevents/spec/blob/main/cloudevents/spec.md#source-1
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if is_valid_topic_type(self.type) is False:
|
||||
raise ValueError(f"Invalid topic type: {self.type}. Must match the pattern: ^[\\w\\-\\.\\:\\=]+\\Z")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.type}/{self.source}"
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, topic_id: str) -> Self:
|
||||
"""Convert a string of the format ``type/source`` into a TopicId"""
|
||||
items = topic_id.split("/", maxsplit=1)
|
||||
if len(items) != 2:
|
||||
raise ValueError(f"Invalid topic id: {topic_id}")
|
||||
type, source = items[0], items[1]
|
||||
return cls(type, source)
|
||||
33
agent_dhal/agentdhal_core/_type_helpers.py
Normal file
33
agent_dhal/agentdhal_core/_type_helpers.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from collections.abc import Sequence
|
||||
from types import NoneType, UnionType
|
||||
from typing import Any, Optional, Type, Union, get_args, get_origin
|
||||
|
||||
|
||||
def is_union(t: object) -> bool:
|
||||
origin = get_origin(t)
|
||||
return origin is Union or origin is UnionType
|
||||
|
||||
|
||||
def is_optional(t: object) -> bool:
|
||||
origin = get_origin(t)
|
||||
return origin is Optional
|
||||
|
||||
|
||||
# Special type to avoid the 3.10 vs 3.11+ difference of typing._SpecialForm vs typing.Any
|
||||
class AnyType:
|
||||
pass
|
||||
|
||||
|
||||
def get_types(t: object) -> Sequence[Type[Any]] | None:
|
||||
if is_union(t):
|
||||
return get_args(t)
|
||||
elif is_optional(t):
|
||||
return tuple(list(get_args(t)) + [NoneType])
|
||||
elif t is Any:
|
||||
return (AnyType,)
|
||||
elif isinstance(t, type):
|
||||
return (t,)
|
||||
elif isinstance(t, NoneType):
|
||||
return (NoneType,)
|
||||
else:
|
||||
return None
|
||||
69
agent_dhal/agentdhal_core/_type_prefix_subscription.py
Normal file
69
agent_dhal/agentdhal_core/_type_prefix_subscription.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import uuid
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_type import AgentType
|
||||
from ._subscription import Subscription
|
||||
from ._topic import TopicId
|
||||
from .exceptions import CantHandleException
|
||||
|
||||
|
||||
class TypePrefixSubscription(Subscription):
|
||||
"""This subscription matches on topics based on a prefix of the type and maps to agents using the source of the topic as the agent key.
|
||||
|
||||
This subscription causes each source to have its own agent instance.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core import TypePrefixSubscription
|
||||
|
||||
subscription = TypePrefixSubscription(topic_type_prefix="t1", agent_type="a1")
|
||||
|
||||
In this case:
|
||||
|
||||
- A topic_id with type `t1` and source `s1` will be handled by an agent of type `a1` with key `s1`
|
||||
- A topic_id with type `t1` and source `s2` will be handled by an agent of type `a1` with key `s2`.
|
||||
- A topic_id with type `t1SUFFIX` and source `s2` will be handled by an agent of type `a1` with key `s2`.
|
||||
|
||||
Args:
|
||||
topic_type_prefix (str): Topic type prefix to match against
|
||||
agent_type (str): Agent type to handle this subscription
|
||||
"""
|
||||
|
||||
def __init__(self, topic_type_prefix: str, agent_type: str | AgentType, id: str | None = None):
|
||||
self._topic_type_prefix = topic_type_prefix
|
||||
if isinstance(agent_type, AgentType):
|
||||
self._agent_type = agent_type.type
|
||||
else:
|
||||
self._agent_type = agent_type
|
||||
self._id = id or str(uuid.uuid4())
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def topic_type_prefix(self) -> str:
|
||||
return self._topic_type_prefix
|
||||
|
||||
@property
|
||||
def agent_type(self) -> str:
|
||||
return self._agent_type
|
||||
|
||||
def is_match(self, topic_id: TopicId) -> bool:
|
||||
return topic_id.type.startswith(self._topic_type_prefix)
|
||||
|
||||
def map_to_agent(self, topic_id: TopicId) -> AgentId:
|
||||
if not self.is_match(topic_id):
|
||||
raise CantHandleException("TopicId does not match the subscription")
|
||||
|
||||
return AgentId(type=self._agent_type, key=topic_id.source)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, TypePrefixSubscription):
|
||||
return False
|
||||
|
||||
return self.id == other.id or (
|
||||
self.agent_type == other.agent_type and self.topic_type_prefix == other.topic_type_prefix
|
||||
)
|
||||
66
agent_dhal/agentdhal_core/_type_subscription.py
Normal file
66
agent_dhal/agentdhal_core/_type_subscription.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import uuid
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_type import AgentType
|
||||
from ._subscription import Subscription
|
||||
from ._topic import TopicId
|
||||
from .exceptions import CantHandleException
|
||||
|
||||
|
||||
class TypeSubscription(Subscription):
|
||||
"""This subscription matches on topics based on the type and maps to agents using the source of the topic as the agent key.
|
||||
|
||||
This subscription causes each source to have its own agent instance.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core import TypeSubscription
|
||||
|
||||
subscription = TypeSubscription(topic_type="t1", agent_type="a1")
|
||||
|
||||
In this case:
|
||||
|
||||
- A topic_id with type `t1` and source `s1` will be handled by an agent of type `a1` with key `s1`
|
||||
- A topic_id with type `t1` and source `s2` will be handled by an agent of type `a1` with key `s2`.
|
||||
|
||||
Args:
|
||||
topic_type (str): Topic type to match against
|
||||
agent_type (str): Agent type to handle this subscription
|
||||
"""
|
||||
|
||||
def __init__(self, topic_type: str, agent_type: str | AgentType, id: str | None = None):
|
||||
self._topic_type = topic_type
|
||||
if isinstance(agent_type, AgentType):
|
||||
self._agent_type = agent_type.type
|
||||
else:
|
||||
self._agent_type = agent_type
|
||||
self._id = id or str(uuid.uuid4())
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def topic_type(self) -> str:
|
||||
return self._topic_type
|
||||
|
||||
@property
|
||||
def agent_type(self) -> str:
|
||||
return self._agent_type
|
||||
|
||||
def is_match(self, topic_id: TopicId) -> bool:
|
||||
return topic_id.type == self._topic_type
|
||||
|
||||
def map_to_agent(self, topic_id: TopicId) -> AgentId:
|
||||
if not self.is_match(topic_id):
|
||||
raise CantHandleException("TopicId does not match the subscription")
|
||||
|
||||
return AgentId(type=self._agent_type, key=topic_id.source)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, TypeSubscription):
|
||||
return False
|
||||
|
||||
return self.id == other.id or (self.agent_type == other.agent_type and self.topic_type == other.topic_type)
|
||||
12
agent_dhal/agentdhal_core/_types.py
Normal file
12
agent_dhal/agentdhal_core/_types.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCall:
|
||||
id: str
|
||||
# JSON args
|
||||
arguments: str
|
||||
# Function to call
|
||||
name: str
|
||||
21
agent_dhal/agentdhal_core/code_executor/__init__.py
Normal file
21
agent_dhal/agentdhal_core/code_executor/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from ._base import CodeBlock, CodeExecutor, CodeResult
|
||||
from ._func_with_reqs import (
|
||||
Alias,
|
||||
FunctionWithRequirements,
|
||||
FunctionWithRequirementsStr,
|
||||
Import,
|
||||
ImportFromModule,
|
||||
with_requirements,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CodeBlock",
|
||||
"CodeExecutor",
|
||||
"CodeResult",
|
||||
"Alias",
|
||||
"ImportFromModule",
|
||||
"Import",
|
||||
"FunctionWithRequirements",
|
||||
"FunctionWithRequirementsStr",
|
||||
"with_requirements",
|
||||
]
|
||||
102
agent_dhal/agentdhal_core/code_executor/_base.py
Normal file
102
agent_dhal/agentdhal_core/code_executor/_base.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/base.py
|
||||
# Credit to original authors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from types import TracebackType
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._cancellation_token import CancellationToken
|
||||
from .._component_config import ComponentBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeBlock:
|
||||
"""A code block extracted fromm an agent message."""
|
||||
|
||||
code: str
|
||||
language: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeResult:
|
||||
"""Result of a code execution."""
|
||||
|
||||
exit_code: int
|
||||
output: str
|
||||
|
||||
|
||||
class CodeExecutor(ABC, ComponentBase[BaseModel]):
|
||||
"""Executes code blocks and returns the result.
|
||||
|
||||
This is an abstract base class for code executors. It defines the interface
|
||||
for executing code blocks and returning the result. A concrete implementation
|
||||
of this class should be provided to execute code blocks in a specific
|
||||
environment. For example, :class:`~agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor` executes
|
||||
code blocks in a command line environment in a Docker container.
|
||||
|
||||
It is recommended for subclass to be used as a context manager to ensure
|
||||
that resources are cleaned up properly. To do this, implement the
|
||||
:meth:`~agentdhal_core.code_executor.CodeExecutor.start` and
|
||||
:meth:`~agentdhal_core.code_executor.CodeExecutor.stop` methods
|
||||
that will be called when entering and exiting the context manager.
|
||||
|
||||
"""
|
||||
|
||||
component_type = "code_executor"
|
||||
|
||||
@abstractmethod
|
||||
async def execute_code_blocks(
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> CodeResult:
|
||||
"""Execute code blocks and return the result.
|
||||
|
||||
This method should be implemented by the code executor.
|
||||
|
||||
Args:
|
||||
code_blocks (List[CodeBlock]): The code blocks to execute.
|
||||
|
||||
Returns:
|
||||
CodeResult: The result of the code execution.
|
||||
|
||||
Raises:
|
||||
ValueError: Errors in user inputs
|
||||
asyncio.TimeoutError: Code execution timeouts
|
||||
asyncio.CancelledError: CancellationToken evoked during execution
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""Start the code executor."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Stop the code executor and release any resources."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def restart(self) -> None:
|
||||
"""Restart the code executor.
|
||||
|
||||
This method should be implemented by the code executor.
|
||||
|
||||
This method is called when the agent is reset.
|
||||
"""
|
||||
...
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
|
||||
) -> Optional[bool]:
|
||||
await self.stop()
|
||||
return None
|
||||
277
agent_dhal/agentdhal_core/code_executor/_func_with_reqs.py
Normal file
277
agent_dhal/agentdhal_core/code_executor/_func_with_reqs.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/func_with_reqs.py
|
||||
# Credit to original authors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from importlib.abc import SourceLoader
|
||||
from importlib.util import module_from_spec, spec_from_loader
|
||||
from textwrap import dedent, indent
|
||||
from typing import Any, Callable, Generic, List, Sequence, Set, Tuple, TypeVar, Union
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str:
|
||||
if isinstance(func, FunctionWithRequirementsStr):
|
||||
return func.func
|
||||
|
||||
if isinstance(func, FunctionWithRequirements):
|
||||
code = inspect.getsource(func.func)
|
||||
else:
|
||||
code = inspect.getsource(func)
|
||||
# Strip the decorator
|
||||
if code.startswith("@"):
|
||||
code = code[code.index("\n") + 1 :]
|
||||
return code
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Alias:
|
||||
name: str
|
||||
alias: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ImportFromModule:
|
||||
module: str
|
||||
imports: Tuple[Union[str, Alias], ...]
|
||||
|
||||
# backward compatibility
|
||||
def __init__(
|
||||
self,
|
||||
module: str,
|
||||
imports: Union[Tuple[Union[str, Alias], ...], List[Union[str, Alias]]],
|
||||
):
|
||||
object.__setattr__(self, "module", module)
|
||||
if isinstance(imports, list):
|
||||
object.__setattr__(self, "imports", tuple(imports))
|
||||
else:
|
||||
object.__setattr__(self, "imports", imports)
|
||||
|
||||
|
||||
Import = Union[str, ImportFromModule, Alias]
|
||||
|
||||
|
||||
def _import_to_str(im: Import) -> str:
|
||||
if isinstance(im, str):
|
||||
return f"import {im}"
|
||||
elif isinstance(im, Alias):
|
||||
return f"import {im.name} as {im.alias}"
|
||||
else:
|
||||
|
||||
def to_str(i: Union[str, Alias]) -> str:
|
||||
if isinstance(i, str):
|
||||
return i
|
||||
else:
|
||||
return f"{i.name} as {i.alias}"
|
||||
|
||||
imports = ", ".join(map(to_str, im.imports))
|
||||
return f"from {im.module} import {imports}"
|
||||
|
||||
|
||||
class _StringLoader(SourceLoader):
|
||||
def __init__(self, data: str):
|
||||
self.data = data
|
||||
|
||||
def get_source(self, fullname: str) -> str:
|
||||
return self.data
|
||||
|
||||
def get_data(self, path: str) -> bytes:
|
||||
return self.data.encode("utf-8")
|
||||
|
||||
def get_filename(self, fullname: str) -> str:
|
||||
return "<not a real path>/" + fullname + ".py"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionWithRequirementsStr:
|
||||
func: str
|
||||
compiled_func: Callable[..., Any]
|
||||
_func_name: str
|
||||
python_packages: Sequence[str] = field(default_factory=list)
|
||||
global_imports: Sequence[Import] = field(default_factory=list)
|
||||
|
||||
def __init__(self, func: str, python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []):
|
||||
self.func = func
|
||||
self.python_packages = python_packages
|
||||
self.global_imports = global_imports
|
||||
|
||||
module_name = "func_module"
|
||||
loader = _StringLoader(func)
|
||||
spec = spec_from_loader(module_name, loader)
|
||||
if spec is None:
|
||||
raise ValueError("Could not create spec")
|
||||
module = module_from_spec(spec)
|
||||
if spec.loader is None:
|
||||
raise ValueError("Could not create loader")
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not compile function: {e}") from e
|
||||
|
||||
functions = inspect.getmembers(module, inspect.isfunction)
|
||||
if len(functions) != 1:
|
||||
raise ValueError("The string must contain exactly one function")
|
||||
|
||||
self._func_name, self.compiled_func = functions[0]
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError("String based function with requirement objects are not directly callable")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionWithRequirements(Generic[T, P]):
|
||||
func: Callable[P, T]
|
||||
python_packages: Sequence[str] = field(default_factory=list)
|
||||
global_imports: Sequence[Import] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_callable(
|
||||
cls, func: Callable[P, T], python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
|
||||
) -> FunctionWithRequirements[T, P]:
|
||||
return cls(python_packages=python_packages, global_imports=global_imports, func=func)
|
||||
|
||||
@staticmethod
|
||||
def from_str(
|
||||
func: str, python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
|
||||
) -> FunctionWithRequirementsStr:
|
||||
return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports)
|
||||
|
||||
# Type this based on F
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def with_requirements(
|
||||
python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
|
||||
) -> Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]:
|
||||
"""
|
||||
Decorate a function with package and import requirements for code execution environments.
|
||||
|
||||
This decorator makes a function available for reference in dynamically executed code blocks
|
||||
by wrapping it in a `FunctionWithRequirements` object that tracks its dependencies. When the
|
||||
decorated function is passed to a code executor, it can be imported by name in the executed
|
||||
code, with all dependencies automatically handled.
|
||||
|
||||
Args:
|
||||
python_packages (Sequence[str], optional): Python packages required by the function.
|
||||
Can include version specifications (e.g., ["pandas>=1.0.0"]). Defaults to [].
|
||||
global_imports (Sequence[Import], optional): Import statements required by the function.
|
||||
Can be strings ("numpy"), ImportFromModule objects, or Alias objects. Defaults to [].
|
||||
|
||||
Returns:
|
||||
Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]: A decorator that wraps
|
||||
the target function, preserving its functionality while registering its dependencies.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import tempfile
|
||||
import asyncio
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.code_executor import with_requirements, CodeBlock
|
||||
from agentdhal_extensions.code_executors.local import LocalCommandLineCodeExecutor
|
||||
import pandas
|
||||
|
||||
@with_requirements(python_packages=["pandas"], global_imports=["pandas"])
|
||||
def load_data() -> pandas.DataFrame:
|
||||
\"\"\"Load some sample data.
|
||||
|
||||
Returns:
|
||||
pandas.DataFrame: A DataFrame with sample data
|
||||
\"\"\"
|
||||
data = {
|
||||
"name": ["John", "Anna", "Peter", "Linda"],
|
||||
"location": ["New York", "Paris", "Berlin", "London"],
|
||||
"age": [24, 13, 53, 33],
|
||||
}
|
||||
return pandas.DataFrame(data)
|
||||
|
||||
async def run_example():
|
||||
# The decorated function can be used in executed code
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[load_data])
|
||||
code = f\"\"\"from {executor.functions_module} import load_data
|
||||
|
||||
# Use the imported function
|
||||
data = load_data()
|
||||
print(data['name'][0])\"\"\"
|
||||
|
||||
result = await executor.execute_code_blocks(
|
||||
code_blocks=[CodeBlock(language="python", code=code)],
|
||||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
print(result.output) # Output: John
|
||||
|
||||
# Run the async example
|
||||
asyncio.run(run_example())
|
||||
"""
|
||||
|
||||
def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]:
|
||||
func_with_reqs = FunctionWithRequirements(
|
||||
python_packages=python_packages, global_imports=global_imports, func=func
|
||||
)
|
||||
|
||||
functools.update_wrapper(func_with_reqs, func)
|
||||
return func_with_reqs
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def build_python_functions_file(
|
||||
funcs: Sequence[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]],
|
||||
) -> str:
|
||||
""":meta private:"""
|
||||
# First collect all global imports
|
||||
global_imports: Set[Import] = set()
|
||||
for func in funcs:
|
||||
if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)):
|
||||
global_imports.update(func.global_imports)
|
||||
|
||||
content = "\n".join(map(_import_to_str, global_imports)) + "\n\n"
|
||||
|
||||
for func in funcs:
|
||||
content += _to_code(func) + "\n\n"
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str:
|
||||
"""Generate a stub for a function as a string
|
||||
|
||||
Args:
|
||||
func (Callable[..., Any]): The function to generate a stub for
|
||||
|
||||
Returns:
|
||||
str: The stub for the function
|
||||
"""
|
||||
if isinstance(func, FunctionWithRequirementsStr):
|
||||
return to_stub(func.compiled_func)
|
||||
|
||||
content = f"def {func.__name__}{inspect.signature(func)}:\n"
|
||||
docstring = func.__doc__
|
||||
|
||||
if docstring:
|
||||
docstring = dedent(docstring)
|
||||
docstring = '"""' + docstring + '"""'
|
||||
docstring = indent(docstring, " ")
|
||||
content += docstring + "\n"
|
||||
|
||||
content += " ..."
|
||||
return content
|
||||
|
||||
|
||||
def to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str:
|
||||
return _to_code(func)
|
||||
|
||||
|
||||
def import_to_str(im: Import) -> str:
|
||||
return _import_to_str(im)
|
||||
17
agent_dhal/agentdhal_core/exceptions.py
Normal file
17
agent_dhal/agentdhal_core/exceptions.py
Normal file
@@ -0,0 +1,17 @@
|
||||
__all__ = ["CantHandleException", "UndeliverableException", "MessageDroppedException", "NotAccessibleError"]
|
||||
|
||||
|
||||
class CantHandleException(Exception):
|
||||
"""Raised when a handler can't handle the exception."""
|
||||
|
||||
|
||||
class UndeliverableException(Exception):
|
||||
"""Raised when a message can't be delivered."""
|
||||
|
||||
|
||||
class MessageDroppedException(Exception):
|
||||
"""Raised when a message is dropped."""
|
||||
|
||||
|
||||
class NotAccessibleError(Exception):
|
||||
"""Tried to access a value that is not accessible. For example if it is remote cannot be accessed locally."""
|
||||
294
agent_dhal/agentdhal_core/logging.py
Normal file
294
agent_dhal/agentdhal_core/logging.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, cast
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._message_handler_context import MessageHandlerContext
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
class LLMCallEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
messages: List[Dict[str, Any]],
|
||||
response: Dict[str, Any],
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""To be used by model clients to log the call to the LLM.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, Any]]): The messages used in the call. Must be json serializable.
|
||||
response (Dict[str, Any]): The response of the call. Must be json serializable.
|
||||
prompt_tokens (int): Number of tokens used in the prompt.
|
||||
completion_tokens (int): Number of tokens used in the completion.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import logging
|
||||
from agentdhal_core import EVENT_LOGGER_NAME
|
||||
from agentdhal_core.logging import LLMCallEvent
|
||||
|
||||
response = {"content": "Hello, world!"}
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.info(LLMCallEvent(prompt_tokens=10, completion_tokens=20, response=response, messages=messages))
|
||||
|
||||
"""
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["type"] = "LLMCall"
|
||||
self.kwargs["messages"] = messages
|
||||
self.kwargs["response"] = response
|
||||
self.kwargs["prompt_tokens"] = prompt_tokens
|
||||
self.kwargs["completion_tokens"] = completion_tokens
|
||||
try:
|
||||
agent_id = MessageHandlerContext.agent_id()
|
||||
except RuntimeError:
|
||||
agent_id = None
|
||||
self.kwargs["agent_id"] = None if agent_id is None else str(agent_id)
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["prompt_tokens"])
|
||||
|
||||
@property
|
||||
def completion_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["completion_tokens"])
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class LLMStreamStartEvent:
|
||||
"""To be used by model clients to log the start of a stream.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, Any]]): The messages used in the call. Must be json serializable.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import logging
|
||||
from agentdhal_core import EVENT_LOGGER_NAME
|
||||
from agentdhal_core.logging import LLMStreamStartEvent
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.info(LLMStreamStartEvent(messages=messages))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
messages: List[Dict[str, Any]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["type"] = "LLMStreamStart"
|
||||
self.kwargs["messages"] = messages
|
||||
try:
|
||||
agent_id = MessageHandlerContext.agent_id()
|
||||
except RuntimeError:
|
||||
agent_id = None
|
||||
self.kwargs["agent_id"] = None if agent_id is None else str(agent_id)
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class LLMStreamEndEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
response: Dict[str, Any],
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""To be used by model clients to log the end of a stream.
|
||||
|
||||
Args:
|
||||
response (Dict[str, Any]): The response of the call. Must be json serializable.
|
||||
prompt_tokens (int): Number of tokens used in the prompt.
|
||||
completion_tokens (int): Number of tokens used in the completion.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import logging
|
||||
from agentdhal_core import EVENT_LOGGER_NAME
|
||||
from agentdhal_core.logging import LLMStreamEndEvent
|
||||
|
||||
response = {"content": "Hello, world!"}
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.info(LLMStreamEndEvent(prompt_tokens=10, completion_tokens=20, response=response))
|
||||
|
||||
"""
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["type"] = "LLMStreamEnd"
|
||||
self.kwargs["response"] = response
|
||||
self.kwargs["prompt_tokens"] = prompt_tokens
|
||||
self.kwargs["completion_tokens"] = completion_tokens
|
||||
try:
|
||||
agent_id = MessageHandlerContext.agent_id()
|
||||
except RuntimeError:
|
||||
agent_id = None
|
||||
self.kwargs["agent_id"] = None if agent_id is None else str(agent_id)
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["prompt_tokens"])
|
||||
|
||||
@property
|
||||
def completion_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["completion_tokens"])
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class ToolCallEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
result: str,
|
||||
) -> None:
|
||||
"""Used by subclasses of :class:`~agentdhal_core.tools.BaseTool` to log executions of tools.
|
||||
|
||||
Args:
|
||||
tool_name (str): The name of the tool.
|
||||
arguments (Dict[str, Any]): The arguments of the tool. Must be json serializable.
|
||||
result (str): The result of the tool. Must be a string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core import EVENT_LOGGER_NAME
|
||||
from agentdhal_core.logging import ToolCallEvent
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.info(ToolCallEvent(tool_name="Tool1", call_id="123", arguments={"arg1": "value1"}))
|
||||
|
||||
"""
|
||||
self.kwargs: Dict[str, Any] = {}
|
||||
self.kwargs["type"] = "ToolCall"
|
||||
self.kwargs["tool_name"] = tool_name
|
||||
self.kwargs["arguments"] = arguments
|
||||
self.kwargs["result"] = result
|
||||
try:
|
||||
agent_id = MessageHandlerContext.agent_id()
|
||||
except RuntimeError:
|
||||
agent_id = None
|
||||
self.kwargs["agent_id"] = None if agent_id is None else str(agent_id)
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class MessageKind(Enum):
|
||||
DIRECT = 1
|
||||
PUBLISH = 2
|
||||
RESPOND = 3
|
||||
|
||||
|
||||
class DeliveryStage(Enum):
|
||||
SEND = 1
|
||||
DELIVER = 2
|
||||
|
||||
|
||||
class MessageEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
payload: str,
|
||||
sender: AgentId | None,
|
||||
receiver: AgentId | TopicId | None,
|
||||
kind: MessageKind,
|
||||
delivery_stage: DeliveryStage,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["payload"] = payload
|
||||
self.kwargs["sender"] = None if sender is None else str(sender)
|
||||
self.kwargs["receiver"] = None if receiver is None else str(receiver)
|
||||
self.kwargs["kind"] = str(kind)
|
||||
self.kwargs["delivery_stage"] = str(delivery_stage)
|
||||
self.kwargs["type"] = "Message"
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class MessageDroppedEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
payload: str,
|
||||
sender: AgentId | None,
|
||||
receiver: AgentId | TopicId | None,
|
||||
kind: MessageKind,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["payload"] = payload
|
||||
self.kwargs["sender"] = None if sender is None else str(sender)
|
||||
self.kwargs["receiver"] = None if receiver is None else str(receiver)
|
||||
self.kwargs["kind"] = str(kind)
|
||||
self.kwargs["type"] = "MessageDropped"
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class MessageHandlerExceptionEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
payload: str,
|
||||
handling_agent: AgentId,
|
||||
exception: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["payload"] = payload
|
||||
self.kwargs["handling_agent"] = str(handling_agent)
|
||||
self.kwargs["exception"] = str(exception)
|
||||
self.kwargs["type"] = "MessageHandlerException"
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class AgentConstructionExceptionEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent_id: AgentId,
|
||||
exception: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["agent_id"] = str(agent_id)
|
||||
self.kwargs["exception"] = str(exception)
|
||||
self.kwargs["type"] = "AgentConstructionException"
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
11
agent_dhal/agentdhal_core/memory/__init__.py
Normal file
11
agent_dhal/agentdhal_core/memory/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from ._base_memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
|
||||
from ._list_memory import ListMemory
|
||||
|
||||
__all__ = [
|
||||
"Memory",
|
||||
"MemoryContent",
|
||||
"MemoryQueryResult",
|
||||
"UpdateContextResult",
|
||||
"MemoryMimeType",
|
||||
"ListMemory",
|
||||
]
|
||||
132
agent_dhal/agentdhal_core/memory/_base_memory.py
Normal file
132
agent_dhal/agentdhal_core/memory/_base_memory.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer
|
||||
|
||||
from .._cancellation_token import CancellationToken
|
||||
from .._component_config import ComponentBase
|
||||
from .._image import Image
|
||||
from ..model_context import ChatCompletionContext
|
||||
|
||||
|
||||
class MemoryMimeType(Enum):
|
||||
"""Supported MIME types for memory content."""
|
||||
|
||||
TEXT = "text/plain"
|
||||
JSON = "application/json"
|
||||
MARKDOWN = "text/markdown"
|
||||
IMAGE = "image/*"
|
||||
BINARY = "application/octet-stream"
|
||||
|
||||
|
||||
ContentType = Union[str, bytes, Dict[str, Any], Image]
|
||||
|
||||
|
||||
class MemoryContent(BaseModel):
|
||||
"""A memory content item."""
|
||||
|
||||
content: ContentType
|
||||
"""The content of the memory item. It can be a string, bytes, dict, or :class:`~agentdhal_core.Image`."""
|
||||
|
||||
mime_type: MemoryMimeType | str
|
||||
"""The MIME type of the memory content."""
|
||||
|
||||
metadata: Dict[str, Any] | None = None
|
||||
"""Metadata associated with the memory item."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@field_serializer("mime_type")
|
||||
def serialize_mime_type(self, mime_type: MemoryMimeType | str) -> str:
|
||||
"""Serialize the MIME type to a string."""
|
||||
if isinstance(mime_type, MemoryMimeType):
|
||||
return mime_type.value
|
||||
return mime_type
|
||||
|
||||
|
||||
class MemoryQueryResult(BaseModel):
|
||||
"""Result of a memory :meth:`~agentdhal_core.memory.Memory.query` operation."""
|
||||
|
||||
results: List[MemoryContent]
|
||||
|
||||
|
||||
class UpdateContextResult(BaseModel):
|
||||
"""Result of a memory :meth:`~agentdhal_core.memory.Memory.update_context` operation."""
|
||||
|
||||
memories: MemoryQueryResult
|
||||
|
||||
|
||||
class Memory(ABC, ComponentBase[BaseModel]):
|
||||
"""Protocol defining the interface for memory implementations.
|
||||
|
||||
A memory is the storage for data that can be used to enrich or modify the model context.
|
||||
|
||||
A memory implementation can use any storage mechanism, such as a list, a database, or a file system.
|
||||
It can also use any retrieval mechanism, such as vector search or text search.
|
||||
It is up to the implementation to decide how to store and retrieve data.
|
||||
|
||||
It is also a memory implementation's responsibility to update the model context
|
||||
with relevant memory content based on the current model context and querying the memory store.
|
||||
|
||||
See :class:`~agentdhal_core.memory.ListMemory` for an example implementation.
|
||||
"""
|
||||
|
||||
component_type = "memory"
|
||||
|
||||
@abstractmethod
|
||||
async def update_context(
|
||||
self,
|
||||
model_context: ChatCompletionContext,
|
||||
) -> UpdateContextResult:
|
||||
"""
|
||||
Update the provided model context using relevant memory content.
|
||||
|
||||
Args:
|
||||
model_context: The context to update.
|
||||
|
||||
Returns:
|
||||
UpdateContextResult containing relevant memories
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def query(
|
||||
self,
|
||||
query: str | MemoryContent,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MemoryQueryResult:
|
||||
"""
|
||||
Query the memory store and return relevant entries.
|
||||
|
||||
Args:
|
||||
query: Query content item
|
||||
cancellation_token: Optional token to cancel operation
|
||||
**kwargs: Additional implementation-specific parameters
|
||||
|
||||
Returns:
|
||||
MemoryQueryResult containing memory entries with relevance scores
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
|
||||
"""
|
||||
Add a new content to memory.
|
||||
|
||||
Args:
|
||||
content: The memory content to add
|
||||
cancellation_token: Optional token to cancel operation
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> None:
|
||||
"""Clear all entries from memory."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Clean up any resources used by the memory implementation."""
|
||||
...
|
||||
172
agent_dhal/agentdhal_core/memory/_list_memory.py
Normal file
172
agent_dhal/agentdhal_core/memory/_list_memory.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._cancellation_token import CancellationToken
|
||||
from .._component_config import Component
|
||||
from ..model_context import ChatCompletionContext
|
||||
from ..models import SystemMessage
|
||||
from ._base_memory import Memory, MemoryContent, MemoryQueryResult, UpdateContextResult
|
||||
|
||||
|
||||
class ListMemoryConfig(BaseModel):
|
||||
"""Configuration for ListMemory component."""
|
||||
|
||||
name: str | None = None
|
||||
"""Optional identifier for this memory instance."""
|
||||
memory_contents: List[MemoryContent] = Field(default_factory=list)
|
||||
"""List of memory contents stored in this memory instance."""
|
||||
|
||||
|
||||
class ListMemory(Memory, Component[ListMemoryConfig]):
|
||||
"""Simple chronological list-based memory implementation.
|
||||
|
||||
This memory implementation stores contents in a list and retrieves them in
|
||||
chronological order. It has an `update_context` method that updates model contexts
|
||||
by appending all stored memories.
|
||||
|
||||
The memory content can be directly accessed and modified through the content property,
|
||||
allowing external applications to manage memory contents directly.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_core.memory import ListMemory, MemoryContent
|
||||
from agentdhal_core.model_context import BufferedChatCompletionContext
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Initialize memory
|
||||
memory = ListMemory(name="chat_history")
|
||||
|
||||
# Add memory content
|
||||
content = MemoryContent(content="User prefers formal language", mime_type="text/plain")
|
||||
await memory.add(content)
|
||||
|
||||
# Directly modify memory contents
|
||||
memory.content = [MemoryContent(content="New preference", mime_type="text/plain")]
|
||||
|
||||
# Create a model context
|
||||
model_context = BufferedChatCompletionContext(buffer_size=10)
|
||||
|
||||
# Update a model context with memory
|
||||
await memory.update_context(model_context)
|
||||
|
||||
# See the updated model context
|
||||
print(await model_context.get_messages())
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Args:
|
||||
name: Optional identifier for this memory instance
|
||||
|
||||
"""
|
||||
|
||||
component_type = "memory"
|
||||
component_provider_override = "agentdhal_core.memory.ListMemory"
|
||||
component_config_schema = ListMemoryConfig
|
||||
|
||||
def __init__(self, name: str | None = None, memory_contents: List[MemoryContent] | None = None) -> None:
|
||||
self._name = name or "default_list_memory"
|
||||
self._contents: List[MemoryContent] = memory_contents if memory_contents is not None else []
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get the memory instance identifier.
|
||||
|
||||
Returns:
|
||||
str: Memory instance name
|
||||
"""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def content(self) -> List[MemoryContent]:
|
||||
"""Get the current memory contents.
|
||||
|
||||
Returns:
|
||||
List[MemoryContent]: List of stored memory contents
|
||||
"""
|
||||
return self._contents
|
||||
|
||||
@content.setter
|
||||
def content(self, value: List[MemoryContent]) -> None:
|
||||
"""Set the memory contents.
|
||||
|
||||
Args:
|
||||
value: New list of memory contents to store
|
||||
"""
|
||||
self._contents = value
|
||||
|
||||
async def update_context(
|
||||
self,
|
||||
model_context: ChatCompletionContext,
|
||||
) -> UpdateContextResult:
|
||||
"""Update the model context by appending memory content.
|
||||
|
||||
This method mutates the provided model_context by adding all memories as a
|
||||
SystemMessage.
|
||||
|
||||
Args:
|
||||
model_context: The context to update. Will be mutated if memories exist.
|
||||
|
||||
Returns:
|
||||
UpdateContextResult containing the memories that were added to the context
|
||||
"""
|
||||
|
||||
if not self._contents:
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
|
||||
|
||||
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(self._contents, 1)]
|
||||
|
||||
if memory_strings:
|
||||
memory_context = "\nRelevant memory content (in chronological order):\n" + "\n".join(memory_strings) + "\n"
|
||||
await model_context.add_message(SystemMessage(content=memory_context))
|
||||
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=self._contents))
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str | MemoryContent = "",
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MemoryQueryResult:
|
||||
"""Return all memories without any filtering.
|
||||
|
||||
Args:
|
||||
query: Ignored in this implementation
|
||||
cancellation_token: Optional token to cancel operation
|
||||
**kwargs: Additional parameters (ignored)
|
||||
|
||||
Returns:
|
||||
MemoryQueryResult containing all stored memories
|
||||
"""
|
||||
_ = query, cancellation_token, kwargs
|
||||
return MemoryQueryResult(results=self._contents)
|
||||
|
||||
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
|
||||
"""Add new content to memory.
|
||||
|
||||
Args:
|
||||
content: Memory content to store
|
||||
cancellation_token: Optional token to cancel operation
|
||||
"""
|
||||
self._contents.append(content)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all memory content."""
|
||||
self._contents = []
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Cleanup resources if needed."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: ListMemoryConfig) -> Self:
|
||||
return cls(name=config.name, memory_contents=config.memory_contents)
|
||||
|
||||
def _to_config(self) -> ListMemoryConfig:
|
||||
return ListMemoryConfig(name=self.name, memory_contents=self._contents)
|
||||
16
agent_dhal/agentdhal_core/model_context/__init__.py
Normal file
16
agent_dhal/agentdhal_core/model_context/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from ._buffered_chat_completion_context import BufferedChatCompletionContext
|
||||
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
|
||||
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
|
||||
from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext
|
||||
from ._unbounded_chat_completion_context import (
|
||||
UnboundedChatCompletionContext,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChatCompletionContext",
|
||||
"ChatCompletionContextState",
|
||||
"UnboundedChatCompletionContext",
|
||||
"BufferedChatCompletionContext",
|
||||
"TokenLimitedChatCompletionContext",
|
||||
"HeadAndTailChatCompletionContext",
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._component_config import Component
|
||||
from ..models import FunctionExecutionResultMessage, LLMMessage
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
|
||||
|
||||
class BufferedChatCompletionContextConfig(BaseModel):
|
||||
buffer_size: int
|
||||
initial_messages: List[LLMMessage] | None = None
|
||||
|
||||
|
||||
class BufferedChatCompletionContext(ChatCompletionContext, Component[BufferedChatCompletionContextConfig]):
|
||||
"""A buffered chat completion context that keeps a view of the last n messages,
|
||||
where n is the buffer size. The buffer size is set at initialization.
|
||||
|
||||
Args:
|
||||
buffer_size (int): The size of the buffer.
|
||||
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||
"""
|
||||
|
||||
component_config_schema = BufferedChatCompletionContextConfig
|
||||
component_provider_override = "agentdhal_core.model_context.BufferedChatCompletionContext"
|
||||
|
||||
def __init__(self, buffer_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
super().__init__(initial_messages)
|
||||
if buffer_size <= 0:
|
||||
raise ValueError("buffer_size must be greater than 0.")
|
||||
self._buffer_size = buffer_size
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `buffer_size` recent messages."""
|
||||
messages = self._messages[-self._buffer_size :]
|
||||
# Handle the first message is a function call result message.
|
||||
if messages and isinstance(messages[0], FunctionExecutionResultMessage):
|
||||
# Remove the first message from the list.
|
||||
messages = messages[1:]
|
||||
return messages
|
||||
|
||||
def _to_config(self) -> BufferedChatCompletionContextConfig:
|
||||
return BufferedChatCompletionContextConfig(
|
||||
buffer_size=self._buffer_size, initial_messages=self._initial_messages
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: BufferedChatCompletionContextConfig) -> Self:
|
||||
return cls(**config.model_dump())
|
||||
@@ -0,0 +1,74 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .._component_config import ComponentBase
|
||||
from ..models import LLMMessage
|
||||
|
||||
|
||||
class ChatCompletionContext(ABC, ComponentBase[BaseModel]):
|
||||
"""An abstract base class for defining the interface of a chat completion context.
|
||||
A chat completion context lets agents store and retrieve LLM messages.
|
||||
It can be implemented with different recall strategies.
|
||||
|
||||
Args:
|
||||
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||
|
||||
Example:
|
||||
|
||||
To create a custom model context that filters out the thought field from AssistantMessage.
|
||||
This is useful for reasoning models like DeepSeek R1, which produces
|
||||
very long thought that is not needed for subsequent completions.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from typing import List
|
||||
|
||||
from agentdhal_core.model_context import UnboundedChatCompletionContext
|
||||
from agentdhal_core.models import AssistantMessage, LLMMessage
|
||||
|
||||
|
||||
class ReasoningModelContext(UnboundedChatCompletionContext):
|
||||
\"\"\"A model context for reasoning models.\"\"\"
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
messages = await super().get_messages()
|
||||
# Filter out thought field from AssistantMessage.
|
||||
messages_out: List[LLMMessage] = []
|
||||
for message in messages:
|
||||
if isinstance(message, AssistantMessage):
|
||||
message.thought = None
|
||||
messages_out.append(message)
|
||||
return messages_out
|
||||
|
||||
"""
|
||||
|
||||
component_type = "chat_completion_context"
|
||||
|
||||
def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
self._messages: List[LLMMessage] = []
|
||||
if initial_messages is not None:
|
||||
self._messages.extend(initial_messages)
|
||||
self._initial_messages = initial_messages
|
||||
|
||||
async def add_message(self, message: LLMMessage) -> None:
|
||||
"""Add a message to the context."""
|
||||
self._messages.append(message)
|
||||
|
||||
@abstractmethod
|
||||
async def get_messages(self) -> List[LLMMessage]: ...
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the context."""
|
||||
self._messages = []
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return ChatCompletionContextState(messages=self._messages).model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._messages = ChatCompletionContextState.model_validate(state).messages
|
||||
|
||||
|
||||
class ChatCompletionContextState(BaseModel):
|
||||
messages: List[LLMMessage] = Field(default_factory=list)
|
||||
@@ -0,0 +1,76 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._component_config import Component
|
||||
from .._types import FunctionCall
|
||||
from ..models import AssistantMessage, FunctionExecutionResultMessage, LLMMessage, UserMessage
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
|
||||
|
||||
class HeadAndTailChatCompletionContextConfig(BaseModel):
|
||||
head_size: int
|
||||
tail_size: int
|
||||
initial_messages: List[LLMMessage] | None = None
|
||||
|
||||
|
||||
class HeadAndTailChatCompletionContext(ChatCompletionContext, Component[HeadAndTailChatCompletionContextConfig]):
|
||||
"""A chat completion context that keeps a view of the first n and last m messages,
|
||||
where n is the head size and m is the tail size. The head and tail sizes
|
||||
are set at initialization.
|
||||
|
||||
Args:
|
||||
head_size (int): The size of the head.
|
||||
tail_size (int): The size of the tail.
|
||||
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||
"""
|
||||
|
||||
component_config_schema = HeadAndTailChatCompletionContextConfig
|
||||
component_provider_override = "agentdhal_core.model_context.HeadAndTailChatCompletionContext"
|
||||
|
||||
def __init__(self, head_size: int, tail_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
super().__init__(initial_messages)
|
||||
if head_size <= 0:
|
||||
raise ValueError("head_size must be greater than 0.")
|
||||
if tail_size <= 0:
|
||||
raise ValueError("tail_size must be greater than 0.")
|
||||
self._head_size = head_size
|
||||
self._tail_size = tail_size
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `head_size` recent messages and `tail_size` oldest messages."""
|
||||
head_messages = self._messages[: self._head_size]
|
||||
# Handle the last message is a function call message.
|
||||
if (
|
||||
head_messages
|
||||
and isinstance(head_messages[-1], AssistantMessage)
|
||||
and isinstance(head_messages[-1].content, list)
|
||||
and all(isinstance(item, FunctionCall) for item in head_messages[-1].content)
|
||||
):
|
||||
# Remove the last message from the head.
|
||||
head_messages = head_messages[:-1]
|
||||
|
||||
tail_messages = self._messages[-self._tail_size :]
|
||||
# Handle the first message is a function call result message.
|
||||
if tail_messages and isinstance(tail_messages[0], FunctionExecutionResultMessage):
|
||||
# Remove the first message from the tail.
|
||||
tail_messages = tail_messages[1:]
|
||||
|
||||
num_skipped = len(self._messages) - self._head_size - self._tail_size
|
||||
if num_skipped <= 0:
|
||||
# If there are not enough messages to fill the head and tail,
|
||||
# return all messages.
|
||||
return self._messages
|
||||
|
||||
placeholder_messages = [UserMessage(content=f"Skipped {num_skipped} messages.", source="System")]
|
||||
return head_messages + placeholder_messages + tail_messages
|
||||
|
||||
def _to_config(self) -> HeadAndTailChatCompletionContextConfig:
|
||||
return HeadAndTailChatCompletionContextConfig(
|
||||
head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._initial_messages
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: HeadAndTailChatCompletionContextConfig) -> Self:
|
||||
return cls(head_size=config.head_size, tail_size=config.tail_size, initial_messages=config.initial_messages)
|
||||
@@ -0,0 +1,94 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._component_config import Component, ComponentModel
|
||||
from ..models import ChatCompletionClient, FunctionExecutionResultMessage, LLMMessage
|
||||
from ..tools import ToolSchema
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
|
||||
|
||||
class TokenLimitedChatCompletionContextConfig(BaseModel):
|
||||
model_client: ComponentModel
|
||||
token_limit: int | None = None
|
||||
tool_schema: List[ToolSchema] | None = None
|
||||
initial_messages: List[LLMMessage] | None = None
|
||||
|
||||
|
||||
class TokenLimitedChatCompletionContext(ChatCompletionContext, Component[TokenLimitedChatCompletionContextConfig]):
|
||||
"""(Experimental) A token based chat completion context maintains a view of the context up to a token limit.
|
||||
|
||||
.. note::
|
||||
|
||||
Added in v0.4.10. This is an experimental component and may change in the future.
|
||||
|
||||
Args:
|
||||
model_client (ChatCompletionClient): The model client to use for token counting.
|
||||
The model client must implement the :meth:`~agentdhal_core.models.ChatCompletionClient.count_tokens`
|
||||
and :meth:`~agentdhal_core.models.ChatCompletionClient.remaining_tokens` methods.
|
||||
token_limit (int | None): The maximum number of tokens to keep in the context
|
||||
using the :meth:`~agentdhal_core.models.ChatCompletionClient.count_tokens` method.
|
||||
If None, the context will be limited by the model client using the
|
||||
:meth:`~agentdhal_core.models.ChatCompletionClient.remaining_tokens` method.
|
||||
tools (List[ToolSchema] | None): A list of tool schema to use in the context.
|
||||
initial_messages (List[LLMMessage] | None): A list of initial messages to include in the context.
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = TokenLimitedChatCompletionContextConfig
|
||||
component_provider_override = "agentdhal_core.model_context.TokenLimitedChatCompletionContext"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
token_limit: int | None = None,
|
||||
tool_schema: List[ToolSchema] | None = None,
|
||||
initial_messages: List[LLMMessage] | None = None,
|
||||
) -> None:
|
||||
super().__init__(initial_messages)
|
||||
if token_limit is not None and token_limit <= 0:
|
||||
raise ValueError("token_limit must be greater than 0.")
|
||||
self._token_limit = token_limit
|
||||
self._model_client = model_client
|
||||
self._tool_schema = tool_schema or []
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `token_limit` tokens in recent messages. If the token limit is not
|
||||
provided, then return as many messages as the remaining token allowed by the model client."""
|
||||
messages = list(self._messages)
|
||||
if self._token_limit is None:
|
||||
remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema)
|
||||
while remaining_tokens < 0 and len(messages) > 0:
|
||||
middle_index = len(messages) // 2
|
||||
messages.pop(middle_index)
|
||||
remaining_tokens = self._model_client.remaining_tokens(messages, tools=self._tool_schema)
|
||||
else:
|
||||
token_count = self._model_client.count_tokens(messages, tools=self._tool_schema)
|
||||
while token_count > self._token_limit and len(messages) > 0:
|
||||
middle_index = len(messages) // 2
|
||||
messages.pop(middle_index)
|
||||
token_count = self._model_client.count_tokens(messages, tools=self._tool_schema)
|
||||
if messages and isinstance(messages[0], FunctionExecutionResultMessage):
|
||||
# Handle the first message is a function call result message.
|
||||
# Remove the first message from the list.
|
||||
messages = messages[1:]
|
||||
return messages
|
||||
|
||||
def _to_config(self) -> TokenLimitedChatCompletionContextConfig:
|
||||
return TokenLimitedChatCompletionContextConfig(
|
||||
model_client=self._model_client.dump_component(),
|
||||
token_limit=self._token_limit,
|
||||
tool_schema=self._tool_schema,
|
||||
initial_messages=self._initial_messages,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: TokenLimitedChatCompletionContextConfig) -> Self:
|
||||
return cls(
|
||||
model_client=ChatCompletionClient.load_component(config.model_client),
|
||||
token_limit=config.token_limit,
|
||||
tool_schema=config.tool_schema,
|
||||
initial_messages=config.initial_messages,
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._component_config import Component
|
||||
from ..models import LLMMessage
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
|
||||
|
||||
class UnboundedChatCompletionContextConfig(BaseModel):
|
||||
initial_messages: List[LLMMessage] | None = None
|
||||
|
||||
|
||||
class UnboundedChatCompletionContext(ChatCompletionContext, Component[UnboundedChatCompletionContextConfig]):
|
||||
"""An unbounded chat completion context that keeps a view of the all the messages."""
|
||||
|
||||
component_config_schema = UnboundedChatCompletionContextConfig
|
||||
component_provider_override = "agentdhal_core.model_context.UnboundedChatCompletionContext"
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `buffer_size` recent messages."""
|
||||
return self._messages
|
||||
|
||||
def _to_config(self) -> UnboundedChatCompletionContextConfig:
|
||||
return UnboundedChatCompletionContextConfig(initial_messages=self._initial_messages)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: UnboundedChatCompletionContextConfig) -> Self:
|
||||
return cls(initial_messages=config.initial_messages)
|
||||
0
agent_dhal/agentdhal_core/py.typed
Normal file
0
agent_dhal/agentdhal_core/py.typed
Normal file
17
agent_dhal/agentdhal_core/tool_agent/__init__.py
Normal file
17
agent_dhal/agentdhal_core/tool_agent/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from ._caller_loop import tool_agent_caller_loop
|
||||
from ._tool_agent import (
|
||||
InvalidToolArgumentsException,
|
||||
ToolAgent,
|
||||
ToolException,
|
||||
ToolExecutionException,
|
||||
ToolNotFoundException,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ToolAgent",
|
||||
"ToolException",
|
||||
"ToolNotFoundException",
|
||||
"InvalidToolArgumentsException",
|
||||
"ToolExecutionException",
|
||||
"tool_agent_caller_loop",
|
||||
]
|
||||
80
agent_dhal/agentdhal_core/tool_agent/_caller_loop.py
Normal file
80
agent_dhal/agentdhal_core/tool_agent/_caller_loop.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from .. import AgentId, AgentRuntime, BaseAgent, CancellationToken, FunctionCall
|
||||
from ..models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
)
|
||||
from ..tools import Tool, ToolSchema
|
||||
from ._tool_agent import ToolException
|
||||
|
||||
|
||||
async def tool_agent_caller_loop(
|
||||
caller: BaseAgent | AgentRuntime,
|
||||
tool_agent_id: AgentId,
|
||||
model_client: ChatCompletionClient,
|
||||
input_messages: List[LLMMessage],
|
||||
tool_schema: List[ToolSchema] | List[Tool],
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
caller_source: str = "assistant",
|
||||
) -> List[LLMMessage]:
|
||||
"""Start a caller loop for a tool agent. This function sends messages to the tool agent
|
||||
and the model client in an alternating fashion until the model client stops generating tool calls.
|
||||
|
||||
Args:
|
||||
tool_agent_id (AgentId): The Agent ID of the tool agent.
|
||||
input_messages (List[LLMMessage]): The list of input messages.
|
||||
model_client (ChatCompletionClient): The model client to use for the model API.
|
||||
tool_schema (List[Tool | ToolSchema]): The list of tools that the model can use.
|
||||
|
||||
Returns:
|
||||
List[LLMMessage]: The list of output messages created in the caller loop.
|
||||
"""
|
||||
|
||||
generated_messages: List[LLMMessage] = []
|
||||
|
||||
# Get a response from the model.
|
||||
response = await model_client.create(input_messages, tools=tool_schema, cancellation_token=cancellation_token)
|
||||
# Add the response to the generated messages.
|
||||
generated_messages.append(AssistantMessage(content=response.content, source=caller_source))
|
||||
|
||||
# Keep iterating until the model stops generating tool calls.
|
||||
while isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content):
|
||||
# Execute functions called by the model by sending messages to tool agent.
|
||||
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
|
||||
*[
|
||||
caller.send_message(
|
||||
message=call,
|
||||
recipient=tool_agent_id,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
for call in response.content
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
# Combine the results into a single response and handle exceptions.
|
||||
function_results: List[FunctionExecutionResult] = []
|
||||
for result in results:
|
||||
if isinstance(result, FunctionExecutionResult):
|
||||
function_results.append(result)
|
||||
elif isinstance(result, ToolException):
|
||||
function_results.append(
|
||||
FunctionExecutionResult(
|
||||
content=f"Error: {result}", call_id=result.call_id, is_error=True, name=result.name
|
||||
)
|
||||
)
|
||||
elif isinstance(result, BaseException):
|
||||
raise result # Unexpected exception.
|
||||
generated_messages.append(FunctionExecutionResultMessage(content=function_results))
|
||||
# Query the model again with the new response.
|
||||
response = await model_client.create(
|
||||
input_messages + generated_messages, tools=tool_schema, cancellation_token=cancellation_token
|
||||
)
|
||||
generated_messages.append(AssistantMessage(content=response.content, source=caller_source))
|
||||
|
||||
# Return the generated messages.
|
||||
return generated_messages
|
||||
96
agent_dhal/agentdhal_core/tool_agent/_tool_agent.py
Normal file
96
agent_dhal/agentdhal_core/tool_agent/_tool_agent.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from .. import FunctionCall, MessageContext, RoutedAgent, message_handler
|
||||
from ..models import FunctionExecutionResult
|
||||
from ..tools import Tool
|
||||
|
||||
__all__ = [
|
||||
"ToolAgent",
|
||||
"ToolException",
|
||||
"ToolNotFoundException",
|
||||
"InvalidToolArgumentsException",
|
||||
"ToolExecutionException",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolException(BaseException):
|
||||
call_id: str
|
||||
content: str
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolNotFoundException(ToolException):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvalidToolArgumentsException(ToolException):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolExecutionException(ToolException):
|
||||
pass
|
||||
|
||||
|
||||
class ToolAgent(RoutedAgent):
|
||||
"""A tool agent accepts direct messages of the type `FunctionCall`,
|
||||
executes the requested tool with the provided arguments, and returns the
|
||||
result as `FunctionExecutionResult` messages.
|
||||
|
||||
Args:
|
||||
description (str): The description of the agent.
|
||||
tools (List[Tool]): The list of tools that the agent can execute.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
tools: List[Tool],
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
self._tools = tools
|
||||
|
||||
@property
|
||||
def tools(self) -> List[Tool]:
|
||||
return self._tools
|
||||
|
||||
@message_handler
|
||||
async def handle_function_call(self, message: FunctionCall, ctx: MessageContext) -> FunctionExecutionResult:
|
||||
"""Handles a `FunctionCall` message by executing the requested tool with the provided arguments.
|
||||
|
||||
Args:
|
||||
message (FunctionCall): The function call message.
|
||||
cancellation_token (CancellationToken): The cancellation token.
|
||||
|
||||
Returns:
|
||||
FunctionExecutionResult: The result of the function execution.
|
||||
|
||||
Raises:
|
||||
ToolNotFoundException: If the tool is not found.
|
||||
InvalidToolArgumentsException: If the tool arguments are invalid.
|
||||
ToolExecutionException: If the tool execution fails.
|
||||
"""
|
||||
tool = next((tool for tool in self._tools if tool.name == message.name), None)
|
||||
if tool is None:
|
||||
raise ToolNotFoundException(
|
||||
call_id=message.id, content=f"Error: Tool not found: {message.name}", name=message.name
|
||||
)
|
||||
else:
|
||||
try:
|
||||
arguments = json.loads(message.arguments)
|
||||
result = await tool.run_json(
|
||||
args=arguments, cancellation_token=ctx.cancellation_token, call_id=message.id
|
||||
)
|
||||
result_as_str = tool.return_value_as_string(result)
|
||||
except json.JSONDecodeError as e:
|
||||
raise InvalidToolArgumentsException(
|
||||
call_id=message.id, content=f"Error: Invalid arguments: {message.arguments}", name=message.name
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ToolExecutionException(call_id=message.id, content=f"Error: {e}", name=message.name) from e
|
||||
return FunctionExecutionResult(content=result_as_str, call_id=message.id, is_error=False, name=message.name)
|
||||
31
agent_dhal/agentdhal_core/tools/__init__.py
Normal file
31
agent_dhal/agentdhal_core/tools/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from ._base import (
|
||||
BaseStreamTool,
|
||||
BaseTool,
|
||||
BaseToolWithState,
|
||||
ParametersSchema,
|
||||
StreamTool,
|
||||
Tool,
|
||||
ToolOverride,
|
||||
ToolSchema,
|
||||
)
|
||||
from ._function_tool import FunctionTool
|
||||
from ._static_workbench import StaticStreamWorkbench, StaticWorkbench
|
||||
from ._workbench import ImageResultContent, TextResultContent, ToolResult, Workbench
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"StreamTool",
|
||||
"ToolSchema",
|
||||
"ParametersSchema",
|
||||
"BaseTool",
|
||||
"BaseToolWithState",
|
||||
"BaseStreamTool",
|
||||
"FunctionTool",
|
||||
"Workbench",
|
||||
"ToolResult",
|
||||
"TextResultContent",
|
||||
"ImageResultContent",
|
||||
"StaticWorkbench",
|
||||
"StaticStreamWorkbench",
|
||||
"ToolOverride",
|
||||
]
|
||||
294
agent_dhal/agentdhal_core/tools/_base.py
Normal file
294
agent_dhal/agentdhal_core/tools/_base.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Generic,
|
||||
Mapping,
|
||||
Optional,
|
||||
Protocol,
|
||||
Type,
|
||||
TypeVar,
|
||||
cast,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
import jsonref
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from .. import EVENT_LOGGER_NAME, CancellationToken
|
||||
from .._component_config import ComponentBase
|
||||
from .._function_utils import normalize_annotated_type
|
||||
from .._telemetry import trace_tool_span
|
||||
from ..logging import ToolCallEvent
|
||||
|
||||
T = TypeVar("T", bound=BaseModel, contravariant=True)
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class ParametersSchema(TypedDict):
|
||||
type: str
|
||||
properties: Dict[str, Any]
|
||||
required: NotRequired[Sequence[str]]
|
||||
additionalProperties: NotRequired[bool]
|
||||
|
||||
|
||||
class ToolSchema(TypedDict):
|
||||
parameters: NotRequired[ParametersSchema]
|
||||
name: str
|
||||
description: NotRequired[str]
|
||||
strict: NotRequired[bool]
|
||||
|
||||
|
||||
class ToolOverride(BaseModel):
|
||||
"""Override configuration for a tool's name and/or description."""
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Tool(Protocol):
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def description(self) -> str: ...
|
||||
|
||||
@property
|
||||
def schema(self) -> ToolSchema: ...
|
||||
|
||||
def args_type(self) -> Type[BaseModel]: ...
|
||||
|
||||
def return_type(self) -> Type[Any]: ...
|
||||
|
||||
def state_type(self) -> Type[BaseModel] | None: ...
|
||||
|
||||
def return_value_as_string(self, value: Any) -> str: ...
|
||||
|
||||
async def run_json(
|
||||
self, args: Mapping[str, Any], cancellation_token: CancellationToken, call_id: str | None = None
|
||||
) -> Any: ...
|
||||
|
||||
async def save_state_json(self) -> Mapping[str, Any]: ...
|
||||
|
||||
async def load_state_json(self, state: Mapping[str, Any]) -> None: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class StreamTool(Tool, Protocol):
|
||||
def run_json_stream(
|
||||
self, args: Mapping[str, Any], cancellation_token: CancellationToken, call_id: str | None = None
|
||||
) -> AsyncGenerator[Any, None]: ...
|
||||
|
||||
|
||||
ArgsT = TypeVar("ArgsT", bound=BaseModel, contravariant=True)
|
||||
ReturnT = TypeVar("ReturnT", bound=BaseModel, covariant=True)
|
||||
StateT = TypeVar("StateT", bound=BaseModel)
|
||||
StreamT = TypeVar("StreamT", bound=BaseModel, covariant=True)
|
||||
|
||||
|
||||
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT], ComponentBase[BaseModel]):
|
||||
component_type = "tool"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args_type: Type[ArgsT],
|
||||
return_type: Type[ReturnT],
|
||||
name: str,
|
||||
description: str,
|
||||
strict: bool = False,
|
||||
) -> None:
|
||||
self._args_type = args_type
|
||||
# Normalize Annotated to the base type.
|
||||
self._return_type = normalize_annotated_type(return_type)
|
||||
self._name = name
|
||||
self._description = description
|
||||
self._strict = strict
|
||||
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
model_schema: Dict[str, Any] = self._args_type.model_json_schema()
|
||||
|
||||
if "$defs" in model_schema:
|
||||
model_schema = cast(Dict[str, Any], jsonref.replace_refs(obj=model_schema, proxies=False)) # type: ignore
|
||||
del model_schema["$defs"]
|
||||
|
||||
parameters = ParametersSchema(
|
||||
type="object",
|
||||
properties=model_schema["properties"],
|
||||
required=model_schema.get("required", []),
|
||||
additionalProperties=model_schema.get("additionalProperties", False),
|
||||
)
|
||||
|
||||
# If strict is enabled, the tool schema should list all properties as required.
|
||||
assert "required" in parameters
|
||||
if self._strict and set(parameters["required"]) != set(parameters["properties"].keys()):
|
||||
raise ValueError(
|
||||
"Strict mode is enabled, but not all input arguments are marked as required. Default arguments are not allowed in strict mode."
|
||||
)
|
||||
|
||||
assert "additionalProperties" in parameters
|
||||
if self._strict and parameters["additionalProperties"]:
|
||||
raise ValueError(
|
||||
"Strict mode is enabled but additional argument is also enabled. This is not allowed in strict mode."
|
||||
)
|
||||
|
||||
tool_schema = ToolSchema(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
parameters=parameters,
|
||||
strict=self._strict,
|
||||
)
|
||||
return tool_schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
def args_type(self) -> Type[BaseModel]:
|
||||
return self._args_type
|
||||
|
||||
def return_type(self) -> Type[Any]:
|
||||
return self._return_type
|
||||
|
||||
def state_type(self) -> Type[BaseModel] | None:
|
||||
return None
|
||||
|
||||
def return_value_as_string(self, value: Any) -> str:
|
||||
if isinstance(value, BaseModel):
|
||||
dumped = value.model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return json.dumps(dumped)
|
||||
return str(dumped)
|
||||
|
||||
return str(value)
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, args: ArgsT, cancellation_token: CancellationToken) -> ReturnT: ...
|
||||
|
||||
async def run_json(
|
||||
self, args: Mapping[str, Any], cancellation_token: CancellationToken, call_id: str | None = None
|
||||
) -> Any:
|
||||
"""Run the tool with the provided arguments in a dictionary.
|
||||
|
||||
Args:
|
||||
args (Mapping[str, Any]): The arguments to pass to the tool.
|
||||
cancellation_token (CancellationToken): A token to cancel the operation if needed.
|
||||
call_id (str | None): An optional identifier for the tool call, used for tracing.
|
||||
|
||||
Returns:
|
||||
Any: The return value of the tool's run method.
|
||||
"""
|
||||
with trace_tool_span(
|
||||
tool_name=self._name,
|
||||
tool_description=self._description,
|
||||
tool_call_id=call_id,
|
||||
):
|
||||
# Execute the tool's run method
|
||||
return_value = await self.run(self._args_type.model_validate(args), cancellation_token)
|
||||
|
||||
# Log the tool call event
|
||||
event = ToolCallEvent(
|
||||
tool_name=self.name,
|
||||
arguments=dict(args), # Using the raw args passed to run_json
|
||||
result=self.return_value_as_string(return_value),
|
||||
)
|
||||
logger.info(event)
|
||||
|
||||
return return_value
|
||||
|
||||
async def save_state_json(self) -> Mapping[str, Any]:
|
||||
return {}
|
||||
|
||||
async def load_state_json(self, state: Mapping[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BaseStreamTool(
|
||||
BaseTool[ArgsT, ReturnT], StreamTool, ABC, Generic[ArgsT, StreamT, ReturnT], ComponentBase[BaseModel]
|
||||
):
|
||||
component_type = "tool"
|
||||
|
||||
@abstractmethod
|
||||
def run_stream(self, args: ArgsT, cancellation_token: CancellationToken) -> AsyncGenerator[StreamT | ReturnT, None]:
|
||||
"""Run the tool with the provided arguments and return a stream of data and end with the final return value."""
|
||||
...
|
||||
|
||||
async def run_json_stream(
|
||||
self,
|
||||
args: Mapping[str, Any],
|
||||
cancellation_token: CancellationToken,
|
||||
call_id: str | None = None,
|
||||
) -> AsyncGenerator[StreamT | ReturnT, None]:
|
||||
"""Run the tool with the provided arguments in a dictionary and return a stream of data
|
||||
from the tool's :meth:`run_stream` method and end with the final return value.
|
||||
|
||||
Args:
|
||||
args (Mapping[str, Any]): The arguments to pass to the tool.
|
||||
cancellation_token (CancellationToken): A token to cancel the operation if needed.
|
||||
call_id (str | None): An optional identifier for the tool call, used for tracing.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[StreamT | ReturnT, None]: A generator yielding results from the tool's :meth:`run_stream` method.
|
||||
"""
|
||||
return_value: ReturnT | StreamT | None = None
|
||||
with trace_tool_span(
|
||||
tool_name=self._name,
|
||||
tool_description=self._description,
|
||||
tool_call_id=call_id,
|
||||
):
|
||||
# Execute the tool's run_stream method
|
||||
async for result in self.run_stream(self._args_type.model_validate(args), cancellation_token):
|
||||
return_value = result
|
||||
yield result
|
||||
|
||||
assert return_value is not None, "The tool must yield a final return value at the end of the stream."
|
||||
if not isinstance(return_value, self._return_type):
|
||||
raise TypeError(
|
||||
f"Expected return value of type {self._return_type.__name__}, but got {type(return_value).__name__}"
|
||||
)
|
||||
|
||||
# Log the tool call event
|
||||
event = ToolCallEvent(
|
||||
tool_name=self.name,
|
||||
arguments=dict(args), # Using the raw args passed to run_json
|
||||
result=self.return_value_as_string(return_value),
|
||||
)
|
||||
logger.info(event)
|
||||
|
||||
|
||||
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT], ComponentBase[BaseModel]):
|
||||
def __init__(
|
||||
self,
|
||||
args_type: Type[ArgsT],
|
||||
return_type: Type[ReturnT],
|
||||
state_type: Type[StateT],
|
||||
name: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
super().__init__(args_type, return_type, name, description)
|
||||
self._state_type = state_type
|
||||
|
||||
component_type = "tool"
|
||||
|
||||
@abstractmethod
|
||||
def save_state(self) -> StateT: ...
|
||||
|
||||
@abstractmethod
|
||||
def load_state(self, state: StateT) -> None: ...
|
||||
|
||||
async def save_state_json(self) -> Mapping[str, Any]:
|
||||
return self.save_state().model_dump()
|
||||
|
||||
async def load_state_json(self, state: Mapping[str, Any]) -> None:
|
||||
self.load_state(self._state_type.model_validate(state))
|
||||
181
agent_dhal/agentdhal_core/tools/_function_tool.py
Normal file
181
agent_dhal/agentdhal_core/tools/_function_tool.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import warnings
|
||||
from textwrap import dedent
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .. import CancellationToken
|
||||
from .._component_config import Component
|
||||
from .._function_utils import (
|
||||
args_base_model_from_signature,
|
||||
get_typed_signature,
|
||||
)
|
||||
from ..code_executor._func_with_reqs import Import, import_to_str, to_code
|
||||
from ._base import BaseTool
|
||||
|
||||
|
||||
class FunctionToolConfig(BaseModel):
|
||||
"""Configuration for a function tool."""
|
||||
|
||||
source_code: str
|
||||
name: str
|
||||
description: str
|
||||
global_imports: Sequence[Import]
|
||||
has_cancellation_support: bool
|
||||
|
||||
|
||||
class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]):
|
||||
"""
|
||||
Create custom tools by wrapping standard Python functions.
|
||||
|
||||
`FunctionTool` offers an interface for executing Python functions either asynchronously or synchronously.
|
||||
Each function must include type annotations for all parameters and its return type. These annotations
|
||||
enable `FunctionTool` to generate a schema necessary for input validation, serialization, and for informing
|
||||
the LLM about expected parameters. When the LLM prepares a function call, it leverages this schema to
|
||||
generate arguments that align with the function's specifications.
|
||||
|
||||
.. note::
|
||||
|
||||
It is the user's responsibility to verify that the tool's output type matches the expected type.
|
||||
|
||||
Args:
|
||||
func (Callable[..., ReturnT | Awaitable[ReturnT]]): The function to wrap and expose as a tool.
|
||||
description (str): A description to inform the model of the function's purpose, specifying what
|
||||
it does and the context in which it should be called.
|
||||
name (str, optional): An optional custom name for the tool. Defaults to
|
||||
the function's original name if not provided.
|
||||
strict (bool, optional): If set to True, the tool schema will only contain arguments that are explicitly
|
||||
defined in the function signature, and no default values will be allowed. Defaults to False.
|
||||
This is required to be set to True when used with models in structured output mode.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import random
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import FunctionTool
|
||||
from typing_extensions import Annotated
|
||||
import asyncio
|
||||
|
||||
|
||||
async def get_stock_price(ticker: str, date: Annotated[str, "Date in YYYY/MM/DD"]) -> float:
|
||||
# Simulates a stock price retrieval by returning a random float within a specified range.
|
||||
return random.uniform(10, 200)
|
||||
|
||||
|
||||
async def example():
|
||||
# Initialize a FunctionTool instance for retrieving stock prices.
|
||||
stock_price_tool = FunctionTool(get_stock_price, description="Fetch the stock price for a given ticker.")
|
||||
|
||||
# Execute the tool with cancellation support.
|
||||
cancellation_token = CancellationToken()
|
||||
result = await stock_price_tool.run_json({"ticker": "AAPL", "date": "2021/01/01"}, cancellation_token)
|
||||
|
||||
# Output the result as a formatted string.
|
||||
print(stock_price_tool.return_value_as_string(result))
|
||||
|
||||
|
||||
asyncio.run(example())
|
||||
"""
|
||||
|
||||
component_provider_override = "agentdhal_core.tools.FunctionTool"
|
||||
component_config_schema = FunctionToolConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[..., Any],
|
||||
description: str,
|
||||
name: str | None = None,
|
||||
global_imports: Sequence[Import] = [],
|
||||
strict: bool = False,
|
||||
) -> None:
|
||||
self._func = func
|
||||
self._global_imports = global_imports
|
||||
self._signature = get_typed_signature(func)
|
||||
func_name = name or func.func.__name__ if isinstance(func, functools.partial) else name or func.__name__
|
||||
args_model = args_base_model_from_signature(func_name + "args", self._signature)
|
||||
self._has_cancellation_support = "cancellation_token" in self._signature.parameters
|
||||
return_type = self._signature.return_annotation
|
||||
super().__init__(args_model, return_type, func_name, description, strict)
|
||||
|
||||
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
|
||||
kwargs = {}
|
||||
|
||||
for name in self._signature.parameters.keys():
|
||||
if hasattr(args, name):
|
||||
kwargs[name] = getattr(args, name)
|
||||
|
||||
if asyncio.iscoroutinefunction(self._func):
|
||||
if self._has_cancellation_support:
|
||||
result = await self._func(**kwargs, cancellation_token=cancellation_token)
|
||||
else:
|
||||
result = await self._func(**kwargs)
|
||||
else:
|
||||
if self._has_cancellation_support:
|
||||
result = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self._func,
|
||||
**kwargs,
|
||||
cancellation_token=cancellation_token,
|
||||
),
|
||||
)
|
||||
else:
|
||||
future = asyncio.get_event_loop().run_in_executor(None, functools.partial(self._func, **kwargs))
|
||||
cancellation_token.link_future(future)
|
||||
result = await future
|
||||
|
||||
return result
|
||||
|
||||
def _to_config(self) -> FunctionToolConfig:
|
||||
return FunctionToolConfig(
|
||||
source_code=dedent(to_code(self._func)),
|
||||
global_imports=self._global_imports,
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
has_cancellation_support=self._has_cancellation_support,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: FunctionToolConfig) -> Self:
|
||||
warnings.warn(
|
||||
"\n⚠️ SECURITY WARNING ⚠️\n"
|
||||
"Loading a FunctionTool from config will execute code to import the provided global imports and and function code.\n"
|
||||
"Only load configs from TRUSTED sources to prevent arbitrary code execution.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
exec_globals: dict[str, Any] = {}
|
||||
|
||||
# Execute imports first
|
||||
for import_stmt in config.global_imports:
|
||||
import_code = import_to_str(import_stmt)
|
||||
try:
|
||||
exec(import_code, exec_globals)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Failed to import {import_code}: Module not found. Please ensure the module is installed."
|
||||
) from e
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Failed to import {import_code}: {str(e)}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unexpected error while importing {import_code}: {str(e)}") from e
|
||||
|
||||
# Execute function code
|
||||
try:
|
||||
exec(config.source_code, exec_globals)
|
||||
func_name = config.source_code.split("def ")[1].split("(")[0]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not compile and load function: {e}") from e
|
||||
|
||||
# Get function and verify it's callable
|
||||
func: Callable[..., Any] = exec_globals[func_name]
|
||||
if not callable(func):
|
||||
raise TypeError(f"Expected function but got {type(func)}")
|
||||
|
||||
return cls(func, name=config.name, description=config.description, global_imports=config.global_imports)
|
||||
225
agent_dhal/agentdhal_core/tools/_static_workbench.py
Normal file
225
agent_dhal/agentdhal_core/tools/_static_workbench.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import asyncio
|
||||
import builtins
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._cancellation_token import CancellationToken
|
||||
from .._component_config import Component, ComponentModel
|
||||
from ._base import BaseTool, StreamTool, ToolOverride, ToolSchema
|
||||
from ._workbench import StreamWorkbench, TextResultContent, ToolResult, Workbench
|
||||
|
||||
|
||||
class StaticWorkbenchConfig(BaseModel):
|
||||
tools: List[ComponentModel] = []
|
||||
tool_overrides: Dict[str, ToolOverride] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class StateicWorkbenchState(BaseModel):
|
||||
type: Literal["StaticWorkbenchState"] = "StaticWorkbenchState"
|
||||
tools: Dict[str, Mapping[str, Any]] = {}
|
||||
|
||||
|
||||
class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
|
||||
"""
|
||||
A workbench that provides a static set of tools that do not change after
|
||||
each tool execution.
|
||||
|
||||
Args:
|
||||
tools (List[BaseTool[Any, Any]]): A list of tools to be included in the workbench.
|
||||
The tools should be subclasses of :class:`~agentdhal_core.tools.BaseTool`.
|
||||
tool_overrides (Optional[Dict[str, ToolOverride]]): Optional mapping of original tool
|
||||
names to override configurations for name and/or description. This allows
|
||||
customizing how tools appear to consumers while maintaining the underlying
|
||||
tool functionality.
|
||||
"""
|
||||
|
||||
component_provider_override = "agentdhal_core.tools.StaticWorkbench"
|
||||
component_config_schema = StaticWorkbenchConfig
|
||||
|
||||
def __init__(
|
||||
self, tools: List[BaseTool[Any, Any]], tool_overrides: Optional[Dict[str, ToolOverride]] = None
|
||||
) -> None:
|
||||
self._tools = tools
|
||||
self._tool_overrides = tool_overrides or {}
|
||||
|
||||
# Build reverse mapping from override names to original names for call_tool
|
||||
self._override_name_to_original: Dict[str, str] = {}
|
||||
existing_tool_names = {tool.name for tool in self._tools}
|
||||
|
||||
for original_name, override in self._tool_overrides.items():
|
||||
if override.name and override.name != original_name:
|
||||
# Check for conflicts with existing tool names
|
||||
if override.name in existing_tool_names and override.name != original_name:
|
||||
raise ValueError(
|
||||
f"Tool override name '{override.name}' conflicts with existing tool name. "
|
||||
f"Override names must not conflict with any tool names."
|
||||
)
|
||||
# Check for conflicts with other override names
|
||||
if override.name in self._override_name_to_original:
|
||||
existing_original = self._override_name_to_original[override.name]
|
||||
raise ValueError(
|
||||
f"Tool override name '{override.name}' is used by multiple tools: "
|
||||
f"'{existing_original}' and '{original_name}'. Override names must be unique."
|
||||
)
|
||||
self._override_name_to_original[override.name] = original_name
|
||||
|
||||
async def list_tools(self) -> List[ToolSchema]:
|
||||
result_schemas: List[ToolSchema] = []
|
||||
for tool in self._tools:
|
||||
original_schema = tool.schema
|
||||
|
||||
# Apply overrides if they exist for this tool
|
||||
if tool.name in self._tool_overrides:
|
||||
override = self._tool_overrides[tool.name]
|
||||
# Create a new ToolSchema with overrides applied
|
||||
schema: ToolSchema = {
|
||||
"name": override.name if override.name is not None else original_schema["name"],
|
||||
"description": override.description
|
||||
if override.description is not None
|
||||
else original_schema.get("description", ""),
|
||||
}
|
||||
# Copy optional fields
|
||||
if "parameters" in original_schema:
|
||||
schema["parameters"] = original_schema["parameters"]
|
||||
if "strict" in original_schema:
|
||||
schema["strict"] = original_schema["strict"]
|
||||
else:
|
||||
schema = original_schema
|
||||
|
||||
result_schemas.append(schema)
|
||||
return result_schemas
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: Mapping[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
call_id: str | None = None,
|
||||
) -> ToolResult:
|
||||
# Check if the name is an override name and map it back to the original
|
||||
original_name = self._override_name_to_original.get(name, name)
|
||||
|
||||
tool = next((tool for tool in self._tools if tool.name == original_name), None)
|
||||
if tool is None:
|
||||
return ToolResult(
|
||||
name=name, # Return the requested name (which might be overridden)
|
||||
result=[TextResultContent(content=f"Tool {name} not found.")],
|
||||
is_error=True,
|
||||
)
|
||||
if not cancellation_token:
|
||||
cancellation_token = CancellationToken()
|
||||
if not arguments:
|
||||
arguments = {}
|
||||
try:
|
||||
result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id))
|
||||
cancellation_token.link_future(result_future)
|
||||
actual_tool_output = await result_future
|
||||
is_error = False
|
||||
result_str = tool.return_value_as_string(actual_tool_output)
|
||||
except Exception as e:
|
||||
result_str = self._format_errors(e)
|
||||
is_error = True
|
||||
return ToolResult(name=name, result=[TextResultContent(content=result_str)], is_error=is_error)
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
|
||||
async def stop(self) -> None:
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
return None
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
tool_states = StateicWorkbenchState()
|
||||
for tool in self._tools:
|
||||
tool_states.tools[tool.name] = await tool.save_state_json()
|
||||
return tool_states.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
parsed_state = StateicWorkbenchState.model_validate(state)
|
||||
for tool in self._tools:
|
||||
if tool.name in parsed_state.tools:
|
||||
await tool.load_state_json(parsed_state.tools[tool.name])
|
||||
|
||||
def _to_config(self) -> StaticWorkbenchConfig:
|
||||
return StaticWorkbenchConfig(
|
||||
tools=[tool.dump_component() for tool in self._tools], tool_overrides=self._tool_overrides
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: StaticWorkbenchConfig) -> Self:
|
||||
return cls(tools=[BaseTool.load_component(tool) for tool in config.tools], tool_overrides=config.tool_overrides)
|
||||
|
||||
def _format_errors(self, error: Exception) -> str:
|
||||
"""Recursively format errors into a string."""
|
||||
|
||||
error_message = ""
|
||||
if hasattr(builtins, "ExceptionGroup") and isinstance(error, builtins.ExceptionGroup):
|
||||
# ExceptionGroup is available in Python 3.11+.
|
||||
# TODO: how to make this compatible with Python 3.10?
|
||||
for sub_exception in error.exceptions: # type: ignore
|
||||
error_message += self._format_errors(sub_exception) # type: ignore
|
||||
else:
|
||||
error_message += f"{str(error)}\n"
|
||||
return error_message.strip()
|
||||
|
||||
|
||||
class StaticStreamWorkbench(StaticWorkbench, StreamWorkbench):
|
||||
"""
|
||||
A workbench that provides a static set of tools that do not change after
|
||||
each tool execution, and supports streaming results.
|
||||
"""
|
||||
|
||||
component_provider_override = "agentdhal_core.tools.StaticStreamWorkbench"
|
||||
|
||||
async def call_tool_stream(
|
||||
self,
|
||||
name: str,
|
||||
arguments: Mapping[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
call_id: str | None = None,
|
||||
) -> AsyncGenerator[Any | ToolResult, None]:
|
||||
tool = next((tool for tool in self._tools if tool.name == name), None)
|
||||
if tool is None:
|
||||
yield ToolResult(
|
||||
name=name,
|
||||
result=[TextResultContent(content=f"Tool {name} not found.")],
|
||||
is_error=True,
|
||||
)
|
||||
return
|
||||
if not cancellation_token:
|
||||
cancellation_token = CancellationToken()
|
||||
if not arguments:
|
||||
arguments = {}
|
||||
try:
|
||||
actual_tool_output: Any | None = None
|
||||
if isinstance(tool, StreamTool):
|
||||
previous_result: Any | None = None
|
||||
try:
|
||||
async for result in tool.run_json_stream(arguments, cancellation_token, call_id=call_id):
|
||||
if previous_result is not None:
|
||||
yield previous_result
|
||||
previous_result = result
|
||||
actual_tool_output = previous_result
|
||||
except Exception as e:
|
||||
# If there was a previous result before the exception, yield it first
|
||||
if previous_result is not None:
|
||||
yield previous_result
|
||||
# Then yield the error result
|
||||
result_str = self._format_errors(e)
|
||||
yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=True)
|
||||
return
|
||||
else:
|
||||
# If the tool is not a stream tool, we run it normally and yield the result
|
||||
result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id))
|
||||
cancellation_token.link_future(result_future)
|
||||
actual_tool_output = await result_future
|
||||
is_error = False
|
||||
result_str = tool.return_value_as_string(actual_tool_output)
|
||||
except Exception as e:
|
||||
result_str = self._format_errors(e)
|
||||
is_error = True
|
||||
yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error)
|
||||
216
agent_dhal/agentdhal_core/tools/_workbench.py
Normal file
216
agent_dhal/agentdhal_core/tools/_workbench.py
Normal file
@@ -0,0 +1,216 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from types import TracebackType
|
||||
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, Self
|
||||
|
||||
from .._cancellation_token import CancellationToken
|
||||
from .._component_config import ComponentBase
|
||||
from .._image import Image
|
||||
from ._base import ToolSchema
|
||||
|
||||
|
||||
class TextResultContent(BaseModel):
|
||||
"""
|
||||
Text result content of a tool execution.
|
||||
"""
|
||||
|
||||
type: Literal["TextResultContent"] = "TextResultContent"
|
||||
|
||||
content: str
|
||||
"""The text content of the result."""
|
||||
|
||||
|
||||
class ImageResultContent(BaseModel):
|
||||
"""
|
||||
Image result content of a tool execution.
|
||||
"""
|
||||
|
||||
type: Literal["ImageResultContent"] = "ImageResultContent"
|
||||
|
||||
content: Image
|
||||
"""The image content of the result."""
|
||||
|
||||
|
||||
ResultContent = Annotated[TextResultContent | ImageResultContent, Field(discriminator="type")]
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""
|
||||
A result of a tool execution by a workbench.
|
||||
"""
|
||||
|
||||
type: Literal["ToolResult"] = "ToolResult"
|
||||
|
||||
name: str
|
||||
"""The name of the tool that was executed."""
|
||||
|
||||
result: List[ResultContent]
|
||||
"""The result of the tool execution."""
|
||||
|
||||
is_error: bool = False
|
||||
"""Whether the tool execution resulted in an error."""
|
||||
|
||||
def to_text(self, replace_image: str | None = None) -> str:
|
||||
"""
|
||||
Convert the result to a text string.
|
||||
|
||||
Args:
|
||||
replace_image (str | None): The string to replace the image content with.
|
||||
If None, the image content will be included in the text as base64 string.
|
||||
|
||||
Returns:
|
||||
str: The text representation of the result.
|
||||
"""
|
||||
parts: List[str] = []
|
||||
for content in self.result:
|
||||
if isinstance(content, TextResultContent):
|
||||
parts.append(content.content)
|
||||
elif isinstance(content, ImageResultContent):
|
||||
if replace_image is not None:
|
||||
parts.append(replace_image)
|
||||
else:
|
||||
parts.append(f"[Image: {content.content.to_base64()}]")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
class Workbench(ABC, ComponentBase[BaseModel]):
|
||||
"""
|
||||
A workbench is a component that provides a set of tools that may share
|
||||
resources and state.
|
||||
|
||||
A workbench is responsible for managing the lifecycle of the tools and
|
||||
providing a single interface to call them. The tools provided by the workbench
|
||||
may be dynamic and their availabilities may change after each tool execution.
|
||||
|
||||
A workbench can be started by calling the :meth:`~agentdhal_core.tools.Workbench.start` method
|
||||
and stopped by calling the :meth:`~agentdhal_core.tools.Workbench.stop` method.
|
||||
It can also be used as an asynchronous context manager, which will automatically
|
||||
start and stop the workbench when entering and exiting the context.
|
||||
"""
|
||||
|
||||
component_type = "workbench"
|
||||
|
||||
@abstractmethod
|
||||
async def list_tools(self) -> List[ToolSchema]:
|
||||
"""
|
||||
List the currently available tools in the workbench as :class:`ToolSchema`
|
||||
objects.
|
||||
|
||||
The list of tools may be dynamic, and their content may change after
|
||||
tool execution.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: Mapping[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
call_id: str | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Call a tool in the workbench.
|
||||
|
||||
Args:
|
||||
name (str): The name of the tool to call.
|
||||
arguments (Mapping[str, Any] | None): The arguments to pass to the tool.
|
||||
If None, the tool will be called with no arguments.
|
||||
cancellation_token (CancellationToken | None): An optional cancellation token
|
||||
to cancel the tool execution.
|
||||
call_id (str | None): An optional identifier for the tool call, used for tracing.
|
||||
Returns:
|
||||
ToolResult: The result of the tool execution.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the workbench and initialize any resources.
|
||||
|
||||
This method should be called before using the workbench.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""
|
||||
Stop the workbench and release any resources.
|
||||
|
||||
This method should be called when the workbench is no longer needed.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""
|
||||
Reset the workbench to its initialized, started state.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""
|
||||
Save the state of the workbench.
|
||||
|
||||
This method should be called to persist the state of the workbench.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""
|
||||
Load the state of the workbench.
|
||||
|
||||
Args:
|
||||
state (Mapping[str, Any]): The state to load into the workbench.
|
||||
"""
|
||||
...
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""
|
||||
Enter the workbench context manager.
|
||||
|
||||
This method is called when the workbench is used in a `with` statement.
|
||||
It calls the :meth:`~agentdhal_core.tools.WorkBench.start` method to start the workbench.
|
||||
"""
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
|
||||
) -> None:
|
||||
"""
|
||||
Exit the workbench context manager.
|
||||
This method is called when the workbench is used in a `with` statement.
|
||||
It calls the :meth:`~agentdhal_core.tools.WorkBench.stop` method to stop the workbench.
|
||||
"""
|
||||
await self.stop()
|
||||
|
||||
|
||||
class StreamWorkbench(Workbench, ABC):
|
||||
"""A workbench that supports streaming results from tool calls."""
|
||||
|
||||
@abstractmethod
|
||||
def call_tool_stream(
|
||||
self,
|
||||
name: str,
|
||||
arguments: Mapping[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
call_id: str | None = None,
|
||||
) -> AsyncGenerator[Any | ToolResult, None]:
|
||||
"""
|
||||
Call a tool in the workbench and return a stream of results.
|
||||
|
||||
Args:
|
||||
name (str): The name of the tool to call.
|
||||
arguments (Mapping[str, Any] | None): The arguments to pass to the tool
|
||||
If None, the tool will be called with no arguments.
|
||||
cancellation_token (CancellationToken | None): An optional cancellation token
|
||||
to cancel the tool execution.
|
||||
call_id (str | None): An optional identifier for the tool call, used for tracing.
|
||||
"""
|
||||
...
|
||||
4
agent_dhal/agentdhal_core/utils/__init__.py
Normal file
4
agent_dhal/agentdhal_core/utils/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from ._json_to_pydantic import schema_to_pydantic_model
|
||||
from ._load_json import extract_json_from_str
|
||||
|
||||
__all__ = ["schema_to_pydantic_model", "extract_json_from_str"]
|
||||
567
agent_dhal/agentdhal_core/utils/_json_to_pydantic.py
Normal file
567
agent_dhal/agentdhal_core/utils/_json_to_pydantic.py
Normal file
@@ -0,0 +1,567 @@
|
||||
import datetime
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
from typing import Annotated, Any, Dict, ForwardRef, List, Literal, Optional, Type, Union, cast
|
||||
|
||||
from pydantic import (
|
||||
UUID1,
|
||||
UUID3,
|
||||
UUID4,
|
||||
UUID5,
|
||||
AnyUrl,
|
||||
BaseModel,
|
||||
EmailStr,
|
||||
Field,
|
||||
Json,
|
||||
conbytes,
|
||||
confloat,
|
||||
conint,
|
||||
conlist,
|
||||
constr,
|
||||
create_model,
|
||||
)
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
class SchemaConversionError(Exception):
|
||||
"""Base class for schema conversion exceptions."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ReferenceNotFoundError(SchemaConversionError):
|
||||
"""Raised when a $ref cannot be resolved."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FormatNotSupportedError(SchemaConversionError):
|
||||
"""Raised when a format is not supported."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedKeywordError(SchemaConversionError):
|
||||
"""Raised when an unsupported JSON Schema keyword is encountered."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
TYPE_MAPPING: Dict[str, Type[Any]] = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"number": float,
|
||||
"array": List,
|
||||
"object": dict,
|
||||
"null": type(None),
|
||||
}
|
||||
|
||||
FORMAT_MAPPING: Dict[str, Any] = {
|
||||
"uuid": UUID4,
|
||||
"uuid1": UUID1,
|
||||
"uuid2": UUID4,
|
||||
"uuid3": UUID3,
|
||||
"uuid4": UUID4,
|
||||
"uuid5": UUID5,
|
||||
"email": EmailStr,
|
||||
"uri": AnyUrl,
|
||||
"hostname": constr(strict=True),
|
||||
"ipv4": IPv4Address,
|
||||
"ipv6": IPv6Address,
|
||||
"ipv4-network": IPv4Address,
|
||||
"ipv6-network": IPv6Address,
|
||||
"date-time": datetime.datetime,
|
||||
"date": datetime.date,
|
||||
"time": datetime.time,
|
||||
"duration": datetime.timedelta,
|
||||
"int32": conint(strict=True, ge=-(2**31), le=2**31 - 1),
|
||||
"int64": conint(strict=True, ge=-(2**63), le=2**63 - 1),
|
||||
"float": confloat(strict=True),
|
||||
"double": float,
|
||||
"decimal": float,
|
||||
"byte": conbytes(strict=True),
|
||||
"binary": conbytes(strict=True),
|
||||
"password": str,
|
||||
"path": str,
|
||||
"json": Json,
|
||||
}
|
||||
|
||||
|
||||
def _make_field(
|
||||
default: Any,
|
||||
*,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Construct a Pydantic Field with proper typing."""
|
||||
field_kwargs: Dict[str, Any] = {}
|
||||
if title is not None:
|
||||
field_kwargs["title"] = title
|
||||
if description is not None:
|
||||
field_kwargs["description"] = description
|
||||
return Field(default, **field_kwargs)
|
||||
|
||||
|
||||
class _JSONSchemaToPydantic:
|
||||
def __init__(self) -> None:
|
||||
self._model_cache: Dict[str, Optional[Union[Type[BaseModel], ForwardRef]]] = {}
|
||||
|
||||
def _resolve_ref(self, ref: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
ref_key = ref.split("/")[-1]
|
||||
definitions = cast(dict[str, dict[str, Any]], schema.get("$defs", {}))
|
||||
|
||||
if ref_key not in definitions:
|
||||
raise ReferenceNotFoundError(
|
||||
f"Reference `{ref}` not found in `$defs`. Available keys: {list(definitions.keys())}"
|
||||
)
|
||||
|
||||
return definitions[ref_key]
|
||||
|
||||
def get_ref(self, ref_name: str) -> Any:
|
||||
if ref_name not in self._model_cache:
|
||||
raise ReferenceNotFoundError(
|
||||
f"Reference `{ref_name}` not found in cache. Available: {list(self._model_cache.keys())}"
|
||||
)
|
||||
|
||||
if self._model_cache[ref_name] is None:
|
||||
return ForwardRef(ref_name)
|
||||
|
||||
return self._model_cache[ref_name]
|
||||
|
||||
def _process_definitions(self, root_schema: Dict[str, Any]) -> None:
|
||||
if "$defs" in root_schema:
|
||||
for model_name in root_schema["$defs"]:
|
||||
if model_name not in self._model_cache:
|
||||
self._model_cache[model_name] = None
|
||||
|
||||
for model_name, model_schema in root_schema["$defs"].items():
|
||||
if self._model_cache[model_name] is None:
|
||||
self._model_cache[model_name] = self.json_schema_to_pydantic(model_schema, model_name, root_schema)
|
||||
|
||||
def json_schema_to_pydantic(
|
||||
self, schema: Dict[str, Any], model_name: str = "GeneratedModel", root_schema: Optional[Dict[str, Any]] = None
|
||||
) -> Type[BaseModel]:
|
||||
if root_schema is None:
|
||||
root_schema = schema
|
||||
self._process_definitions(root_schema)
|
||||
|
||||
if "$ref" in schema:
|
||||
resolved = self._resolve_ref(schema["$ref"], root_schema)
|
||||
schema = {**resolved, **{k: v for k, v in schema.items() if k != "$ref"}}
|
||||
|
||||
if "allOf" in schema:
|
||||
merged: Dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
||||
for s in schema["allOf"]:
|
||||
part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s
|
||||
merged["properties"].update(part.get("properties", {}))
|
||||
merged["required"].extend(part.get("required", []))
|
||||
for k, v in schema.items():
|
||||
if k not in {"allOf", "properties", "required"}:
|
||||
merged[k] = v
|
||||
merged["required"] = list(set(merged["required"]))
|
||||
schema = merged
|
||||
|
||||
return self._json_schema_to_model(schema, model_name, root_schema)
|
||||
|
||||
def _resolve_union_types(self, schemas: List[Dict[str, Any]]) -> List[Any]:
|
||||
types: List[Any] = []
|
||||
for s in schemas:
|
||||
if "$ref" in s:
|
||||
types.append(self.get_ref(s["$ref"].split("/")[-1]))
|
||||
elif "enum" in s:
|
||||
types.append(Literal[tuple(s["enum"])] if len(s["enum"]) > 0 else Any)
|
||||
else:
|
||||
json_type = s.get("type")
|
||||
if json_type not in TYPE_MAPPING:
|
||||
raise UnsupportedKeywordError(f"Unsupported or missing type `{json_type}` in union")
|
||||
|
||||
# Handle array types with items specification
|
||||
if json_type == "array" and "items" in s:
|
||||
item_schema = s["items"]
|
||||
if "$ref" in item_schema:
|
||||
item_type = self.get_ref(item_schema["$ref"].split("/")[-1])
|
||||
else:
|
||||
item_type_name = item_schema.get("type")
|
||||
if item_type_name is None:
|
||||
item_type = str
|
||||
elif item_type_name not in TYPE_MAPPING:
|
||||
raise UnsupportedKeywordError(f"Unsupported item type `{item_type_name}` in union array")
|
||||
else:
|
||||
item_type = TYPE_MAPPING[item_type_name]
|
||||
|
||||
constraints = {}
|
||||
if "minItems" in s:
|
||||
constraints["min_length"] = s["minItems"]
|
||||
if "maxItems" in s:
|
||||
constraints["max_length"] = s["maxItems"]
|
||||
|
||||
array_type = conlist(item_type, **constraints) if constraints else List[item_type] # type: ignore[valid-type]
|
||||
types.append(array_type)
|
||||
else:
|
||||
types.append(TYPE_MAPPING[json_type])
|
||||
return types
|
||||
|
||||
def _extract_field_type(self, key: str, value: Dict[str, Any], model_name: str, root_schema: Dict[str, Any]) -> Any:
|
||||
json_type = value.get("type")
|
||||
if json_type not in TYPE_MAPPING:
|
||||
raise UnsupportedKeywordError(
|
||||
f"Unsupported or missing type `{json_type}` for field `{key}` in `{model_name}`"
|
||||
)
|
||||
|
||||
base_type = TYPE_MAPPING[json_type]
|
||||
constraints: Dict[str, Any] = {}
|
||||
|
||||
if json_type == "string":
|
||||
if "minLength" in value:
|
||||
constraints["min_length"] = value["minLength"]
|
||||
if "maxLength" in value:
|
||||
constraints["max_length"] = value["maxLength"]
|
||||
if "pattern" in value:
|
||||
constraints["pattern"] = value["pattern"]
|
||||
if constraints:
|
||||
base_type = constr(**constraints)
|
||||
|
||||
elif json_type == "integer":
|
||||
if "minimum" in value:
|
||||
constraints["ge"] = value["minimum"]
|
||||
if "maximum" in value:
|
||||
constraints["le"] = value["maximum"]
|
||||
if "exclusiveMinimum" in value:
|
||||
constraints["gt"] = value["exclusiveMinimum"]
|
||||
if "exclusiveMaximum" in value:
|
||||
constraints["lt"] = value["exclusiveMaximum"]
|
||||
if constraints:
|
||||
base_type = conint(**constraints)
|
||||
|
||||
elif json_type == "number":
|
||||
if "minimum" in value:
|
||||
constraints["ge"] = value["minimum"]
|
||||
if "maximum" in value:
|
||||
constraints["le"] = value["maximum"]
|
||||
if "exclusiveMinimum" in value:
|
||||
constraints["gt"] = value["exclusiveMinimum"]
|
||||
if "exclusiveMaximum" in value:
|
||||
constraints["lt"] = value["exclusiveMaximum"]
|
||||
if constraints:
|
||||
base_type = confloat(**constraints)
|
||||
|
||||
elif json_type == "array":
|
||||
if "minItems" in value:
|
||||
constraints["min_length"] = value["minItems"]
|
||||
if "maxItems" in value:
|
||||
constraints["max_length"] = value["maxItems"]
|
||||
item_schema = value.get("items", {"type": "string"})
|
||||
if "$ref" in item_schema:
|
||||
item_type = self.get_ref(item_schema["$ref"].split("/")[-1])
|
||||
else:
|
||||
item_type_name = item_schema.get("type")
|
||||
if item_type_name is None:
|
||||
item_type = str
|
||||
elif item_type_name not in TYPE_MAPPING:
|
||||
raise UnsupportedKeywordError(
|
||||
f"Unsupported or missing item type `{item_type_name}` for array field `{key}` in `{model_name}`"
|
||||
)
|
||||
else:
|
||||
item_type = TYPE_MAPPING[item_type_name]
|
||||
|
||||
base_type = conlist(item_type, **constraints) if constraints else List[item_type] # type: ignore[valid-type]
|
||||
|
||||
if "format" in value:
|
||||
format_type = FORMAT_MAPPING.get(value["format"])
|
||||
if format_type is None:
|
||||
raise FormatNotSupportedError(f"Unknown format `{value['format']}` for `{key}` in `{model_name}`")
|
||||
if not isinstance(format_type, type):
|
||||
return format_type
|
||||
if not issubclass(format_type, str):
|
||||
return format_type
|
||||
return format_type
|
||||
|
||||
return base_type
|
||||
|
||||
def _json_schema_to_model(
|
||||
self, schema: Dict[str, Any], model_name: str, root_schema: Dict[str, Any]
|
||||
) -> Type[BaseModel]:
|
||||
if "allOf" in schema:
|
||||
merged: Dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
||||
for s in schema["allOf"]:
|
||||
part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s
|
||||
merged["properties"].update(part.get("properties", {}))
|
||||
merged["required"].extend(part.get("required", []))
|
||||
for k, v in schema.items():
|
||||
if k not in {"allOf", "properties", "required"}:
|
||||
merged[k] = v
|
||||
merged["required"] = list(set(merged["required"]))
|
||||
schema = merged
|
||||
|
||||
fields: Dict[str, tuple[Any, FieldInfo]] = {}
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
for key, value in schema.get("properties", {}).items():
|
||||
if "$ref" in value:
|
||||
ref_name = value["$ref"].split("/")[-1]
|
||||
field_type = self.get_ref(ref_name)
|
||||
elif "anyOf" in value:
|
||||
sub_models = self._resolve_union_types(value["anyOf"])
|
||||
field_type = Union[tuple(sub_models)]
|
||||
elif "oneOf" in value:
|
||||
sub_models = self._resolve_union_types(value["oneOf"])
|
||||
field_type = Union[tuple(sub_models)]
|
||||
if "discriminator" in value:
|
||||
discriminator = value["discriminator"]["propertyName"]
|
||||
field_type = Annotated[field_type, Field(discriminator=discriminator)]
|
||||
elif "enum" in value:
|
||||
field_type = Literal[tuple(value["enum"])]
|
||||
elif "allOf" in value:
|
||||
merged = {"type": "object", "properties": {}, "required": []}
|
||||
for s in value["allOf"]:
|
||||
part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s
|
||||
merged["properties"].update(part.get("properties", {}))
|
||||
merged["required"].extend(part.get("required", []))
|
||||
for k, v in value.items():
|
||||
if k not in {"allOf", "properties", "required"}:
|
||||
merged[k] = v
|
||||
merged["required"] = list(set(merged["required"]))
|
||||
field_type = self._json_schema_to_model(merged, f"{model_name}_{key}", root_schema)
|
||||
elif value.get("type") == "object" and "properties" in value:
|
||||
field_type = self._json_schema_to_model(value, f"{model_name}_{key}", root_schema)
|
||||
else:
|
||||
field_type = self._extract_field_type(key, value, model_name, root_schema)
|
||||
|
||||
if field_type is None:
|
||||
raise UnsupportedKeywordError(f"Unsupported or missing type for field `{key}` in `{model_name}`")
|
||||
|
||||
default_value = value.get("default")
|
||||
is_required = key in required_fields
|
||||
|
||||
if not is_required and default_value is None:
|
||||
field_type = Optional[field_type]
|
||||
|
||||
field_args = {
|
||||
"default": default_value if not is_required else ...,
|
||||
}
|
||||
if "title" in value:
|
||||
field_args["title"] = value["title"]
|
||||
if "description" in value:
|
||||
field_args["description"] = value["description"]
|
||||
|
||||
fields[key] = (
|
||||
field_type,
|
||||
_make_field(
|
||||
default_value if not is_required else ...,
|
||||
title=value.get("title"),
|
||||
description=value.get("description"),
|
||||
),
|
||||
)
|
||||
|
||||
model: Type[BaseModel] = create_model(model_name, **cast(dict[str, Any], fields))
|
||||
model.model_rebuild()
|
||||
return model
|
||||
|
||||
|
||||
def schema_to_pydantic_model(schema: Dict[str, Any], model_name: str = "GeneratedModel") -> Type[BaseModel]:
|
||||
"""
|
||||
Convert a JSON Schema dictionary to a fully-typed Pydantic model.
|
||||
|
||||
This function handles schema translation and validation logic to produce
|
||||
a Pydantic model.
|
||||
|
||||
**Supported JSON Schema Features**
|
||||
|
||||
- **Primitive types**: `string`, `integer`, `number`, `boolean`, `object`, `array`, `null`
|
||||
- **String formats**:
|
||||
- `email`, `uri`, `uuid`, `uuid1`, `uuid3`, `uuid4`, `uuid5`
|
||||
- `hostname`, `ipv4`, `ipv6`, `ipv4-network`, `ipv6-network`
|
||||
- `date`, `time`, `date-time`, `duration`
|
||||
- `byte`, `binary`, `password`, `path`
|
||||
- **String constraints**:
|
||||
- `minLength`, `maxLength`, `pattern`
|
||||
- **Numeric constraints**:
|
||||
- `minimum`, `maximum`, `exclusiveMinimum`, `exclusiveMaximum`
|
||||
- **Array constraints**:
|
||||
- `minItems`, `maxItems`, `items`
|
||||
- **Object schema support**:
|
||||
- `properties`, `required`, `title`, `description`, `default`
|
||||
- **Enums**:
|
||||
- Converted to Python `Literal` type
|
||||
- **Union types**:
|
||||
- `anyOf`, `oneOf` supported with optional `discriminator`
|
||||
- **Inheritance and composition**:
|
||||
- `allOf` merges multiple schemas into one model
|
||||
- **$ref and $defs resolution**:
|
||||
- Supports references to sibling definitions and self-referencing schemas
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core.utils import schema_to_pydantic_model
|
||||
|
||||
# Example 1: Simple user model
|
||||
schema = {
|
||||
"title": "User",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"email": {"type": "string", "format": "email"},
|
||||
"age": {"type": "integer", "minimum": 0},
|
||||
},
|
||||
"required": ["name", "email"],
|
||||
}
|
||||
|
||||
UserModel = schema_to_pydantic_model(schema)
|
||||
user = UserModel(name="Alice", email="alice@example.com", age=30)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core.utils import schema_to_pydantic_model
|
||||
|
||||
# Example 2: Nested model
|
||||
schema = {
|
||||
"title": "BlogPost",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"tags": {"type": "array", "items": {"type": "string"}},
|
||||
"author": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "email": {"type": "string", "format": "email"}},
|
||||
"required": ["name"],
|
||||
},
|
||||
},
|
||||
"required": ["title", "author"],
|
||||
}
|
||||
|
||||
BlogPost = schema_to_pydantic_model(schema)
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core.utils import schema_to_pydantic_model
|
||||
|
||||
# Example 3: allOf merging with $refs
|
||||
schema = {
|
||||
"title": "EmployeeWithDepartment",
|
||||
"allOf": [{"$ref": "#/$defs/Employee"}, {"$ref": "#/$defs/Department"}],
|
||||
"$defs": {
|
||||
"Employee": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "string"}, "name": {"type": "string"}},
|
||||
"required": ["id", "name"],
|
||||
},
|
||||
"Department": {
|
||||
"type": "object",
|
||||
"properties": {"department": {"type": "string"}},
|
||||
"required": ["department"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Model = schema_to_pydantic_model(schema)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_core.utils import schema_to_pydantic_model
|
||||
|
||||
# Example 4: Self-referencing (recursive) model
|
||||
schema = {
|
||||
"title": "Category",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"subcategories": {"type": "array", "items": {"$ref": "#/$defs/Category"}},
|
||||
},
|
||||
"required": ["name"],
|
||||
"$defs": {
|
||||
"Category": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"subcategories": {"type": "array", "items": {"$ref": "#/$defs/Category"}},
|
||||
},
|
||||
"required": ["name"],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
Category = schema_to_pydantic_model(schema)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Example 5: Serializing and deserializing with Pydantic
|
||||
|
||||
from uuid import uuid4
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from typing import Optional, List, Dict, Any
|
||||
from agentdhal_core.utils import schema_to_pydantic_model
|
||||
|
||||
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
zipcode: str
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: EmailStr
|
||||
age: int = Field(..., ge=18)
|
||||
address: Address
|
||||
|
||||
|
||||
class Employee(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
manager: Optional["Employee"] = None
|
||||
|
||||
|
||||
class Department(BaseModel):
|
||||
name: str
|
||||
employees: List[Employee]
|
||||
|
||||
|
||||
class ComplexModel(BaseModel):
|
||||
user: User
|
||||
extra_info: Optional[Dict[str, Any]] = None
|
||||
sub_items: List[Employee]
|
||||
|
||||
|
||||
# Convert ComplexModel to JSON schema
|
||||
complex_schema = ComplexModel.model_json_schema()
|
||||
|
||||
# Rebuild a new Pydantic model from JSON schema
|
||||
ReconstructedModel = schema_to_pydantic_model(complex_schema, "ComplexModel")
|
||||
|
||||
# Instantiate reconstructed model
|
||||
reconstructed = ReconstructedModel(
|
||||
user={
|
||||
"id": str(uuid4()),
|
||||
"name": "Alice",
|
||||
"email": "alice@example.com",
|
||||
"age": 30,
|
||||
"address": {"street": "123 Main St", "city": "Wonderland", "zipcode": "12345"},
|
||||
},
|
||||
sub_items=[{"id": str(uuid4()), "name": "Bob", "manager": {"id": str(uuid4()), "name": "Eve"}}],
|
||||
)
|
||||
|
||||
print(reconstructed.model_dump())
|
||||
|
||||
|
||||
Args:
|
||||
schema (Dict[str, Any]): A valid JSON Schema dictionary.
|
||||
model_name (str, optional): The name of the root model. Defaults to "GeneratedModel".
|
||||
|
||||
Returns:
|
||||
Type[BaseModel]: A dynamically generated Pydantic model class.
|
||||
|
||||
Raises:
|
||||
ReferenceNotFoundError: If a `$ref` key references a missing entry.
|
||||
FormatNotSupportedError: If a `format` keyword is unknown or unsupported.
|
||||
UnsupportedKeywordError: If the schema contains an unsupported `type`.
|
||||
|
||||
See Also:
|
||||
- :class:`pydantic.BaseModel`
|
||||
- :func:`pydantic.create_model`
|
||||
- https://json-schema.org/
|
||||
"""
|
||||
...
|
||||
|
||||
return _JSONSchemaToPydantic().json_schema_to_pydantic(schema, model_name)
|
||||
20
agent_dhal/agentdhal_core/utils/_load_json.py
Normal file
20
agent_dhal/agentdhal_core/utils/_load_json.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def extract_json_from_str(content: str) -> List[Dict[str, Any]]:
|
||||
"""Extract JSON objects from a string. Supports backtick enclosed JSON objects"""
|
||||
pattern = re.compile(r"```(?:\s*([\w\+\-]+))?\n([\s\S]*?)```")
|
||||
matches = pattern.findall(content)
|
||||
ret: List[Dict[str, Any]] = []
|
||||
# If no matches found, assume the entire content is a JSON object
|
||||
if not matches:
|
||||
ret.append(json.loads(content))
|
||||
for match in matches:
|
||||
language = match[0].strip() if match[0] else None
|
||||
if language and language.lower() != "json":
|
||||
raise ValueError(f"Expected JSON object, but found language: {language}")
|
||||
content = match[1]
|
||||
ret.append(json.loads(content))
|
||||
return ret
|
||||
Reference in New Issue
Block a user