first commit

This commit is contained in:
DigiJ
2026-03-13 12:56:43 -07:00
commit 159cf9fcfe
309 changed files with 64584 additions and 0 deletions

View File

@@ -0,0 +1,31 @@
from ._base import (
BaseStreamTool,
BaseTool,
BaseToolWithState,
ParametersSchema,
StreamTool,
Tool,
ToolOverride,
ToolSchema,
)
from ._function_tool import FunctionTool
from ._static_workbench import StaticStreamWorkbench, StaticWorkbench
from ._workbench import ImageResultContent, TextResultContent, ToolResult, Workbench
__all__ = [
"Tool",
"StreamTool",
"ToolSchema",
"ParametersSchema",
"BaseTool",
"BaseToolWithState",
"BaseStreamTool",
"FunctionTool",
"Workbench",
"ToolResult",
"TextResultContent",
"ImageResultContent",
"StaticWorkbench",
"StaticStreamWorkbench",
"ToolOverride",
]

View File

@@ -0,0 +1,294 @@
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import (
Any,
AsyncGenerator,
Dict,
Generic,
Mapping,
Optional,
Protocol,
Type,
TypeVar,
cast,
runtime_checkable,
)
import jsonref
from pydantic import BaseModel
from typing_extensions import NotRequired, TypedDict
from .. import EVENT_LOGGER_NAME, CancellationToken
from .._component_config import ComponentBase
from .._function_utils import normalize_annotated_type
from .._telemetry import trace_tool_span
from ..logging import ToolCallEvent
T = TypeVar("T", bound=BaseModel, contravariant=True)
logger = logging.getLogger(EVENT_LOGGER_NAME)
class ParametersSchema(TypedDict):
type: str
properties: Dict[str, Any]
required: NotRequired[Sequence[str]]
additionalProperties: NotRequired[bool]
class ToolSchema(TypedDict):
parameters: NotRequired[ParametersSchema]
name: str
description: NotRequired[str]
strict: NotRequired[bool]
class ToolOverride(BaseModel):
"""Override configuration for a tool's name and/or description."""
name: Optional[str] = None
description: Optional[str] = None
@runtime_checkable
class Tool(Protocol):
@property
def name(self) -> str: ...
@property
def description(self) -> str: ...
@property
def schema(self) -> ToolSchema: ...
def args_type(self) -> Type[BaseModel]: ...
def return_type(self) -> Type[Any]: ...
def state_type(self) -> Type[BaseModel] | None: ...
def return_value_as_string(self, value: Any) -> str: ...
async def run_json(
self, args: Mapping[str, Any], cancellation_token: CancellationToken, call_id: str | None = None
) -> Any: ...
async def save_state_json(self) -> Mapping[str, Any]: ...
async def load_state_json(self, state: Mapping[str, Any]) -> None: ...
@runtime_checkable
class StreamTool(Tool, Protocol):
def run_json_stream(
self, args: Mapping[str, Any], cancellation_token: CancellationToken, call_id: str | None = None
) -> AsyncGenerator[Any, None]: ...
ArgsT = TypeVar("ArgsT", bound=BaseModel, contravariant=True)
ReturnT = TypeVar("ReturnT", bound=BaseModel, covariant=True)
StateT = TypeVar("StateT", bound=BaseModel)
StreamT = TypeVar("StreamT", bound=BaseModel, covariant=True)
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT], ComponentBase[BaseModel]):
component_type = "tool"
def __init__(
self,
args_type: Type[ArgsT],
return_type: Type[ReturnT],
name: str,
description: str,
strict: bool = False,
) -> None:
self._args_type = args_type
# Normalize Annotated to the base type.
self._return_type = normalize_annotated_type(return_type)
self._name = name
self._description = description
self._strict = strict
@property
def schema(self) -> ToolSchema:
model_schema: Dict[str, Any] = self._args_type.model_json_schema()
if "$defs" in model_schema:
model_schema = cast(Dict[str, Any], jsonref.replace_refs(obj=model_schema, proxies=False)) # type: ignore
del model_schema["$defs"]
parameters = ParametersSchema(
type="object",
properties=model_schema["properties"],
required=model_schema.get("required", []),
additionalProperties=model_schema.get("additionalProperties", False),
)
# If strict is enabled, the tool schema should list all properties as required.
assert "required" in parameters
if self._strict and set(parameters["required"]) != set(parameters["properties"].keys()):
raise ValueError(
"Strict mode is enabled, but not all input arguments are marked as required. Default arguments are not allowed in strict mode."
)
assert "additionalProperties" in parameters
if self._strict and parameters["additionalProperties"]:
raise ValueError(
"Strict mode is enabled but additional argument is also enabled. This is not allowed in strict mode."
)
tool_schema = ToolSchema(
name=self._name,
description=self._description,
parameters=parameters,
strict=self._strict,
)
return tool_schema
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
def args_type(self) -> Type[BaseModel]:
return self._args_type
def return_type(self) -> Type[Any]:
return self._return_type
def state_type(self) -> Type[BaseModel] | None:
return None
def return_value_as_string(self, value: Any) -> str:
if isinstance(value, BaseModel):
dumped = value.model_dump()
if isinstance(dumped, dict):
return json.dumps(dumped)
return str(dumped)
return str(value)
@abstractmethod
async def run(self, args: ArgsT, cancellation_token: CancellationToken) -> ReturnT: ...
async def run_json(
self, args: Mapping[str, Any], cancellation_token: CancellationToken, call_id: str | None = None
) -> Any:
"""Run the tool with the provided arguments in a dictionary.
Args:
args (Mapping[str, Any]): The arguments to pass to the tool.
cancellation_token (CancellationToken): A token to cancel the operation if needed.
call_id (str | None): An optional identifier for the tool call, used for tracing.
Returns:
Any: The return value of the tool's run method.
"""
with trace_tool_span(
tool_name=self._name,
tool_description=self._description,
tool_call_id=call_id,
):
# Execute the tool's run method
return_value = await self.run(self._args_type.model_validate(args), cancellation_token)
# Log the tool call event
event = ToolCallEvent(
tool_name=self.name,
arguments=dict(args), # Using the raw args passed to run_json
result=self.return_value_as_string(return_value),
)
logger.info(event)
return return_value
async def save_state_json(self) -> Mapping[str, Any]:
return {}
async def load_state_json(self, state: Mapping[str, Any]) -> None:
pass
class BaseStreamTool(
BaseTool[ArgsT, ReturnT], StreamTool, ABC, Generic[ArgsT, StreamT, ReturnT], ComponentBase[BaseModel]
):
component_type = "tool"
@abstractmethod
def run_stream(self, args: ArgsT, cancellation_token: CancellationToken) -> AsyncGenerator[StreamT | ReturnT, None]:
"""Run the tool with the provided arguments and return a stream of data and end with the final return value."""
...
async def run_json_stream(
self,
args: Mapping[str, Any],
cancellation_token: CancellationToken,
call_id: str | None = None,
) -> AsyncGenerator[StreamT | ReturnT, None]:
"""Run the tool with the provided arguments in a dictionary and return a stream of data
from the tool's :meth:`run_stream` method and end with the final return value.
Args:
args (Mapping[str, Any]): The arguments to pass to the tool.
cancellation_token (CancellationToken): A token to cancel the operation if needed.
call_id (str | None): An optional identifier for the tool call, used for tracing.
Returns:
AsyncGenerator[StreamT | ReturnT, None]: A generator yielding results from the tool's :meth:`run_stream` method.
"""
return_value: ReturnT | StreamT | None = None
with trace_tool_span(
tool_name=self._name,
tool_description=self._description,
tool_call_id=call_id,
):
# Execute the tool's run_stream method
async for result in self.run_stream(self._args_type.model_validate(args), cancellation_token):
return_value = result
yield result
assert return_value is not None, "The tool must yield a final return value at the end of the stream."
if not isinstance(return_value, self._return_type):
raise TypeError(
f"Expected return value of type {self._return_type.__name__}, but got {type(return_value).__name__}"
)
# Log the tool call event
event = ToolCallEvent(
tool_name=self.name,
arguments=dict(args), # Using the raw args passed to run_json
result=self.return_value_as_string(return_value),
)
logger.info(event)
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT], ComponentBase[BaseModel]):
def __init__(
self,
args_type: Type[ArgsT],
return_type: Type[ReturnT],
state_type: Type[StateT],
name: str,
description: str,
) -> None:
super().__init__(args_type, return_type, name, description)
self._state_type = state_type
component_type = "tool"
@abstractmethod
def save_state(self) -> StateT: ...
@abstractmethod
def load_state(self, state: StateT) -> None: ...
async def save_state_json(self) -> Mapping[str, Any]:
return self.save_state().model_dump()
async def load_state_json(self, state: Mapping[str, Any]) -> None:
self.load_state(self._state_type.model_validate(state))

View File

@@ -0,0 +1,181 @@
import asyncio
import functools
import warnings
from textwrap import dedent
from typing import Any, Callable, Sequence
from pydantic import BaseModel
from typing_extensions import Self
from .. import CancellationToken
from .._component_config import Component
from .._function_utils import (
args_base_model_from_signature,
get_typed_signature,
)
from ..code_executor._func_with_reqs import Import, import_to_str, to_code
from ._base import BaseTool
class FunctionToolConfig(BaseModel):
"""Configuration for a function tool."""
source_code: str
name: str
description: str
global_imports: Sequence[Import]
has_cancellation_support: bool
class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]):
"""
Create custom tools by wrapping standard Python functions.
`FunctionTool` offers an interface for executing Python functions either asynchronously or synchronously.
Each function must include type annotations for all parameters and its return type. These annotations
enable `FunctionTool` to generate a schema necessary for input validation, serialization, and for informing
the LLM about expected parameters. When the LLM prepares a function call, it leverages this schema to
generate arguments that align with the function's specifications.
.. note::
It is the user's responsibility to verify that the tool's output type matches the expected type.
Args:
func (Callable[..., ReturnT | Awaitable[ReturnT]]): The function to wrap and expose as a tool.
description (str): A description to inform the model of the function's purpose, specifying what
it does and the context in which it should be called.
name (str, optional): An optional custom name for the tool. Defaults to
the function's original name if not provided.
strict (bool, optional): If set to True, the tool schema will only contain arguments that are explicitly
defined in the function signature, and no default values will be allowed. Defaults to False.
This is required to be set to True when used with models in structured output mode.
Example:
.. code-block:: python
import random
from agentdhal_core import CancellationToken
from agentdhal_core.tools import FunctionTool
from typing_extensions import Annotated
import asyncio
async def get_stock_price(ticker: str, date: Annotated[str, "Date in YYYY/MM/DD"]) -> float:
# Simulates a stock price retrieval by returning a random float within a specified range.
return random.uniform(10, 200)
async def example():
# Initialize a FunctionTool instance for retrieving stock prices.
stock_price_tool = FunctionTool(get_stock_price, description="Fetch the stock price for a given ticker.")
# Execute the tool with cancellation support.
cancellation_token = CancellationToken()
result = await stock_price_tool.run_json({"ticker": "AAPL", "date": "2021/01/01"}, cancellation_token)
# Output the result as a formatted string.
print(stock_price_tool.return_value_as_string(result))
asyncio.run(example())
"""
component_provider_override = "agentdhal_core.tools.FunctionTool"
component_config_schema = FunctionToolConfig
def __init__(
self,
func: Callable[..., Any],
description: str,
name: str | None = None,
global_imports: Sequence[Import] = [],
strict: bool = False,
) -> None:
self._func = func
self._global_imports = global_imports
self._signature = get_typed_signature(func)
func_name = name or func.func.__name__ if isinstance(func, functools.partial) else name or func.__name__
args_model = args_base_model_from_signature(func_name + "args", self._signature)
self._has_cancellation_support = "cancellation_token" in self._signature.parameters
return_type = self._signature.return_annotation
super().__init__(args_model, return_type, func_name, description, strict)
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
kwargs = {}
for name in self._signature.parameters.keys():
if hasattr(args, name):
kwargs[name] = getattr(args, name)
if asyncio.iscoroutinefunction(self._func):
if self._has_cancellation_support:
result = await self._func(**kwargs, cancellation_token=cancellation_token)
else:
result = await self._func(**kwargs)
else:
if self._has_cancellation_support:
result = await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
self._func,
**kwargs,
cancellation_token=cancellation_token,
),
)
else:
future = asyncio.get_event_loop().run_in_executor(None, functools.partial(self._func, **kwargs))
cancellation_token.link_future(future)
result = await future
return result
def _to_config(self) -> FunctionToolConfig:
return FunctionToolConfig(
source_code=dedent(to_code(self._func)),
global_imports=self._global_imports,
name=self.name,
description=self.description,
has_cancellation_support=self._has_cancellation_support,
)
@classmethod
def _from_config(cls, config: FunctionToolConfig) -> Self:
warnings.warn(
"\n⚠️ SECURITY WARNING ⚠️\n"
"Loading a FunctionTool from config will execute code to import the provided global imports and and function code.\n"
"Only load configs from TRUSTED sources to prevent arbitrary code execution.",
UserWarning,
stacklevel=2,
)
exec_globals: dict[str, Any] = {}
# Execute imports first
for import_stmt in config.global_imports:
import_code = import_to_str(import_stmt)
try:
exec(import_code, exec_globals)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Failed to import {import_code}: Module not found. Please ensure the module is installed."
) from e
except ImportError as e:
raise ImportError(f"Failed to import {import_code}: {str(e)}") from e
except Exception as e:
raise RuntimeError(f"Unexpected error while importing {import_code}: {str(e)}") from e
# Execute function code
try:
exec(config.source_code, exec_globals)
func_name = config.source_code.split("def ")[1].split("(")[0]
except Exception as e:
raise ValueError(f"Could not compile and load function: {e}") from e
# Get function and verify it's callable
func: Callable[..., Any] = exec_globals[func_name]
if not callable(func):
raise TypeError(f"Expected function but got {type(func)}")
return cls(func, name=config.name, description=config.description, global_imports=config.global_imports)

View File

@@ -0,0 +1,225 @@
import asyncio
import builtins
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional
from pydantic import BaseModel, Field
from typing_extensions import Self
from .._cancellation_token import CancellationToken
from .._component_config import Component, ComponentModel
from ._base import BaseTool, StreamTool, ToolOverride, ToolSchema
from ._workbench import StreamWorkbench, TextResultContent, ToolResult, Workbench
class StaticWorkbenchConfig(BaseModel):
tools: List[ComponentModel] = []
tool_overrides: Dict[str, ToolOverride] = Field(default_factory=dict)
class StateicWorkbenchState(BaseModel):
type: Literal["StaticWorkbenchState"] = "StaticWorkbenchState"
tools: Dict[str, Mapping[str, Any]] = {}
class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
"""
A workbench that provides a static set of tools that do not change after
each tool execution.
Args:
tools (List[BaseTool[Any, Any]]): A list of tools to be included in the workbench.
The tools should be subclasses of :class:`~agentdhal_core.tools.BaseTool`.
tool_overrides (Optional[Dict[str, ToolOverride]]): Optional mapping of original tool
names to override configurations for name and/or description. This allows
customizing how tools appear to consumers while maintaining the underlying
tool functionality.
"""
component_provider_override = "agentdhal_core.tools.StaticWorkbench"
component_config_schema = StaticWorkbenchConfig
def __init__(
self, tools: List[BaseTool[Any, Any]], tool_overrides: Optional[Dict[str, ToolOverride]] = None
) -> None:
self._tools = tools
self._tool_overrides = tool_overrides or {}
# Build reverse mapping from override names to original names for call_tool
self._override_name_to_original: Dict[str, str] = {}
existing_tool_names = {tool.name for tool in self._tools}
for original_name, override in self._tool_overrides.items():
if override.name and override.name != original_name:
# Check for conflicts with existing tool names
if override.name in existing_tool_names and override.name != original_name:
raise ValueError(
f"Tool override name '{override.name}' conflicts with existing tool name. "
f"Override names must not conflict with any tool names."
)
# Check for conflicts with other override names
if override.name in self._override_name_to_original:
existing_original = self._override_name_to_original[override.name]
raise ValueError(
f"Tool override name '{override.name}' is used by multiple tools: "
f"'{existing_original}' and '{original_name}'. Override names must be unique."
)
self._override_name_to_original[override.name] = original_name
async def list_tools(self) -> List[ToolSchema]:
result_schemas: List[ToolSchema] = []
for tool in self._tools:
original_schema = tool.schema
# Apply overrides if they exist for this tool
if tool.name in self._tool_overrides:
override = self._tool_overrides[tool.name]
# Create a new ToolSchema with overrides applied
schema: ToolSchema = {
"name": override.name if override.name is not None else original_schema["name"],
"description": override.description
if override.description is not None
else original_schema.get("description", ""),
}
# Copy optional fields
if "parameters" in original_schema:
schema["parameters"] = original_schema["parameters"]
if "strict" in original_schema:
schema["strict"] = original_schema["strict"]
else:
schema = original_schema
result_schemas.append(schema)
return result_schemas
async def call_tool(
self,
name: str,
arguments: Mapping[str, Any] | None = None,
cancellation_token: CancellationToken | None = None,
call_id: str | None = None,
) -> ToolResult:
# Check if the name is an override name and map it back to the original
original_name = self._override_name_to_original.get(name, name)
tool = next((tool for tool in self._tools if tool.name == original_name), None)
if tool is None:
return ToolResult(
name=name, # Return the requested name (which might be overridden)
result=[TextResultContent(content=f"Tool {name} not found.")],
is_error=True,
)
if not cancellation_token:
cancellation_token = CancellationToken()
if not arguments:
arguments = {}
try:
result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id))
cancellation_token.link_future(result_future)
actual_tool_output = await result_future
is_error = False
result_str = tool.return_value_as_string(actual_tool_output)
except Exception as e:
result_str = self._format_errors(e)
is_error = True
return ToolResult(name=name, result=[TextResultContent(content=result_str)], is_error=is_error)
async def start(self) -> None:
return None
async def stop(self) -> None:
return None
async def reset(self) -> None:
return None
async def save_state(self) -> Mapping[str, Any]:
tool_states = StateicWorkbenchState()
for tool in self._tools:
tool_states.tools[tool.name] = await tool.save_state_json()
return tool_states.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
parsed_state = StateicWorkbenchState.model_validate(state)
for tool in self._tools:
if tool.name in parsed_state.tools:
await tool.load_state_json(parsed_state.tools[tool.name])
def _to_config(self) -> StaticWorkbenchConfig:
return StaticWorkbenchConfig(
tools=[tool.dump_component() for tool in self._tools], tool_overrides=self._tool_overrides
)
@classmethod
def _from_config(cls, config: StaticWorkbenchConfig) -> Self:
return cls(tools=[BaseTool.load_component(tool) for tool in config.tools], tool_overrides=config.tool_overrides)
def _format_errors(self, error: Exception) -> str:
"""Recursively format errors into a string."""
error_message = ""
if hasattr(builtins, "ExceptionGroup") and isinstance(error, builtins.ExceptionGroup):
# ExceptionGroup is available in Python 3.11+.
# TODO: how to make this compatible with Python 3.10?
for sub_exception in error.exceptions: # type: ignore
error_message += self._format_errors(sub_exception) # type: ignore
else:
error_message += f"{str(error)}\n"
return error_message.strip()
class StaticStreamWorkbench(StaticWorkbench, StreamWorkbench):
"""
A workbench that provides a static set of tools that do not change after
each tool execution, and supports streaming results.
"""
component_provider_override = "agentdhal_core.tools.StaticStreamWorkbench"
async def call_tool_stream(
self,
name: str,
arguments: Mapping[str, Any] | None = None,
cancellation_token: CancellationToken | None = None,
call_id: str | None = None,
) -> AsyncGenerator[Any | ToolResult, None]:
tool = next((tool for tool in self._tools if tool.name == name), None)
if tool is None:
yield ToolResult(
name=name,
result=[TextResultContent(content=f"Tool {name} not found.")],
is_error=True,
)
return
if not cancellation_token:
cancellation_token = CancellationToken()
if not arguments:
arguments = {}
try:
actual_tool_output: Any | None = None
if isinstance(tool, StreamTool):
previous_result: Any | None = None
try:
async for result in tool.run_json_stream(arguments, cancellation_token, call_id=call_id):
if previous_result is not None:
yield previous_result
previous_result = result
actual_tool_output = previous_result
except Exception as e:
# If there was a previous result before the exception, yield it first
if previous_result is not None:
yield previous_result
# Then yield the error result
result_str = self._format_errors(e)
yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=True)
return
else:
# If the tool is not a stream tool, we run it normally and yield the result
result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id))
cancellation_token.link_future(result_future)
actual_tool_output = await result_future
is_error = False
result_str = tool.return_value_as_string(actual_tool_output)
except Exception as e:
result_str = self._format_errors(e)
is_error = True
yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error)

View File

@@ -0,0 +1,216 @@
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Type
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Self
from .._cancellation_token import CancellationToken
from .._component_config import ComponentBase
from .._image import Image
from ._base import ToolSchema
class TextResultContent(BaseModel):
"""
Text result content of a tool execution.
"""
type: Literal["TextResultContent"] = "TextResultContent"
content: str
"""The text content of the result."""
class ImageResultContent(BaseModel):
"""
Image result content of a tool execution.
"""
type: Literal["ImageResultContent"] = "ImageResultContent"
content: Image
"""The image content of the result."""
ResultContent = Annotated[TextResultContent | ImageResultContent, Field(discriminator="type")]
class ToolResult(BaseModel):
"""
A result of a tool execution by a workbench.
"""
type: Literal["ToolResult"] = "ToolResult"
name: str
"""The name of the tool that was executed."""
result: List[ResultContent]
"""The result of the tool execution."""
is_error: bool = False
"""Whether the tool execution resulted in an error."""
def to_text(self, replace_image: str | None = None) -> str:
"""
Convert the result to a text string.
Args:
replace_image (str | None): The string to replace the image content with.
If None, the image content will be included in the text as base64 string.
Returns:
str: The text representation of the result.
"""
parts: List[str] = []
for content in self.result:
if isinstance(content, TextResultContent):
parts.append(content.content)
elif isinstance(content, ImageResultContent):
if replace_image is not None:
parts.append(replace_image)
else:
parts.append(f"[Image: {content.content.to_base64()}]")
return "\n".join(parts)
class Workbench(ABC, ComponentBase[BaseModel]):
"""
A workbench is a component that provides a set of tools that may share
resources and state.
A workbench is responsible for managing the lifecycle of the tools and
providing a single interface to call them. The tools provided by the workbench
may be dynamic and their availabilities may change after each tool execution.
A workbench can be started by calling the :meth:`~agentdhal_core.tools.Workbench.start` method
and stopped by calling the :meth:`~agentdhal_core.tools.Workbench.stop` method.
It can also be used as an asynchronous context manager, which will automatically
start and stop the workbench when entering and exiting the context.
"""
component_type = "workbench"
@abstractmethod
async def list_tools(self) -> List[ToolSchema]:
"""
List the currently available tools in the workbench as :class:`ToolSchema`
objects.
The list of tools may be dynamic, and their content may change after
tool execution.
"""
...
@abstractmethod
async def call_tool(
self,
name: str,
arguments: Mapping[str, Any] | None = None,
cancellation_token: CancellationToken | None = None,
call_id: str | None = None,
) -> ToolResult:
"""
Call a tool in the workbench.
Args:
name (str): The name of the tool to call.
arguments (Mapping[str, Any] | None): The arguments to pass to the tool.
If None, the tool will be called with no arguments.
cancellation_token (CancellationToken | None): An optional cancellation token
to cancel the tool execution.
call_id (str | None): An optional identifier for the tool call, used for tracing.
Returns:
ToolResult: The result of the tool execution.
"""
...
@abstractmethod
async def start(self) -> None:
"""
Start the workbench and initialize any resources.
This method should be called before using the workbench.
"""
...
@abstractmethod
async def stop(self) -> None:
"""
Stop the workbench and release any resources.
This method should be called when the workbench is no longer needed.
"""
...
@abstractmethod
async def reset(self) -> None:
"""
Reset the workbench to its initialized, started state.
"""
...
@abstractmethod
async def save_state(self) -> Mapping[str, Any]:
"""
Save the state of the workbench.
This method should be called to persist the state of the workbench.
"""
...
@abstractmethod
async def load_state(self, state: Mapping[str, Any]) -> None:
"""
Load the state of the workbench.
Args:
state (Mapping[str, Any]): The state to load into the workbench.
"""
...
async def __aenter__(self) -> Self:
"""
Enter the workbench context manager.
This method is called when the workbench is used in a `with` statement.
It calls the :meth:`~agentdhal_core.tools.WorkBench.start` method to start the workbench.
"""
await self.start()
return self
async def __aexit__(
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> None:
"""
Exit the workbench context manager.
This method is called when the workbench is used in a `with` statement.
It calls the :meth:`~agentdhal_core.tools.WorkBench.stop` method to stop the workbench.
"""
await self.stop()
class StreamWorkbench(Workbench, ABC):
"""A workbench that supports streaming results from tool calls."""
@abstractmethod
def call_tool_stream(
self,
name: str,
arguments: Mapping[str, Any] | None = None,
cancellation_token: CancellationToken | None = None,
call_id: str | None = None,
) -> AsyncGenerator[Any | ToolResult, None]:
"""
Call a tool in the workbench and return a stream of results.
Args:
name (str): The name of the tool to call.
arguments (Mapping[str, Any] | None): The arguments to pass to the tool
If None, the tool will be called with no arguments.
cancellation_token (CancellationToken | None): An optional cancellation token
to cancel the tool execution.
call_id (str | None): An optional identifier for the tool call, used for tracing.
"""
...