first commit
This commit is contained in:
31
agent_dhal/agentdhal_core/tools/__init__.py
Normal file
31
agent_dhal/agentdhal_core/tools/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from ._base import (
|
||||
BaseStreamTool,
|
||||
BaseTool,
|
||||
BaseToolWithState,
|
||||
ParametersSchema,
|
||||
StreamTool,
|
||||
Tool,
|
||||
ToolOverride,
|
||||
ToolSchema,
|
||||
)
|
||||
from ._function_tool import FunctionTool
|
||||
from ._static_workbench import StaticStreamWorkbench, StaticWorkbench
|
||||
from ._workbench import ImageResultContent, TextResultContent, ToolResult, Workbench
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"StreamTool",
|
||||
"ToolSchema",
|
||||
"ParametersSchema",
|
||||
"BaseTool",
|
||||
"BaseToolWithState",
|
||||
"BaseStreamTool",
|
||||
"FunctionTool",
|
||||
"Workbench",
|
||||
"ToolResult",
|
||||
"TextResultContent",
|
||||
"ImageResultContent",
|
||||
"StaticWorkbench",
|
||||
"StaticStreamWorkbench",
|
||||
"ToolOverride",
|
||||
]
|
||||
294
agent_dhal/agentdhal_core/tools/_base.py
Normal file
294
agent_dhal/agentdhal_core/tools/_base.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Generic,
|
||||
Mapping,
|
||||
Optional,
|
||||
Protocol,
|
||||
Type,
|
||||
TypeVar,
|
||||
cast,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
import jsonref
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from .. import EVENT_LOGGER_NAME, CancellationToken
|
||||
from .._component_config import ComponentBase
|
||||
from .._function_utils import normalize_annotated_type
|
||||
from .._telemetry import trace_tool_span
|
||||
from ..logging import ToolCallEvent
|
||||
|
||||
T = TypeVar("T", bound=BaseModel, contravariant=True)
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class ParametersSchema(TypedDict):
|
||||
type: str
|
||||
properties: Dict[str, Any]
|
||||
required: NotRequired[Sequence[str]]
|
||||
additionalProperties: NotRequired[bool]
|
||||
|
||||
|
||||
class ToolSchema(TypedDict):
|
||||
parameters: NotRequired[ParametersSchema]
|
||||
name: str
|
||||
description: NotRequired[str]
|
||||
strict: NotRequired[bool]
|
||||
|
||||
|
||||
class ToolOverride(BaseModel):
|
||||
"""Override configuration for a tool's name and/or description."""
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Tool(Protocol):
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def description(self) -> str: ...
|
||||
|
||||
@property
|
||||
def schema(self) -> ToolSchema: ...
|
||||
|
||||
def args_type(self) -> Type[BaseModel]: ...
|
||||
|
||||
def return_type(self) -> Type[Any]: ...
|
||||
|
||||
def state_type(self) -> Type[BaseModel] | None: ...
|
||||
|
||||
def return_value_as_string(self, value: Any) -> str: ...
|
||||
|
||||
async def run_json(
|
||||
self, args: Mapping[str, Any], cancellation_token: CancellationToken, call_id: str | None = None
|
||||
) -> Any: ...
|
||||
|
||||
async def save_state_json(self) -> Mapping[str, Any]: ...
|
||||
|
||||
async def load_state_json(self, state: Mapping[str, Any]) -> None: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class StreamTool(Tool, Protocol):
|
||||
def run_json_stream(
|
||||
self, args: Mapping[str, Any], cancellation_token: CancellationToken, call_id: str | None = None
|
||||
) -> AsyncGenerator[Any, None]: ...
|
||||
|
||||
|
||||
ArgsT = TypeVar("ArgsT", bound=BaseModel, contravariant=True)
|
||||
ReturnT = TypeVar("ReturnT", bound=BaseModel, covariant=True)
|
||||
StateT = TypeVar("StateT", bound=BaseModel)
|
||||
StreamT = TypeVar("StreamT", bound=BaseModel, covariant=True)
|
||||
|
||||
|
||||
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT], ComponentBase[BaseModel]):
|
||||
component_type = "tool"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args_type: Type[ArgsT],
|
||||
return_type: Type[ReturnT],
|
||||
name: str,
|
||||
description: str,
|
||||
strict: bool = False,
|
||||
) -> None:
|
||||
self._args_type = args_type
|
||||
# Normalize Annotated to the base type.
|
||||
self._return_type = normalize_annotated_type(return_type)
|
||||
self._name = name
|
||||
self._description = description
|
||||
self._strict = strict
|
||||
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
model_schema: Dict[str, Any] = self._args_type.model_json_schema()
|
||||
|
||||
if "$defs" in model_schema:
|
||||
model_schema = cast(Dict[str, Any], jsonref.replace_refs(obj=model_schema, proxies=False)) # type: ignore
|
||||
del model_schema["$defs"]
|
||||
|
||||
parameters = ParametersSchema(
|
||||
type="object",
|
||||
properties=model_schema["properties"],
|
||||
required=model_schema.get("required", []),
|
||||
additionalProperties=model_schema.get("additionalProperties", False),
|
||||
)
|
||||
|
||||
# If strict is enabled, the tool schema should list all properties as required.
|
||||
assert "required" in parameters
|
||||
if self._strict and set(parameters["required"]) != set(parameters["properties"].keys()):
|
||||
raise ValueError(
|
||||
"Strict mode is enabled, but not all input arguments are marked as required. Default arguments are not allowed in strict mode."
|
||||
)
|
||||
|
||||
assert "additionalProperties" in parameters
|
||||
if self._strict and parameters["additionalProperties"]:
|
||||
raise ValueError(
|
||||
"Strict mode is enabled but additional argument is also enabled. This is not allowed in strict mode."
|
||||
)
|
||||
|
||||
tool_schema = ToolSchema(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
parameters=parameters,
|
||||
strict=self._strict,
|
||||
)
|
||||
return tool_schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
def args_type(self) -> Type[BaseModel]:
|
||||
return self._args_type
|
||||
|
||||
def return_type(self) -> Type[Any]:
|
||||
return self._return_type
|
||||
|
||||
def state_type(self) -> Type[BaseModel] | None:
|
||||
return None
|
||||
|
||||
def return_value_as_string(self, value: Any) -> str:
|
||||
if isinstance(value, BaseModel):
|
||||
dumped = value.model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return json.dumps(dumped)
|
||||
return str(dumped)
|
||||
|
||||
return str(value)
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, args: ArgsT, cancellation_token: CancellationToken) -> ReturnT: ...
|
||||
|
||||
async def run_json(
|
||||
self, args: Mapping[str, Any], cancellation_token: CancellationToken, call_id: str | None = None
|
||||
) -> Any:
|
||||
"""Run the tool with the provided arguments in a dictionary.
|
||||
|
||||
Args:
|
||||
args (Mapping[str, Any]): The arguments to pass to the tool.
|
||||
cancellation_token (CancellationToken): A token to cancel the operation if needed.
|
||||
call_id (str | None): An optional identifier for the tool call, used for tracing.
|
||||
|
||||
Returns:
|
||||
Any: The return value of the tool's run method.
|
||||
"""
|
||||
with trace_tool_span(
|
||||
tool_name=self._name,
|
||||
tool_description=self._description,
|
||||
tool_call_id=call_id,
|
||||
):
|
||||
# Execute the tool's run method
|
||||
return_value = await self.run(self._args_type.model_validate(args), cancellation_token)
|
||||
|
||||
# Log the tool call event
|
||||
event = ToolCallEvent(
|
||||
tool_name=self.name,
|
||||
arguments=dict(args), # Using the raw args passed to run_json
|
||||
result=self.return_value_as_string(return_value),
|
||||
)
|
||||
logger.info(event)
|
||||
|
||||
return return_value
|
||||
|
||||
async def save_state_json(self) -> Mapping[str, Any]:
|
||||
return {}
|
||||
|
||||
async def load_state_json(self, state: Mapping[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BaseStreamTool(
|
||||
BaseTool[ArgsT, ReturnT], StreamTool, ABC, Generic[ArgsT, StreamT, ReturnT], ComponentBase[BaseModel]
|
||||
):
|
||||
component_type = "tool"
|
||||
|
||||
@abstractmethod
|
||||
def run_stream(self, args: ArgsT, cancellation_token: CancellationToken) -> AsyncGenerator[StreamT | ReturnT, None]:
|
||||
"""Run the tool with the provided arguments and return a stream of data and end with the final return value."""
|
||||
...
|
||||
|
||||
async def run_json_stream(
|
||||
self,
|
||||
args: Mapping[str, Any],
|
||||
cancellation_token: CancellationToken,
|
||||
call_id: str | None = None,
|
||||
) -> AsyncGenerator[StreamT | ReturnT, None]:
|
||||
"""Run the tool with the provided arguments in a dictionary and return a stream of data
|
||||
from the tool's :meth:`run_stream` method and end with the final return value.
|
||||
|
||||
Args:
|
||||
args (Mapping[str, Any]): The arguments to pass to the tool.
|
||||
cancellation_token (CancellationToken): A token to cancel the operation if needed.
|
||||
call_id (str | None): An optional identifier for the tool call, used for tracing.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[StreamT | ReturnT, None]: A generator yielding results from the tool's :meth:`run_stream` method.
|
||||
"""
|
||||
return_value: ReturnT | StreamT | None = None
|
||||
with trace_tool_span(
|
||||
tool_name=self._name,
|
||||
tool_description=self._description,
|
||||
tool_call_id=call_id,
|
||||
):
|
||||
# Execute the tool's run_stream method
|
||||
async for result in self.run_stream(self._args_type.model_validate(args), cancellation_token):
|
||||
return_value = result
|
||||
yield result
|
||||
|
||||
assert return_value is not None, "The tool must yield a final return value at the end of the stream."
|
||||
if not isinstance(return_value, self._return_type):
|
||||
raise TypeError(
|
||||
f"Expected return value of type {self._return_type.__name__}, but got {type(return_value).__name__}"
|
||||
)
|
||||
|
||||
# Log the tool call event
|
||||
event = ToolCallEvent(
|
||||
tool_name=self.name,
|
||||
arguments=dict(args), # Using the raw args passed to run_json
|
||||
result=self.return_value_as_string(return_value),
|
||||
)
|
||||
logger.info(event)
|
||||
|
||||
|
||||
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT], ComponentBase[BaseModel]):
|
||||
def __init__(
|
||||
self,
|
||||
args_type: Type[ArgsT],
|
||||
return_type: Type[ReturnT],
|
||||
state_type: Type[StateT],
|
||||
name: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
super().__init__(args_type, return_type, name, description)
|
||||
self._state_type = state_type
|
||||
|
||||
component_type = "tool"
|
||||
|
||||
@abstractmethod
|
||||
def save_state(self) -> StateT: ...
|
||||
|
||||
@abstractmethod
|
||||
def load_state(self, state: StateT) -> None: ...
|
||||
|
||||
async def save_state_json(self) -> Mapping[str, Any]:
|
||||
return self.save_state().model_dump()
|
||||
|
||||
async def load_state_json(self, state: Mapping[str, Any]) -> None:
|
||||
self.load_state(self._state_type.model_validate(state))
|
||||
181
agent_dhal/agentdhal_core/tools/_function_tool.py
Normal file
181
agent_dhal/agentdhal_core/tools/_function_tool.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import warnings
|
||||
from textwrap import dedent
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .. import CancellationToken
|
||||
from .._component_config import Component
|
||||
from .._function_utils import (
|
||||
args_base_model_from_signature,
|
||||
get_typed_signature,
|
||||
)
|
||||
from ..code_executor._func_with_reqs import Import, import_to_str, to_code
|
||||
from ._base import BaseTool
|
||||
|
||||
|
||||
class FunctionToolConfig(BaseModel):
|
||||
"""Configuration for a function tool."""
|
||||
|
||||
source_code: str
|
||||
name: str
|
||||
description: str
|
||||
global_imports: Sequence[Import]
|
||||
has_cancellation_support: bool
|
||||
|
||||
|
||||
class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]):
|
||||
"""
|
||||
Create custom tools by wrapping standard Python functions.
|
||||
|
||||
`FunctionTool` offers an interface for executing Python functions either asynchronously or synchronously.
|
||||
Each function must include type annotations for all parameters and its return type. These annotations
|
||||
enable `FunctionTool` to generate a schema necessary for input validation, serialization, and for informing
|
||||
the LLM about expected parameters. When the LLM prepares a function call, it leverages this schema to
|
||||
generate arguments that align with the function's specifications.
|
||||
|
||||
.. note::
|
||||
|
||||
It is the user's responsibility to verify that the tool's output type matches the expected type.
|
||||
|
||||
Args:
|
||||
func (Callable[..., ReturnT | Awaitable[ReturnT]]): The function to wrap and expose as a tool.
|
||||
description (str): A description to inform the model of the function's purpose, specifying what
|
||||
it does and the context in which it should be called.
|
||||
name (str, optional): An optional custom name for the tool. Defaults to
|
||||
the function's original name if not provided.
|
||||
strict (bool, optional): If set to True, the tool schema will only contain arguments that are explicitly
|
||||
defined in the function signature, and no default values will be allowed. Defaults to False.
|
||||
This is required to be set to True when used with models in structured output mode.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import random
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import FunctionTool
|
||||
from typing_extensions import Annotated
|
||||
import asyncio
|
||||
|
||||
|
||||
async def get_stock_price(ticker: str, date: Annotated[str, "Date in YYYY/MM/DD"]) -> float:
|
||||
# Simulates a stock price retrieval by returning a random float within a specified range.
|
||||
return random.uniform(10, 200)
|
||||
|
||||
|
||||
async def example():
|
||||
# Initialize a FunctionTool instance for retrieving stock prices.
|
||||
stock_price_tool = FunctionTool(get_stock_price, description="Fetch the stock price for a given ticker.")
|
||||
|
||||
# Execute the tool with cancellation support.
|
||||
cancellation_token = CancellationToken()
|
||||
result = await stock_price_tool.run_json({"ticker": "AAPL", "date": "2021/01/01"}, cancellation_token)
|
||||
|
||||
# Output the result as a formatted string.
|
||||
print(stock_price_tool.return_value_as_string(result))
|
||||
|
||||
|
||||
asyncio.run(example())
|
||||
"""
|
||||
|
||||
component_provider_override = "agentdhal_core.tools.FunctionTool"
|
||||
component_config_schema = FunctionToolConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[..., Any],
|
||||
description: str,
|
||||
name: str | None = None,
|
||||
global_imports: Sequence[Import] = [],
|
||||
strict: bool = False,
|
||||
) -> None:
|
||||
self._func = func
|
||||
self._global_imports = global_imports
|
||||
self._signature = get_typed_signature(func)
|
||||
func_name = name or func.func.__name__ if isinstance(func, functools.partial) else name or func.__name__
|
||||
args_model = args_base_model_from_signature(func_name + "args", self._signature)
|
||||
self._has_cancellation_support = "cancellation_token" in self._signature.parameters
|
||||
return_type = self._signature.return_annotation
|
||||
super().__init__(args_model, return_type, func_name, description, strict)
|
||||
|
||||
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
|
||||
kwargs = {}
|
||||
|
||||
for name in self._signature.parameters.keys():
|
||||
if hasattr(args, name):
|
||||
kwargs[name] = getattr(args, name)
|
||||
|
||||
if asyncio.iscoroutinefunction(self._func):
|
||||
if self._has_cancellation_support:
|
||||
result = await self._func(**kwargs, cancellation_token=cancellation_token)
|
||||
else:
|
||||
result = await self._func(**kwargs)
|
||||
else:
|
||||
if self._has_cancellation_support:
|
||||
result = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self._func,
|
||||
**kwargs,
|
||||
cancellation_token=cancellation_token,
|
||||
),
|
||||
)
|
||||
else:
|
||||
future = asyncio.get_event_loop().run_in_executor(None, functools.partial(self._func, **kwargs))
|
||||
cancellation_token.link_future(future)
|
||||
result = await future
|
||||
|
||||
return result
|
||||
|
||||
def _to_config(self) -> FunctionToolConfig:
|
||||
return FunctionToolConfig(
|
||||
source_code=dedent(to_code(self._func)),
|
||||
global_imports=self._global_imports,
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
has_cancellation_support=self._has_cancellation_support,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: FunctionToolConfig) -> Self:
|
||||
warnings.warn(
|
||||
"\n⚠️ SECURITY WARNING ⚠️\n"
|
||||
"Loading a FunctionTool from config will execute code to import the provided global imports and and function code.\n"
|
||||
"Only load configs from TRUSTED sources to prevent arbitrary code execution.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
exec_globals: dict[str, Any] = {}
|
||||
|
||||
# Execute imports first
|
||||
for import_stmt in config.global_imports:
|
||||
import_code = import_to_str(import_stmt)
|
||||
try:
|
||||
exec(import_code, exec_globals)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Failed to import {import_code}: Module not found. Please ensure the module is installed."
|
||||
) from e
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Failed to import {import_code}: {str(e)}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unexpected error while importing {import_code}: {str(e)}") from e
|
||||
|
||||
# Execute function code
|
||||
try:
|
||||
exec(config.source_code, exec_globals)
|
||||
func_name = config.source_code.split("def ")[1].split("(")[0]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not compile and load function: {e}") from e
|
||||
|
||||
# Get function and verify it's callable
|
||||
func: Callable[..., Any] = exec_globals[func_name]
|
||||
if not callable(func):
|
||||
raise TypeError(f"Expected function but got {type(func)}")
|
||||
|
||||
return cls(func, name=config.name, description=config.description, global_imports=config.global_imports)
|
||||
225
agent_dhal/agentdhal_core/tools/_static_workbench.py
Normal file
225
agent_dhal/agentdhal_core/tools/_static_workbench.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import asyncio
|
||||
import builtins
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._cancellation_token import CancellationToken
|
||||
from .._component_config import Component, ComponentModel
|
||||
from ._base import BaseTool, StreamTool, ToolOverride, ToolSchema
|
||||
from ._workbench import StreamWorkbench, TextResultContent, ToolResult, Workbench
|
||||
|
||||
|
||||
class StaticWorkbenchConfig(BaseModel):
|
||||
tools: List[ComponentModel] = []
|
||||
tool_overrides: Dict[str, ToolOverride] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class StateicWorkbenchState(BaseModel):
|
||||
type: Literal["StaticWorkbenchState"] = "StaticWorkbenchState"
|
||||
tools: Dict[str, Mapping[str, Any]] = {}
|
||||
|
||||
|
||||
class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
|
||||
"""
|
||||
A workbench that provides a static set of tools that do not change after
|
||||
each tool execution.
|
||||
|
||||
Args:
|
||||
tools (List[BaseTool[Any, Any]]): A list of tools to be included in the workbench.
|
||||
The tools should be subclasses of :class:`~agentdhal_core.tools.BaseTool`.
|
||||
tool_overrides (Optional[Dict[str, ToolOverride]]): Optional mapping of original tool
|
||||
names to override configurations for name and/or description. This allows
|
||||
customizing how tools appear to consumers while maintaining the underlying
|
||||
tool functionality.
|
||||
"""
|
||||
|
||||
component_provider_override = "agentdhal_core.tools.StaticWorkbench"
|
||||
component_config_schema = StaticWorkbenchConfig
|
||||
|
||||
def __init__(
|
||||
self, tools: List[BaseTool[Any, Any]], tool_overrides: Optional[Dict[str, ToolOverride]] = None
|
||||
) -> None:
|
||||
self._tools = tools
|
||||
self._tool_overrides = tool_overrides or {}
|
||||
|
||||
# Build reverse mapping from override names to original names for call_tool
|
||||
self._override_name_to_original: Dict[str, str] = {}
|
||||
existing_tool_names = {tool.name for tool in self._tools}
|
||||
|
||||
for original_name, override in self._tool_overrides.items():
|
||||
if override.name and override.name != original_name:
|
||||
# Check for conflicts with existing tool names
|
||||
if override.name in existing_tool_names and override.name != original_name:
|
||||
raise ValueError(
|
||||
f"Tool override name '{override.name}' conflicts with existing tool name. "
|
||||
f"Override names must not conflict with any tool names."
|
||||
)
|
||||
# Check for conflicts with other override names
|
||||
if override.name in self._override_name_to_original:
|
||||
existing_original = self._override_name_to_original[override.name]
|
||||
raise ValueError(
|
||||
f"Tool override name '{override.name}' is used by multiple tools: "
|
||||
f"'{existing_original}' and '{original_name}'. Override names must be unique."
|
||||
)
|
||||
self._override_name_to_original[override.name] = original_name
|
||||
|
||||
async def list_tools(self) -> List[ToolSchema]:
|
||||
result_schemas: List[ToolSchema] = []
|
||||
for tool in self._tools:
|
||||
original_schema = tool.schema
|
||||
|
||||
# Apply overrides if they exist for this tool
|
||||
if tool.name in self._tool_overrides:
|
||||
override = self._tool_overrides[tool.name]
|
||||
# Create a new ToolSchema with overrides applied
|
||||
schema: ToolSchema = {
|
||||
"name": override.name if override.name is not None else original_schema["name"],
|
||||
"description": override.description
|
||||
if override.description is not None
|
||||
else original_schema.get("description", ""),
|
||||
}
|
||||
# Copy optional fields
|
||||
if "parameters" in original_schema:
|
||||
schema["parameters"] = original_schema["parameters"]
|
||||
if "strict" in original_schema:
|
||||
schema["strict"] = original_schema["strict"]
|
||||
else:
|
||||
schema = original_schema
|
||||
|
||||
result_schemas.append(schema)
|
||||
return result_schemas
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: Mapping[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
call_id: str | None = None,
|
||||
) -> ToolResult:
|
||||
# Check if the name is an override name and map it back to the original
|
||||
original_name = self._override_name_to_original.get(name, name)
|
||||
|
||||
tool = next((tool for tool in self._tools if tool.name == original_name), None)
|
||||
if tool is None:
|
||||
return ToolResult(
|
||||
name=name, # Return the requested name (which might be overridden)
|
||||
result=[TextResultContent(content=f"Tool {name} not found.")],
|
||||
is_error=True,
|
||||
)
|
||||
if not cancellation_token:
|
||||
cancellation_token = CancellationToken()
|
||||
if not arguments:
|
||||
arguments = {}
|
||||
try:
|
||||
result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id))
|
||||
cancellation_token.link_future(result_future)
|
||||
actual_tool_output = await result_future
|
||||
is_error = False
|
||||
result_str = tool.return_value_as_string(actual_tool_output)
|
||||
except Exception as e:
|
||||
result_str = self._format_errors(e)
|
||||
is_error = True
|
||||
return ToolResult(name=name, result=[TextResultContent(content=result_str)], is_error=is_error)
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
|
||||
async def stop(self) -> None:
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
return None
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
tool_states = StateicWorkbenchState()
|
||||
for tool in self._tools:
|
||||
tool_states.tools[tool.name] = await tool.save_state_json()
|
||||
return tool_states.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
parsed_state = StateicWorkbenchState.model_validate(state)
|
||||
for tool in self._tools:
|
||||
if tool.name in parsed_state.tools:
|
||||
await tool.load_state_json(parsed_state.tools[tool.name])
|
||||
|
||||
def _to_config(self) -> StaticWorkbenchConfig:
|
||||
return StaticWorkbenchConfig(
|
||||
tools=[tool.dump_component() for tool in self._tools], tool_overrides=self._tool_overrides
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: StaticWorkbenchConfig) -> Self:
|
||||
return cls(tools=[BaseTool.load_component(tool) for tool in config.tools], tool_overrides=config.tool_overrides)
|
||||
|
||||
def _format_errors(self, error: Exception) -> str:
|
||||
"""Recursively format errors into a string."""
|
||||
|
||||
error_message = ""
|
||||
if hasattr(builtins, "ExceptionGroup") and isinstance(error, builtins.ExceptionGroup):
|
||||
# ExceptionGroup is available in Python 3.11+.
|
||||
# TODO: how to make this compatible with Python 3.10?
|
||||
for sub_exception in error.exceptions: # type: ignore
|
||||
error_message += self._format_errors(sub_exception) # type: ignore
|
||||
else:
|
||||
error_message += f"{str(error)}\n"
|
||||
return error_message.strip()
|
||||
|
||||
|
||||
class StaticStreamWorkbench(StaticWorkbench, StreamWorkbench):
|
||||
"""
|
||||
A workbench that provides a static set of tools that do not change after
|
||||
each tool execution, and supports streaming results.
|
||||
"""
|
||||
|
||||
component_provider_override = "agentdhal_core.tools.StaticStreamWorkbench"
|
||||
|
||||
async def call_tool_stream(
|
||||
self,
|
||||
name: str,
|
||||
arguments: Mapping[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
call_id: str | None = None,
|
||||
) -> AsyncGenerator[Any | ToolResult, None]:
|
||||
tool = next((tool for tool in self._tools if tool.name == name), None)
|
||||
if tool is None:
|
||||
yield ToolResult(
|
||||
name=name,
|
||||
result=[TextResultContent(content=f"Tool {name} not found.")],
|
||||
is_error=True,
|
||||
)
|
||||
return
|
||||
if not cancellation_token:
|
||||
cancellation_token = CancellationToken()
|
||||
if not arguments:
|
||||
arguments = {}
|
||||
try:
|
||||
actual_tool_output: Any | None = None
|
||||
if isinstance(tool, StreamTool):
|
||||
previous_result: Any | None = None
|
||||
try:
|
||||
async for result in tool.run_json_stream(arguments, cancellation_token, call_id=call_id):
|
||||
if previous_result is not None:
|
||||
yield previous_result
|
||||
previous_result = result
|
||||
actual_tool_output = previous_result
|
||||
except Exception as e:
|
||||
# If there was a previous result before the exception, yield it first
|
||||
if previous_result is not None:
|
||||
yield previous_result
|
||||
# Then yield the error result
|
||||
result_str = self._format_errors(e)
|
||||
yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=True)
|
||||
return
|
||||
else:
|
||||
# If the tool is not a stream tool, we run it normally and yield the result
|
||||
result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id))
|
||||
cancellation_token.link_future(result_future)
|
||||
actual_tool_output = await result_future
|
||||
is_error = False
|
||||
result_str = tool.return_value_as_string(actual_tool_output)
|
||||
except Exception as e:
|
||||
result_str = self._format_errors(e)
|
||||
is_error = True
|
||||
yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error)
|
||||
216
agent_dhal/agentdhal_core/tools/_workbench.py
Normal file
216
agent_dhal/agentdhal_core/tools/_workbench.py
Normal file
@@ -0,0 +1,216 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from types import TracebackType
|
||||
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, Self
|
||||
|
||||
from .._cancellation_token import CancellationToken
|
||||
from .._component_config import ComponentBase
|
||||
from .._image import Image
|
||||
from ._base import ToolSchema
|
||||
|
||||
|
||||
class TextResultContent(BaseModel):
|
||||
"""
|
||||
Text result content of a tool execution.
|
||||
"""
|
||||
|
||||
type: Literal["TextResultContent"] = "TextResultContent"
|
||||
|
||||
content: str
|
||||
"""The text content of the result."""
|
||||
|
||||
|
||||
class ImageResultContent(BaseModel):
|
||||
"""
|
||||
Image result content of a tool execution.
|
||||
"""
|
||||
|
||||
type: Literal["ImageResultContent"] = "ImageResultContent"
|
||||
|
||||
content: Image
|
||||
"""The image content of the result."""
|
||||
|
||||
|
||||
ResultContent = Annotated[TextResultContent | ImageResultContent, Field(discriminator="type")]
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""
|
||||
A result of a tool execution by a workbench.
|
||||
"""
|
||||
|
||||
type: Literal["ToolResult"] = "ToolResult"
|
||||
|
||||
name: str
|
||||
"""The name of the tool that was executed."""
|
||||
|
||||
result: List[ResultContent]
|
||||
"""The result of the tool execution."""
|
||||
|
||||
is_error: bool = False
|
||||
"""Whether the tool execution resulted in an error."""
|
||||
|
||||
def to_text(self, replace_image: str | None = None) -> str:
|
||||
"""
|
||||
Convert the result to a text string.
|
||||
|
||||
Args:
|
||||
replace_image (str | None): The string to replace the image content with.
|
||||
If None, the image content will be included in the text as base64 string.
|
||||
|
||||
Returns:
|
||||
str: The text representation of the result.
|
||||
"""
|
||||
parts: List[str] = []
|
||||
for content in self.result:
|
||||
if isinstance(content, TextResultContent):
|
||||
parts.append(content.content)
|
||||
elif isinstance(content, ImageResultContent):
|
||||
if replace_image is not None:
|
||||
parts.append(replace_image)
|
||||
else:
|
||||
parts.append(f"[Image: {content.content.to_base64()}]")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
class Workbench(ABC, ComponentBase[BaseModel]):
|
||||
"""
|
||||
A workbench is a component that provides a set of tools that may share
|
||||
resources and state.
|
||||
|
||||
A workbench is responsible for managing the lifecycle of the tools and
|
||||
providing a single interface to call them. The tools provided by the workbench
|
||||
may be dynamic and their availabilities may change after each tool execution.
|
||||
|
||||
A workbench can be started by calling the :meth:`~agentdhal_core.tools.Workbench.start` method
|
||||
and stopped by calling the :meth:`~agentdhal_core.tools.Workbench.stop` method.
|
||||
It can also be used as an asynchronous context manager, which will automatically
|
||||
start and stop the workbench when entering and exiting the context.
|
||||
"""
|
||||
|
||||
component_type = "workbench"
|
||||
|
||||
@abstractmethod
|
||||
async def list_tools(self) -> List[ToolSchema]:
|
||||
"""
|
||||
List the currently available tools in the workbench as :class:`ToolSchema`
|
||||
objects.
|
||||
|
||||
The list of tools may be dynamic, and their content may change after
|
||||
tool execution.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: Mapping[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
call_id: str | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Call a tool in the workbench.
|
||||
|
||||
Args:
|
||||
name (str): The name of the tool to call.
|
||||
arguments (Mapping[str, Any] | None): The arguments to pass to the tool.
|
||||
If None, the tool will be called with no arguments.
|
||||
cancellation_token (CancellationToken | None): An optional cancellation token
|
||||
to cancel the tool execution.
|
||||
call_id (str | None): An optional identifier for the tool call, used for tracing.
|
||||
Returns:
|
||||
ToolResult: The result of the tool execution.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the workbench and initialize any resources.
|
||||
|
||||
This method should be called before using the workbench.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""
|
||||
Stop the workbench and release any resources.
|
||||
|
||||
This method should be called when the workbench is no longer needed.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""
|
||||
Reset the workbench to its initialized, started state.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""
|
||||
Save the state of the workbench.
|
||||
|
||||
This method should be called to persist the state of the workbench.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""
|
||||
Load the state of the workbench.
|
||||
|
||||
Args:
|
||||
state (Mapping[str, Any]): The state to load into the workbench.
|
||||
"""
|
||||
...
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""
|
||||
Enter the workbench context manager.
|
||||
|
||||
This method is called when the workbench is used in a `with` statement.
|
||||
It calls the :meth:`~agentdhal_core.tools.WorkBench.start` method to start the workbench.
|
||||
"""
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
|
||||
) -> None:
|
||||
"""
|
||||
Exit the workbench context manager.
|
||||
This method is called when the workbench is used in a `with` statement.
|
||||
It calls the :meth:`~agentdhal_core.tools.WorkBench.stop` method to stop the workbench.
|
||||
"""
|
||||
await self.stop()
|
||||
|
||||
|
||||
class StreamWorkbench(Workbench, ABC):
|
||||
"""A workbench that supports streaming results from tool calls."""
|
||||
|
||||
@abstractmethod
|
||||
def call_tool_stream(
|
||||
self,
|
||||
name: str,
|
||||
arguments: Mapping[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
call_id: str | None = None,
|
||||
) -> AsyncGenerator[Any | ToolResult, None]:
|
||||
"""
|
||||
Call a tool in the workbench and return a stream of results.
|
||||
|
||||
Args:
|
||||
name (str): The name of the tool to call.
|
||||
arguments (Mapping[str, Any] | None): The arguments to pass to the tool
|
||||
If None, the tool will be called with no arguments.
|
||||
cancellation_token (CancellationToken | None): An optional cancellation token
|
||||
to cancel the tool execution.
|
||||
call_id (str | None): An optional identifier for the tool call, used for tracing.
|
||||
"""
|
||||
...
|
||||
Reference in New Issue
Block a user