first commit

This commit is contained in:
DigiJ
2026-03-13 12:56:43 -07:00
commit 159cf9fcfe
309 changed files with 64584 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
import importlib.metadata
__version__ = importlib.metadata.version("agentdhal_extensions")

View File

@@ -0,0 +1,10 @@
try:
from ._azure_ai_agent import AzureAIAgent
except ImportError as e:
raise ImportError(
"Dependencies for AzureAIAgent not found. "
'Please install autogen-ext with the "azure" extra: '
'pip install "agentdhal-ext[azure]"'
) from e
__all__ = ["AzureAIAgent"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,61 @@
from typing import Any, Awaitable, Callable, Iterable, List, Literal, Optional, TypeGuard, Union
from agentdhal_core.tools import Tool
from pydantic import BaseModel, Field
from azure.ai.agents.models import (
AzureAISearchToolDefinition,
AzureFunctionToolDefinition,
BingGroundingToolDefinition,
CodeInterpreterToolDefinition,
FileSearchToolDefinition,
MessageTextUrlCitationAnnotation,
)
ListToolType = Iterable[
Union[
Literal[
"file_search",
"code_interpreter",
"bing_grounding",
"azure_ai_search",
"azure_function",
],
BingGroundingToolDefinition,
CodeInterpreterToolDefinition,
AzureAISearchToolDefinition,
FileSearchToolDefinition,
AzureFunctionToolDefinition,
Tool,
Callable[..., Any],
Callable[..., Awaitable[Any]],
]
]
class AzureAIAgentState(BaseModel):
"""
Represents the state of an AzureAIAgent that can be saved and loaded.
This state model keeps track of persistent information about an agent session
including agent and thread identifiers, message history, and associated resources.
Attributes:
type (str): The type identifier for the state object, always "AzureAIAgentState"
agent_id (Optional[str]): The ID of the Azure AI agent
thread_id (Optional[str]): The ID of the conversation thread
initial_message_ids (List[str]): List of message IDs from the initial state
vector_store_id (Optional[str]): The ID of the associated vector store for file search
uploaded_file_ids (List[str]): List of IDs for files uploaded to the agent
"""
type: str = Field(default="AzureAIAgentState")
agent_id: Optional[str] = None
thread_id: Optional[str] = None
initial_message_ids: List[str] = Field(default_factory=list)
vector_store_id: Optional[str] = None
uploaded_file_ids: List[str] = Field(default_factory=list)
def has_annotations(obj: Any) -> TypeGuard[list[MessageTextUrlCitationAnnotation]]:
return obj is not None and isinstance(obj, list)

View File

@@ -0,0 +1,3 @@
from ._file_surfer import FileSurfer
__all__ = ["FileSurfer"]

View File

@@ -0,0 +1,208 @@
import json
import os
import traceback
from typing import List, Sequence, Tuple
from agentdhal_agentchat.agents import BaseChatAgent
from agentdhal_agentchat.base import Response
from agentdhal_agentchat.messages import (
BaseChatMessage,
TextMessage,
)
from agentdhal_agentchat.utils import remove_images
from agentdhal_core import CancellationToken, Component, ComponentModel, FunctionCall
from agentdhal_core.models import (
AssistantMessage,
ChatCompletionClient,
LLMMessage,
SystemMessage,
UserMessage,
)
from pydantic import BaseModel
from typing_extensions import Self
from ._markdown_file_browser import MarkdownFileBrowser
# from typing_extensions import Annotated
from ._tool_definitions import (
TOOL_FIND_NEXT,
TOOL_FIND_ON_PAGE_CTRL_F,
TOOL_OPEN_PATH,
TOOL_PAGE_DOWN,
TOOL_PAGE_UP,
)
class FileSurferConfig(BaseModel):
"""Configuration for FileSurfer agent"""
name: str
model_client: ComponentModel
description: str | None = None
class FileSurfer(BaseChatAgent, Component[FileSurferConfig]):
"""An agent, used by MagenticOne, that acts as a local file previewer. FileSurfer can open and read a variety of common file types, and can navigate the local file hierarchy.
Installation:
.. code-block:: bash
pip install "agentdhal-ext[file-surfer]"
Args:
name (str): The agent's name
model_client (ChatCompletionClient): The model to use (must be tool-use enabled)
description (str): The agent's description used by the team. Defaults to DEFAULT_DESCRIPTION
base_path (str): The base path to use for the file browser. Defaults to the current working directory.
"""
component_config_schema = FileSurferConfig
component_provider_override = "agentdhal_extensions.agents.file_surfer.FileSurfer"
DEFAULT_DESCRIPTION = "An agent that can handle local files."
DEFAULT_SYSTEM_MESSAGES = [
SystemMessage(
content="""
You are a helpful AI Assistant.
When given a user query, use available functions to help the user with their request."""
),
]
def __init__(
self,
name: str,
model_client: ChatCompletionClient,
description: str = DEFAULT_DESCRIPTION,
base_path: str = os.getcwd(),
) -> None:
super().__init__(name, description)
self._model_client = model_client
self._chat_history: List[LLMMessage] = []
self._browser = MarkdownFileBrowser(viewport_size=1024 * 5, base_path=base_path)
@property
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
for chat_message in messages:
self._chat_history.append(chat_message.to_model_message())
try:
_, content = await self._generate_reply(cancellation_token=cancellation_token)
self._chat_history.append(AssistantMessage(content=content, source=self.name))
return Response(chat_message=TextMessage(content=content, source=self.name))
except BaseException:
content = f"File surfing error:\n\n{traceback.format_exc()}"
self._chat_history.append(AssistantMessage(content=content, source=self.name))
return Response(chat_message=TextMessage(content=content, source=self.name))
async def on_reset(self, cancellation_token: CancellationToken) -> None:
self._chat_history.clear()
def _get_browser_state(self) -> Tuple[str, str]:
"""
Get the current state of the browser, including the header and content.
"""
header = f"Path: {self._browser.path}\n"
if self._browser.page_title is not None:
header += f"Title: {self._browser.page_title}\n"
current_page = self._browser.viewport_current_page
total_pages = len(self._browser.viewport_pages)
header += f"Viewport position: Showing page {current_page+1} of {total_pages}.\n"
return (header, self._browser.viewport)
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, str]:
history = self._chat_history[0:-1]
last_message = self._chat_history[-1]
assert isinstance(last_message, UserMessage)
task_content = last_message.content # the last message from the sender is the task
assert self._browser is not None
context_message = UserMessage(
source="user",
content=f"Your file viewer is currently open to the file or directory '{self._browser.page_title}' with path '{self._browser.path}'.",
)
task_message = UserMessage(
source="user",
content=task_content,
)
create_result = await self._model_client.create(
messages=self._get_compatible_context(history + [context_message, task_message]),
tools=[
TOOL_OPEN_PATH,
TOOL_PAGE_DOWN,
TOOL_PAGE_UP,
TOOL_FIND_NEXT,
TOOL_FIND_ON_PAGE_CTRL_F,
],
cancellation_token=cancellation_token,
)
response = create_result.content
if isinstance(response, str):
# Answer directly.
return False, response
elif isinstance(response, list) and all(isinstance(item, FunctionCall) for item in response):
function_calls = response
for function_call in function_calls:
tool_name = function_call.name
try:
arguments = json.loads(function_call.arguments)
except json.JSONDecodeError as e:
error_str = f"File surfer encountered an error decoding JSON arguments: {e}"
return False, error_str
if tool_name == "open_path":
path = arguments["path"]
self._browser.open_path(path)
elif tool_name == "page_up":
self._browser.page_up()
elif tool_name == "page_down":
self._browser.page_down()
elif tool_name == "find_on_page_ctrl_f":
search_string = arguments["search_string"]
self._browser.find_on_page(search_string)
elif tool_name == "find_next":
self._browser.find_next()
header, content = self._get_browser_state()
final_response = header.strip() + "\n=======================\n" + content
return False, final_response
final_response = "TERMINATE"
return False, final_response
def _get_compatible_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
if self._model_client.model_info["vision"]:
return messages
else:
return remove_images(messages)
def _to_config(self) -> FileSurferConfig:
return FileSurferConfig(
name=self.name,
model_client=self._model_client.dump_component(),
description=self.description,
)
@classmethod
def _from_config(cls, config: FileSurferConfig) -> Self:
return cls(
name=config.name,
model_client=ChatCompletionClient.load_component(config.model_client),
description=config.description or cls.DEFAULT_DESCRIPTION,
)

View File

@@ -0,0 +1,317 @@
# ruff: noqa: E722
import datetime
import io
import os
import re
import time
from typing import List, Optional, Tuple, Union
# TODO: Fix unfollowed import
from markitdown import FileConversionException, MarkItDown, UnsupportedFormatException # type: ignore
class MarkdownFileBrowser:
"""
(In preview) An extremely simple Markdown-powered file browser.
"""
# TODO: Fix unfollowed import
def __init__( # type: ignore
self,
viewport_size: Union[int, None] = 1024 * 8,
base_path: str | None = os.getcwd(),
cwd: str | None = None,
):
"""
Instantiate a new MarkdownFileBrowser.
Arguments:
viewport_size: Approximately how many *characters* fit in the viewport. Viewport dimensions are adjusted dynamically to avoid cutting off words (default: 8192).
base_path: The base path to use for the file browser. Files outside this path cannot be accessed. Defaults to the current working directory.
cwd: The browser's current working directory. Defaults to the system's current working directory.
"""
self.viewport_size = viewport_size # Applies only to the standard uri types
self.history: List[Tuple[str, float]] = list()
self.page_title: Optional[str] = None
self.viewport_current_page = 0
self.viewport_pages: List[Tuple[int, int]] = list()
self._markdown_converter = MarkItDown()
self._base_path = None if base_path is None else os.path.realpath(base_path)
self._page_content: str = ""
self._find_on_page_query: Union[str, None] = None
self._find_on_page_last_result: Union[int, None] = None # Location of the last result
# Set the working directory
if cwd is None:
if self._validate_path(os.getcwd()):
# Use the current working directory if it's in the base path
cwd = os.path.realpath(os.getcwd())
elif self._base_path is not None:
# Otherwise, use the base path
cwd = os.path.realpath(self._base_path)
else:
raise ValueError("No valid working directory (cwd) provided.")
elif not self._validate_path(cwd):
# A cwd was provided, but it is not valid
raise ValueError(f"Working directory (cwd) '{cwd}' is not valid. It must be within the base path.")
# Populate the history with the current working directory
self.set_path(os.path.realpath(cwd))
@property
def path(self) -> str:
"""Return the path of the current page."""
assert len(self.history) > 0
return self.history[-1][0]
def _validate_path(self, path: str) -> bool:
"""Validates the path to ensure it is within the base path.
Arguments:
path: The path to validate.
Returns:
True if the path is valid, False otherwise.
"""
if self._base_path is None:
return True
# Normalize the paths
path = os.path.realpath(path)
base = os.path.realpath(self._base_path)
# Check if the path is within the base path
if os.path.commonpath([path, base]) != base:
return False
return True
def set_path(self, path: str) -> None:
"""Sets the path of the current page.
This will result in the file being opened for reading.
Arguments:
path: An absolute or relative path of the file or directory to open."
"""
# Handle relative paths
path = os.path.expanduser(path)
if not os.path.isabs(path):
if os.path.isfile(self.path):
path = os.path.abspath(os.path.join(os.path.dirname(self.path), path))
elif os.path.isdir(self.path):
path = os.path.abspath(os.path.join(self.path, path))
# If neither a file or a directory, take it verbatim
# Validating the path wrt. the base path is done in _open_path
path = os.path.realpath(path)
self.history.append((path, time.time()))
self._open_path(path)
self.viewport_current_page = 0
self.find_on_page_query = None
self.find_on_page_viewport = None
@property
def viewport(self) -> str:
"""Return the content of the current viewport."""
bounds = self.viewport_pages[self.viewport_current_page]
return self.page_content[bounds[0] : bounds[1]]
@property
def page_content(self) -> str:
"""Return the full contents of the current page."""
return self._page_content
def _set_page_content(self, content: str, split_pages: bool = True) -> None:
"""Sets the text content of the current page."""
self._page_content = content
if split_pages:
self._split_pages()
else:
self.viewport_pages = [(0, len(self._page_content))]
if self.viewport_current_page >= len(self.viewport_pages):
self.viewport_current_page = len(self.viewport_pages) - 1
def page_down(self) -> None:
"""Move the viewport down one page, if possible."""
self.viewport_current_page = min(self.viewport_current_page + 1, len(self.viewport_pages) - 1)
def page_up(self) -> None:
"""Move the viewport up one page, if possible."""
self.viewport_current_page = max(self.viewport_current_page - 1, 0)
def find_on_page(self, query: str) -> Union[str, None]:
"""Searches for the query from the current viewport forward, looping back to the start if necessary."""
# Did we get here via a previous find_on_page search with the same query?
# If so, map to find_next
if query == self._find_on_page_query and self.viewport_current_page == self._find_on_page_last_result:
return self.find_next()
# Ok it's a new search start from the current viewport
self._find_on_page_query = query
viewport_match = self._find_next_viewport(query, self.viewport_current_page)
if viewport_match is None:
self._find_on_page_last_result = None
return None
else:
self.viewport_current_page = viewport_match
self._find_on_page_last_result = viewport_match
return self.viewport
def find_next(self) -> Union[str, None]:
"""Scroll to the next viewport that matches the query"""
if self._find_on_page_query is None:
return None
starting_viewport = self._find_on_page_last_result
if starting_viewport is None:
starting_viewport = 0
else:
starting_viewport += 1
if starting_viewport >= len(self.viewport_pages):
starting_viewport = 0
viewport_match = self._find_next_viewport(self._find_on_page_query, starting_viewport)
if viewport_match is None:
self._find_on_page_last_result = None
return None
else:
self.viewport_current_page = viewport_match
self._find_on_page_last_result = viewport_match
return self.viewport
def _find_next_viewport(self, query: Optional[str], starting_viewport: int) -> Union[int, None]:
"""Search for matches between the starting viewport looping when reaching the end."""
if query is None:
return None
# Normalize the query, and convert to a regular expression
nquery = re.sub(r"\*", "__STAR__", query)
nquery = " " + (" ".join(re.split(r"\W+", nquery))).strip() + " "
nquery = nquery.replace(" __STAR__ ", "__STAR__ ") # Merge isolated stars with prior word
nquery = nquery.replace("__STAR__", ".*").lower()
if nquery.strip() == "":
return None
idxs: List[int] = list()
idxs.extend(range(starting_viewport, len(self.viewport_pages)))
idxs.extend(range(0, starting_viewport))
for i in idxs:
bounds = self.viewport_pages[i]
content = self.page_content[bounds[0] : bounds[1]]
# TODO: Remove markdown links and images
ncontent = " " + (" ".join(re.split(r"\W+", content))).strip().lower() + " "
if re.search(nquery, ncontent):
return i
return None
def open_path(self, path: str) -> str:
"""Open a file or directory in the file surfer."""
self.set_path(path)
return self.viewport
def _split_pages(self) -> None:
"""Split the page contents into pages that are approximately the viewport size. Small deviations are permitted to ensure words are not broken."""
# Handle empty pages
if len(self._page_content) == 0:
self.viewport_pages = [(0, 0)]
return
# Break the viewport into pages
self.viewport_pages = []
start_idx = 0
while start_idx < len(self._page_content):
end_idx = min(start_idx + self.viewport_size, len(self._page_content)) # type: ignore[operator]
# Adjust to end on a space
while end_idx < len(self._page_content) and self._page_content[end_idx - 1] not in [" ", "\t", "\r", "\n"]:
end_idx += 1
self.viewport_pages.append((start_idx, end_idx))
start_idx = end_idx
def _open_path(
self,
path: str,
) -> None:
"""Open a file for reading, converting it to Markdown in the process.
Arguments:
path: The path of the file or directory to open.
"""
if not self._validate_path(path):
# Not robust to TOCTOU issues.
# Mitigate by running with limited permissions, or use a sandbox.
self.page_title = "FileNotFoundError"
self._set_page_content(f"# FileNotFoundError\n\nFile not found: {path}")
else:
try:
if os.path.isdir(path): # TODO: Fix markdown_converter types
res = self._markdown_converter.convert_stream( # type: ignore
io.BytesIO(self._fetch_local_dir(path).encode("utf-8")), file_extension=".txt"
)
assert self._validate_path(path)
self.page_title = res.title
self._set_page_content(res.text_content, split_pages=False)
else:
res = self._markdown_converter.convert_local(path)
assert self._validate_path(path)
self.page_title = res.title
self._set_page_content(res.text_content)
except UnsupportedFormatException:
self.page_title = "UnsupportedFormatException"
self._set_page_content(f"# UnsupportedFormatException\n\nCannot preview '{path}' as Markdown.")
except FileConversionException:
self.page_title = "FileConversionException."
self._set_page_content(f"# FileConversionException\n\nError converting '{path}' to Markdown.")
except FileNotFoundError:
self.page_title = "FileNotFoundError"
self._set_page_content(f"# FileNotFoundError\n\nFile not found: {path}")
def _fetch_local_dir(self, local_path: str) -> str:
"""Render a local directory listing in HTML to assist with local file browsing via the "file://" protocol.
Through rendered in HTML, later parts of the pipeline will convert the listing to Markdown.
Arguments:
local_path: A path to the local directory whose contents are to be listed.
Returns:
A directory listing, rendered in HTML.
"""
listing = f"""
# Index of {local_path}
| Name | Size | Date Modified |
| ---- | ---- | ------------- |
| .. (parent directory) | | |
"""
for entry in os.listdir(local_path):
size = ""
full_path = os.path.join(local_path, entry)
mtime = ""
try:
mtime = datetime.datetime.fromtimestamp(os.path.getmtime(full_path)).strftime("%Y-%m-%d %H:%M")
except Exception as e:
# Handles PermissionError, etc.
mtime = f"N/A: {type(e).__name__}"
if os.path.isdir(full_path):
entry = entry + os.path.sep
else:
try:
size = str(os.path.getsize(full_path))
except Exception as e:
# Handles PermissionError, etc.
size = f"N/A: {type(e).__name__}"
listing += f"| {entry} | {size} | {mtime} |\n"
return listing

View File

@@ -0,0 +1,50 @@
from agentdhal_core.tools import ParametersSchema, ToolSchema
TOOL_OPEN_PATH = ToolSchema(
name="open_path",
description="Open a local file or directory at a path in the text-based file browser and return current viewport content.",
parameters=ParametersSchema(
type="object",
properties={
"path": {
"type": "string",
"description": "The relative or absolute path of a local file to visit.",
},
},
required=["path"],
),
)
TOOL_PAGE_UP = ToolSchema(
name="page_up",
description="Scroll the viewport UP one page-length in the current file and return the new viewport content.",
)
TOOL_PAGE_DOWN = ToolSchema(
name="page_down",
description="Scroll the viewport DOWN one page-length in the current file and return the new viewport content.",
)
TOOL_FIND_ON_PAGE_CTRL_F = ToolSchema(
name="find_on_page_ctrl_f",
description="Scroll the viewport to the first occurrence of the search string. This is equivalent to Ctrl+F.",
parameters=ParametersSchema(
type="object",
properties={
"search_string": {
"type": "string",
"description": "The string to search for on the page. This search string supports wildcards like '*'",
},
},
required=["search_string"],
),
)
TOOL_FIND_NEXT = ToolSchema(
name="find_next",
description="Scroll the viewport to next occurrence of the search string.",
)

View File

@@ -0,0 +1,10 @@
try:
from ._magentic_one_coder_agent import MagenticOneCoderAgent
except ImportError as e:
raise ImportError(
"Dependencies for MagenticOneCoderAgent not found. "
'Please install autogen-ext with the "magentic-one" extra: '
'pip install "agentdhal-ext[magentic-one]"'
) from e
__all__ = ["MagenticOneCoderAgent"]

View File

@@ -0,0 +1,41 @@
from typing import Any
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_core.models import (
ChatCompletionClient,
)
MAGENTIC_ONE_CODER_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills."
MAGENTIC_ONE_CODER_SYSTEM_MESSAGE = """You are a helpful AI assistant.
Solve tasks using your coding and language skills.
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.
Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.
When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.
Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use the 'print' function for the output when relevant. Check the execution result returned by the user.
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible."""
class MagenticOneCoderAgent(AssistantAgent):
"""An agent, used by MagenticOne that provides coding assistance using an LLM model client.
The prompts and description are sealed, to replicate the original MagenticOne configuration. See AssistantAgent if you wish to modify these values.
"""
component_provider_override = "agentdhal_extensions.agents.magentic_one.MagenticOneCoderAgent"
def __init__(
self,
name: str,
model_client: ChatCompletionClient,
**kwargs: Any,
):
super().__init__(
name,
model_client,
description=MAGENTIC_ONE_CODER_DESCRIPTION,
system_message=MAGENTIC_ONE_CODER_SYSTEM_MESSAGE,
)

View File

@@ -0,0 +1,7 @@
from ._openai_agent import OpenAIAgent
from ._openai_assistant_agent import OpenAIAssistantAgent
__all__ = [
"OpenAIAgent",
"OpenAIAssistantAgent",
]

View File

@@ -0,0 +1,682 @@
import asyncio
import logging
from typing import (
Any,
AsyncGenerator,
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Sequence,
Type,
Union,
cast,
)
from agentdhal_agentchat import EVENT_LOGGER_NAME
from agentdhal_agentchat.agents import BaseChatAgent
from agentdhal_agentchat.base import Response
from agentdhal_agentchat.messages import (
AgentEvent,
BaseChatMessage,
ChatMessage,
HandoffMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallSummaryMessage,
)
from agentdhal_core import CancellationToken, Component
from agentdhal_core.models import UserMessage
from pydantic import BaseModel, Field
from typing_extensions import NotRequired, TypedDict
from openai import AsyncAzureOpenAI, AsyncOpenAI # type: ignore
# Number of characters to display when previewing image content in logs and UI
# Base64 encoded images can be very long, so we truncate for readability
IMAGE_CONTENT_PREVIEW_LENGTH = 50
# NOTE: We use the new Responses API, so ChatCompletion imports are not needed.
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
# TypedDict classes for built-in tool configurations
class FileSearchToolConfig(TypedDict):
"""Configuration for file_search tool."""
type: Literal["file_search"]
vector_store_ids: List[str] # required - The IDs of the vector stores to search
max_num_results: NotRequired[int] # optional
ranking_options: NotRequired[Dict[str, Any]] # optional
filters: NotRequired[Dict[str, Any]] # optional
class WebSearchToolConfig(TypedDict):
"""Configuration for web_search_preview tool."""
type: Literal["web_search_preview"]
search_context_size: NotRequired[str] # optional
user_location: NotRequired[Union[str, Dict[str, Any]]] # optional - Can be string or structured location
class ComputerUseToolConfig(TypedDict):
"""Configuration for computer_use_preview tool."""
type: Literal["computer_use_preview"]
display_height: int # required - Display height in pixels
display_width: int # required - Display width in pixels
environment: str # required - Environment type for computer use
class MCPToolConfig(TypedDict):
"""Configuration for mcp tool."""
type: Literal["mcp"]
server_label: str # required - Label for the MCP server
server_url: str # required - URL of the MCP server
allowed_tools: NotRequired[List[str]] # optional - List of allowed tools
headers: NotRequired[Dict[str, str]] # optional - HTTP headers for requests
require_approval: NotRequired[bool] # optional - Whether to require user approval
class CodeInterpreterToolConfig(TypedDict):
"""Configuration for code_interpreter tool."""
type: Literal["code_interpreter"]
container: str | Dict[str, Any] # required - Container configuration for code execution
class ImageGenerationToolConfig(TypedDict):
"""Configuration for image_generation tool."""
type: Literal["image_generation"]
background: NotRequired[str] # optional - Background color or image
input_image_mask: NotRequired[str] # optional - Mask for input image editing
class LocalShellToolConfig(TypedDict):
"""Configuration for local_shell tool.
WARNING: This tool is only supported with the 'codex-mini-latest' model
and is available exclusively through the Responses API.
"""
type: Literal["local_shell"]
# Note: local_shell currently has no additional parameters in the API
# Union type for all built-in tool configurations
BuiltinToolConfig = Union[
FileSearchToolConfig,
WebSearchToolConfig,
ComputerUseToolConfig,
MCPToolConfig,
CodeInterpreterToolConfig,
ImageGenerationToolConfig,
LocalShellToolConfig,
]
# Define ImageMessage class early since it's used in _convert_message_to_openai_message
class ImageMessage(BaseChatMessage):
"""A message containing an image."""
content: str # URL or base64 string
def to_model_message(self) -> UserMessage:
return UserMessage(content=self.content, source=self.source)
def to_model_text(self) -> str:
return "[image]"
def to_text(self) -> str:
# Truncate long image content (especially base64) for better readability
# While still showing enough of the URL or content to be identifiable
if len(self.content) > IMAGE_CONTENT_PREVIEW_LENGTH:
return f"[Image: {self.content[:IMAGE_CONTENT_PREVIEW_LENGTH]}...]"
return f"[Image: {self.content}]"
class OpenAIMessageContent(TypedDict):
type: str
text: str
class OpenAIImageUrlContent(TypedDict):
url: str
class OpenAIImageContent(TypedDict):
type: str
image_url: OpenAIImageUrlContent
class OpenAIMessage(TypedDict):
role: str
content: Union[str, List[Union[OpenAIMessageContent, OpenAIImageContent]]]
def _convert_message_to_openai_message(
message: Union[TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage],
) -> OpenAIMessage:
"""Convert an AutoGen message to an OpenAI message format."""
if isinstance(message, TextMessage):
if message.source == "user":
return {"role": "user", "content": str(message.content)}
elif message.source == "system":
return {"role": "system", "content": str(message.content)}
elif message.source == "assistant":
return {"role": "assistant", "content": str(message.content)}
else:
return {"role": "user", "content": str(message.content)}
elif isinstance(message, MultiModalMessage):
content_parts: List[Union[OpenAIMessageContent, OpenAIImageContent]] = []
for part in message.content:
if isinstance(part, TextMessage):
content_parts.append({"type": "text", "text": str(part.content)})
elif isinstance(part, ImageMessage):
image_content = str(part.content)
content_parts.append({"type": "image_url", "image_url": {"url": image_content}})
return {"role": "user", "content": content_parts}
else:
return {"role": "user", "content": str(message.content)}
class OpenAIAgentState(BaseModel):
type: str = Field(default="OpenAIAgentState")
response_id: Optional[str] = None
history: List[Dict[str, Any]] = Field(default_factory=list)
class OpenAIAgentConfig(BaseModel):
"""
Configuration model for OpenAI agent supporting OpenAI built-in tools only.
.. versionchanged:: v0.7.0
Added support for built-in tools in JSON configuration via _to_config and _from_config methods.
The tools field accepts built-in tool configurations (dict format) and built-in tool names (string format).
Custom tools are not supported.
"""
name: str
description: str
model: str
instructions: str
tools: List[Dict[str, Any] | str] | None = None
temperature: Optional[float] = 1
max_output_tokens: Optional[int] = None
json_mode: bool = False
store: bool = True
truncation: str = "disabled"
class OpenAIAgent(BaseChatAgent, Component[OpenAIAgentConfig]):
"""
An agent implementation that uses the OpenAI Responses API to generate responses.
Installation:
.. code-block:: bash
pip install "agentdhal-ext[openai]"
# pip install "agentdhal-ext[openai,azure]" # For Azure OpenAI Assistant
This agent leverages the Responses API to generate responses with capabilities like:
* Multi-turn conversations
* Built-in tool support (file_search, code_interpreter, web_search_preview, etc.)
Currently, custom tools are not supported.
.. versionchanged:: v0.7.0
Added support for built-in tool types like file_search, web_search_preview,
code_interpreter, computer_use_preview, image_generation, and mcp.
Added support for tool configurations with required and optional parameters.
Built-in tools are split into two categories:
**Tools that can use string format** (no required parameters):
- web_search_preview: Can be used as "web_search_preview" or with optional config
(user_location, search_context_size)
- image_generation: Can be used as "image_generation" or with optional config (background, input_image_mask)
- local_shell: Can be used as "local_shell" (WARNING: Only works with codex-mini-latest model)
**Tools that REQUIRE dict configuration** (have required parameters):
- file_search: MUST use dict with vector_store_ids (List[str])
- computer_use_preview: MUST use dict with display_height (int), display_width (int), environment (str)
- code_interpreter: MUST use dict with container (str)
- mcp: MUST use dict with server_label (str), server_url (str)
Using required-parameter tools in string format will raise a ValueError with helpful error messages.
The tools parameter type annotation only accepts string values for tools that don't require parameters.
Note:
Custom tools (autogen FunctionTool or other user-defined tools) are not supported by this agent.
Only OpenAI built-in tools provided via the Responses API are supported.
Args:
name (str): Name of the agent
description (str): Description of the agent's purpose
client (Union[AsyncOpenAI, AsyncAzureOpenAI]): OpenAI client instance
model (str): Model to use (e.g. "gpt-4.1")
instructions (str): System instructions for the agent
tools (Optional[Iterable[Union[str, BuiltinToolConfig]]]): Tools the agent can use.
Supported string values (no required parameters): "web_search_preview", "image_generation", "local_shell".
Dict values can provide configuration for built-in tools with parameters.
Required parameters for built-in tools:
- file_search: vector_store_ids (List[str])
- computer_use_preview: display_height (int), display_width (int), environment (str)
- code_interpreter: container (str)
- mcp: server_label (str), server_url (str)
Optional parameters for built-in tools:
- file_search: max_num_results (int), ranking_options (dict), filters (dict)
- web_search_preview: user_location (str or dict), search_context_size (int)
- image_generation: background (str), input_image_mask (str)
- mcp: allowed_tools (List[str]), headers (dict), require_approval (bool)
Special tools with model restrictions:
- local_shell: Only works with "codex-mini-latest" model (WARNING: Very limited support)
Custom tools are not supported.
temperature (Optional[float]): Temperature for response generation (default: 1)
max_output_tokens (Optional[int]): Maximum output tokens
json_mode (bool): Whether to use JSON mode (default: False)
store (bool): Whether to store conversations (default: True)
truncation (str): Truncation strategy (default: "disabled")
Example:
Basic usage with built-in tools:
.. code-block:: python
import asyncio
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.agents.openai import OpenAIAgent
from openai import AsyncOpenAI
async def example():
client = AsyncOpenAI()
agent = OpenAIAgent(
name="SimpleAgent",
description="A simple OpenAI agent using the Responses API",
client=client,
model="gpt-4.1",
instructions="You are a helpful assistant.",
tools=["web_search_preview"], # Only tools without required params
)
await Console(agent.run_stream(task="Search for recent AI developments"))
asyncio.run(example())
Usage with configured built-in tools:
.. code-block:: python
import asyncio
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.agents.openai import OpenAIAgent
from openai import AsyncOpenAI
async def example_with_configs():
client = AsyncOpenAI()
# Configure tools with required and optional parameters
tools = [
# {
# "type": "file_search",
# "vector_store_ids": ["vs_abc123"], # required
# "max_num_results": 10, # optional
# },
# {
# "type": "computer_use_preview",
# "display_height": 1024, # required
# "display_width": 1280, # required
# "environment": "linux", # required
# },
{
"type": "code_interpreter",
"container": {"type": "auto"}, # required
},
# {
# "type": "mcp",
# "server_label": "my-mcp-server", # required
# "server_url": "http://localhost:3000", # required
# },
{
"type": "web_search_preview",
"user_location": { # optional - structured location
"type": "approximate", # required: "approximate" or "exact"
"country": "US", # optional
"region": "CA", # optional
"city": "San Francisco", # optional
},
"search_context_size": "low", # optional
},
# "image_generation", # Simple tools can still use string format
]
agent = OpenAIAgent(
name="ConfiguredAgent",
description="An agent with configured tools",
client=client,
model="gpt-4.1",
instructions="You are a helpful assistant with specialized tools.",
tools=tools, # type: ignore
)
await Console(agent.run_stream(task="Search for recent AI developments"))
asyncio.run(example_with_configs())
Note:
Custom tools are not supported by OpenAIAgent. Use only built-in tools from the Responses API.
"""
component_config_schema = OpenAIAgentConfig
component_provider_override = "agentdhal_extensions.agents.openai.OpenAIAgent"
def __init__(
self: "OpenAIAgent",
name: str,
description: str,
client: Union[AsyncOpenAI, AsyncAzureOpenAI],
model: str,
instructions: str,
tools: Optional[
Iterable[
Union[
Literal["web_search_preview", "image_generation", "local_shell"],
BuiltinToolConfig,
]
]
] = None,
temperature: Optional[float] = 1,
max_output_tokens: Optional[int] = None,
json_mode: bool = False,
store: bool = True,
truncation: str = "disabled",
) -> None:
super().__init__(name, description)
self._client: Union[AsyncOpenAI, AsyncAzureOpenAI] = client
self._model: str = model
self._instructions: str = instructions
self._temperature: Optional[float] = temperature
self._max_output_tokens: Optional[int] = max_output_tokens
self._json_mode: bool = json_mode
self._store: bool = store
self._truncation: str = truncation
self._last_response_id: Optional[str] = None
self._message_history: List[Dict[str, Any]] = []
self._tools: List[Dict[str, Any]] = []
if tools is not None:
for tool in tools:
if isinstance(tool, str):
# Handle built-in tool types
self._add_builtin_tool(tool)
elif isinstance(tool, dict) and "type" in tool:
# Handle configured built-in tools
self._tools.append(cast(dict[str, Any], tool))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
def _add_builtin_tool(self, tool_name: str) -> None:
"""Add a built-in tool by name."""
# Skip if an identical tool has already been registered (idempotent behaviour)
if any(td.get("type") == tool_name for td in self._tools):
return # Duplicate ignore rather than raise to stay backward-compatible
# Only allow string format for tools that don't require parameters
if tool_name == "web_search_preview":
self._tools.append({"type": "web_search_preview"})
elif tool_name == "image_generation":
self._tools.append({"type": "image_generation"})
elif tool_name == "local_shell":
# Special handling for local_shell - very limited model support
if self._model != "codex-mini-latest":
raise ValueError(
f"Tool 'local_shell' is only supported with model 'codex-mini-latest', "
f"but current model is '{self._model}'. "
f"This tool is available exclusively through the Responses API and has severe limitations. "
f"Consider using agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool with "
f"agentdhal_extensions.code_executors.local.LocalCommandLineCodeExecutor for shell execution instead."
)
self._tools.append({"type": "local_shell"})
elif tool_name in ["file_search", "code_interpreter", "computer_use_preview", "mcp"]:
# These tools require specific parameters and must use dict configuration
raise ValueError(
f"Tool '{tool_name}' requires specific parameters and cannot be added using string format. "
f"Use dict configuration instead. Required parameters for {tool_name}: "
f"{self._get_required_params_help(tool_name)}"
)
else:
raise ValueError(f"Unsupported built-in tool type: {tool_name}")
def _get_required_params_help(self, tool_name: str) -> str:
"""Get help text for required parameters of a tool."""
help_text = {
"file_search": "vector_store_ids (List[str])",
"code_interpreter": "container (str | dict)",
"computer_use_preview": "display_height (int), display_width (int), environment (str)",
"mcp": "server_label (str), server_url (str)",
}
return help_text.get(tool_name, "unknown parameters")
def _convert_message_to_dict(self, message: OpenAIMessage) -> Dict[str, Any]:
"""Convert an OpenAIMessage to a Dict[str, Any]."""
return dict(message)
@property
def produced_message_types(
self: "OpenAIAgent",
) -> Sequence[
Union[
Type[TextMessage],
Type[MultiModalMessage],
Type[StopMessage],
Type[ToolCallSummaryMessage],
Type[HandoffMessage],
]
]:
"""Return the types of messages that this agent can produce."""
return [TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage]
# Custom tool execution is not supported by this agent.
def _build_api_parameters(self: "OpenAIAgent", messages: List[Dict[str, Any]]) -> Dict[str, Any]:
has_system_message = any(msg.get("role") == "system" for msg in messages)
if self._instructions and not has_system_message:
messages = [{"role": "system", "content": self._instructions}] + messages
api_params: Dict[str, Any] = {
"model": self._model,
"input": messages, # Responses API expects 'input'
}
if self._temperature is not None:
api_params["temperature"] = self._temperature
if self._max_output_tokens is not None:
api_params["max_output_tokens"] = self._max_output_tokens
if self._tools:
api_params["tools"] = self._tools
if self._json_mode:
api_params["text"] = {"type": "json_object"}
api_params["store"] = self._store
api_params["truncation"] = self._truncation
if self._last_response_id:
api_params["previous_response_id"] = self._last_response_id
return api_params
async def on_messages(
self: "OpenAIAgent", messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> Response:
response = None
inner_messages: List[
Union[AgentEvent, TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage]
] = []
async for msg in self.on_messages_stream(messages, cancellation_token):
if isinstance(msg, Response):
response = msg
# ModelClientStreamingChunkEvent does not exist in this version, so skip this check
else:
inner_messages.append(msg)
if response is None:
raise ValueError("No response was generated")
if response.inner_messages is None:
response.inner_messages = []
for msg in inner_messages:
if msg not in response.inner_messages:
response.inner_messages = list(response.inner_messages) + [msg]
return response
async def on_messages_stream(
self: "OpenAIAgent", messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[
Union[
AgentEvent, TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage, Response
],
None,
]:
input_messages: List[Dict[str, Any]] = []
if self._message_history:
input_messages.extend(self._message_history)
for message in messages:
if isinstance(
message, (TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage)
):
openai_message = _convert_message_to_openai_message(message)
dict_message = self._convert_message_to_dict(openai_message)
input_messages.append(dict_message)
self._message_history.append(dict_message)
else:
msg_content = str(cast(Any, message).content) if hasattr(message, "content") else str(message)
dict_message = {"role": "user", "content": msg_content}
input_messages.append(dict_message)
self._message_history.append(dict_message)
inner_messages: List[AgentEvent | ChatMessage] = []
api_params = self._build_api_parameters(input_messages)
try:
client = cast(Any, self._client)
response_obj = await cancellation_token.link_future(
asyncio.ensure_future(client.responses.create(**api_params))
)
content = getattr(response_obj, "output_text", None)
response_id = getattr(response_obj, "id", None)
self._last_response_id = response_id
# Use a readable placeholder when the API returns no content to aid debugging
content_str: str = str(content) if content is not None else "[no content returned]"
self._message_history.append({"role": "assistant", "content": content_str})
final_message = TextMessage(source=self.name, content=content_str)
response = Response(chat_message=final_message, inner_messages=inner_messages)
yield response
except Exception as e:
error_message = f"Error generating response: {str(e)}"
event_logger.error(f"API error: {error_message}", exc_info=True)
error_response = TextMessage(source=self.name, content=error_message)
yield Response(chat_message=error_response, inner_messages=inner_messages)
async def on_reset(self: "OpenAIAgent", cancellation_token: CancellationToken) -> None:
self._last_response_id = None
self._message_history = []
async def save_state(self: "OpenAIAgent") -> Mapping[str, Any]:
state = OpenAIAgentState(
response_id=self._last_response_id,
history=self._message_history,
)
return state.model_dump()
async def load_state(self: "OpenAIAgent", state: Mapping[str, Any]) -> None:
agent_state = OpenAIAgentState.model_validate(state)
self._last_response_id = agent_state.response_id
self._message_history = agent_state.history
def _to_config(self: "OpenAIAgent") -> OpenAIAgentConfig:
"""Convert the OpenAI agent to a declarative config.
Serializes built-in tools to their appropriate configuration formats for JSON serialization.
Returns:
OpenAIAgentConfig: The configuration that can recreate this agent.
"""
return OpenAIAgentConfig(
name=self.name,
description=self.description,
model=self._model,
instructions=self._instructions,
tools=list(self._tools),
temperature=self._temperature,
max_output_tokens=self._max_output_tokens,
json_mode=self._json_mode,
store=self._store,
truncation=self._truncation,
)
@classmethod
def _from_config(cls: Type["OpenAIAgent"], config: OpenAIAgentConfig) -> "OpenAIAgent":
"""Create an OpenAI agent from a declarative config.
Handles built-in tools (from string or dict configurations).
Args:
config: The configuration to load the agent from.
Returns:
OpenAIAgent: The reconstructed agent.
"""
from openai import AsyncOpenAI
client = AsyncOpenAI()
return cls(
name=config.name,
description=config.description,
client=client,
model=config.model,
instructions=config.instructions,
tools=config.tools, # type: ignore
temperature=config.temperature,
max_output_tokens=config.max_output_tokens,
json_mode=config.json_mode,
store=config.store,
truncation=config.truncation,
)
# Add public API wrappers for configuration and tools
def to_config(self) -> OpenAIAgentConfig:
"""Public wrapper for the private _to_config method."""
return self._to_config()
@classmethod
def from_config(cls, config: OpenAIAgentConfig) -> "OpenAIAgent":
"""Public wrapper for the private _from_config classmethod."""
return cls._from_config(config)
@property
def tools(self) -> list[Any]:
"""Public access to the agent's tools."""
return self._tools
@property
def model(self) -> str:
"""Public access to the agent's model."""
return self._model

View File

@@ -0,0 +1,715 @@
import asyncio
import json
import logging
import os
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Union,
cast,
)
import aiofiles
from agentdhal_agentchat import EVENT_LOGGER_NAME
from agentdhal_agentchat.agents import BaseChatAgent
from agentdhal_agentchat.base import Response
from agentdhal_agentchat.messages import (
BaseAgentEvent,
BaseChatMessage,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
)
from agentdhal_core import CancellationToken, FunctionCall, Image
from agentdhal_core.models import ChatCompletionClient, FunctionExecutionResult
from agentdhal_core.tools import FunctionTool, Tool
from pydantic import BaseModel, Field
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI, NotGiven
from openai.pagination import AsyncCursorPage
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
from openai.types import FileObject
from openai.types.beta import thread_update_params
from openai.types.beta.assistant import Assistant
from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam
from openai.types.beta.assistant_tool_param import AssistantToolParam
from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
from openai.types.beta.file_search_tool_param import FileSearchToolParam
from openai.types.beta.function_tool_param import FunctionToolParam
from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter
from openai.types.beta.threads import Message, MessageDeleted, Run
from openai.types.beta.threads.image_url_content_block_param import ImageURLContentBlockParam
from openai.types.beta.threads.image_url_param import ImageURLParam
from openai.types.beta.threads.message_content_part_param import (
MessageContentPartParam,
)
from openai.types.beta.threads.text_content_block_param import TextContentBlockParam
from openai.types.shared_params.function_definition import FunctionDefinition
from openai.types.vector_store import VectorStore
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam":
"""Convert an autogen Tool to an OpenAI Assistant function tool parameter."""
schema = tool.schema
parameters: Dict[str, object] = {}
if "parameters" in schema:
parameters = {
"type": schema["parameters"]["type"],
"properties": schema["parameters"]["properties"],
}
if "required" in schema["parameters"]:
parameters["required"] = schema["parameters"]["required"]
function_def = FunctionDefinition(
name=schema["name"],
description=schema.get("description", ""),
parameters=parameters,
)
return FunctionToolParam(type="function", function=function_def)
class OpenAIAssistantAgentState(BaseModel):
type: str = Field(default="OpenAIAssistantAgentState")
assistant_id: Optional[str] = None
thread_id: Optional[str] = None
initial_message_ids: List[str] = Field(default_factory=list)
vector_store_id: Optional[str] = None
uploaded_file_ids: List[str] = Field(default_factory=list)
class OpenAIAssistantAgent(BaseChatAgent):
"""An agent implementation that uses the Assistant API to generate responses.
Installation:
.. code-block:: bash
pip install "agentdhal-ext[openai]" # For OpenAI Assistant
# pip install "agentdhal-ext[openai,azure]" # For Azure OpenAI Assistant
This agent leverages the Assistant API to create AI assistants with capabilities like:
* Code interpretation and execution
* File handling and search
* Custom function calling
* Multi-turn conversations
The agent maintains a thread of conversation and can use various tools including
* Code interpreter: For executing code and working with files
* File search: For searching through uploaded documents
* Custom functions: For extending capabilities with user-defined tools
Key Features:
* Supports multiple file formats including code, documents, images
* Can handle up to 128 tools per assistant
* Maintains conversation context in threads
* Supports file uploads for code interpreter and search
* Vector store integration for efficient file search
* Automatic file parsing and embedding
You can use an existing thread or assistant by providing the `thread_id` or `assistant_id` parameters.
Examples:
Use the assistant to analyze data in a CSV file:
.. code-block:: python
from openai import AsyncOpenAI
from agentdhal_core import CancellationToken
import asyncio
from agentdhal_extensions.agents.openai import OpenAIAssistantAgent
from agentdhal_agentchat.messages import TextMessage
async def example():
cancellation_token = CancellationToken()
# Create an OpenAI client
client = AsyncOpenAI(api_key="your-api-key", base_url="your-base-url")
# Create an assistant with code interpreter
assistant = OpenAIAssistantAgent(
name="PythonHelper",
description="Helps with Python programming",
client=client,
model="gpt-4",
instructions="You are a helpful Python programming assistant.",
tools=["code_interpreter"],
)
# Upload files for the assistant to use
await assistant.on_upload_for_code_interpreter("data.csv", cancellation_token)
# Get response from the assistant
response = await assistant.on_messages(
[TextMessage(source="user", content="Analyze the data in data.csv")], cancellation_token
)
print(response)
# Clean up resources
await assistant.delete_uploaded_files(cancellation_token)
await assistant.delete_assistant(cancellation_token)
asyncio.run(example())
Use Azure OpenAI Assistant with AAD authentication:
.. code-block:: python
from openai import AsyncAzureOpenAI
import asyncio
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from agentdhal_core import CancellationToken
from agentdhal_extensions.agents.openai import OpenAIAssistantAgent
from agentdhal_agentchat.messages import TextMessage
async def example():
cancellation_token = CancellationToken()
# Create an Azure OpenAI client
token_provider = get_bearer_token_provider(DefaultAzureCredential())
client = AsyncAzureOpenAI(
azure_deployment="YOUR_AZURE_DEPLOYMENT",
api_version="YOUR_API_VERSION",
azure_endpoint="YOUR_AZURE_ENDPOINT",
azure_ad_token_provider=token_provider,
)
# Create an assistant with code interpreter
assistant = OpenAIAssistantAgent(
name="PythonHelper",
description="Helps with Python programming",
client=client,
model="gpt-4o",
instructions="You are a helpful Python programming assistant.",
tools=["code_interpreter"],
)
# Get response from the assistant
response = await assistant.on_messages([TextMessage(source="user", content="Hello.")], cancellation_token)
print(response)
# Clean up resources
await assistant.delete_assistant(cancellation_token)
asyncio.run(example())
Args:
name (str): Name of the assistant
description (str): Description of the assistant's purpose
client (AsyncOpenAI | AsyncAzureOpenAI): OpenAI client or Azure OpenAI client instance
model (str): Model to use (e.g. "gpt-4")
instructions (str): System instructions for the assistant
tools (Optional[Iterable[Union[Literal["code_interpreter", "file_search"], Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]]]]): Tools the assistant can use
assistant_id (Optional[str]): ID of existing assistant to use
thread_id (Optional[str]): ID of existing thread to use
metadata (Optional[Dict[str, str]]): Additional metadata for the assistant.
response_format (Optional[AssistantResponseFormatOptionParam]): Response format settings
temperature (Optional[float]): Temperature for response generation
tool_resources (Optional[ToolResources]): Additional tool configuration
top_p (Optional[float]): Top p sampling parameter
"""
def __init__(
self,
name: str,
description: str,
client: AsyncOpenAI | AsyncAzureOpenAI,
model: str,
instructions: str,
tools: Optional[
Iterable[
Union[
Literal["code_interpreter", "file_search"],
Tool | Callable[..., Any] | Callable[..., Awaitable[Any]],
]
]
] = None,
assistant_id: Optional[str] = None,
thread_id: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
response_format: Optional["AssistantResponseFormatOptionParam"] = None,
temperature: Optional[float] = None,
tool_resources: Optional["ToolResources"] = None,
top_p: Optional[float] = None,
) -> None:
if isinstance(client, ChatCompletionClient):
raise ValueError(
"Incorrect client passed to OpenAIAssistantAgent. Please use an OpenAI AsyncClient instance instead of an AutoGen ChatCompletionClient instance."
)
super().__init__(name, description)
if tools is None:
tools = []
# Store original tools and converted tools separately
self._original_tools: List[Tool] = []
converted_tools: List["AssistantToolParam"] = []
for tool in tools:
if isinstance(tool, str):
if tool == "code_interpreter":
converted_tools.append(CodeInterpreterToolParam(type="code_interpreter"))
elif tool == "file_search":
converted_tools.append(FileSearchToolParam(type="file_search"))
elif isinstance(tool, Tool):
self._original_tools.append(tool)
converted_tools.append(_convert_tool_to_function_param(tool))
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
function_tool = FunctionTool(tool, description=description)
self._original_tools.append(function_tool)
converted_tools.append(_convert_tool_to_function_param(function_tool))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
self._client = client
self._assistant: Optional["Assistant"] = None
self._thread: Optional["Thread"] = None
self._init_thread_id = thread_id
self._model = model
self._instructions = instructions
self._api_tools = converted_tools
self._assistant_id = assistant_id
self._metadata = metadata
self._response_format = response_format
self._temperature = temperature
self._tool_resources = tool_resources
self._top_p = top_p
self._vector_store_id: Optional[str] = None
self._uploaded_file_ids: List[str] = []
# Variables to track initial state
self._initial_message_ids: Set[str] = set()
self._initial_state_retrieved: bool = False
async def _ensure_initialized(self) -> None:
"""Ensure assistant and thread are created."""
if self._assistant is None:
if self._assistant_id:
self._assistant = await self._client.beta.assistants.retrieve(assistant_id=self._assistant_id) # type: ignore[reportDeprecated]
else:
self._assistant = await self._client.beta.assistants.create( # type: ignore[reportDeprecated]
model=self._model,
description=self.description,
instructions=self._instructions,
tools=self._api_tools,
metadata=self._metadata,
response_format=self._response_format if self._response_format else NOT_GIVEN, # type: ignore
temperature=self._temperature,
tool_resources=self._tool_resources if self._tool_resources else NOT_GIVEN, # type: ignore
top_p=self._top_p,
)
if self._thread is None:
if self._init_thread_id:
self._thread = await self._client.beta.threads.retrieve(thread_id=self._init_thread_id) # type: ignore[reportDeprecated]
else:
self._thread = await self._client.beta.threads.create() # type: ignore[reportDeprecated]
# Retrieve initial state only once
if not self._initial_state_retrieved:
await self._retrieve_initial_state()
self._initial_state_retrieved = True
async def _retrieve_initial_state(self) -> None:
"""Retrieve and store the initial state of messages and runs."""
# Retrieve all initial message IDs
initial_message_ids: Set[str] = set()
after: str | NotGiven = NOT_GIVEN
while True:
msgs: AsyncCursorPage[Message] = await self._client.beta.threads.messages.list( # type: ignore[reportDeprecated]
self._thread_id, after=after, order="asc", limit=100
)
for msg in msgs.data:
initial_message_ids.add(msg.id)
if not msgs.has_next_page():
break
after = msgs.data[-1].id
self._initial_message_ids = initial_message_ids
@property
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
"""The types of messages that the assistant agent produces."""
return (TextMessage,)
@property
def threads(self) -> AsyncThreads:
return self._client.beta.threads
@property
def runs(self) -> AsyncRuns:
return self._client.beta.threads.runs
@property
def messages(self) -> AsyncMessages:
return self._client.beta.threads.messages
@property
def _get_assistant_id(self) -> str:
if self._assistant is None:
raise ValueError("Assistant not initialized")
return self._assistant.id
@property
def _thread_id(self) -> str:
if self._thread is None:
raise ValueError("Thread not initialized")
return self._thread.id
async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str:
"""Execute a tool call and return the result."""
if not self._original_tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._original_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token, call_id=tool_call.id)
return tool.return_value_as_string(result)
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handle incoming messages and return a response."""
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
async def on_messages_stream(
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
"""Handle incoming messages and return a response."""
await self._ensure_initialized()
# Process all messages in sequence
for message in messages:
await self.handle_incoming_message(message, cancellation_token)
# Inner messages for tool calls
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
# Create and start a run
run: Run = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.runs.create( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
assistant_id=self._get_assistant_id,
)
)
)
# Wait for run completion by polling
while True:
run = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.runs.retrieve( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
run_id=run.id,
)
)
)
if run.status == "failed":
raise ValueError(f"Run failed: {run.last_error}")
# If the run requires action (function calls), execute tools and continue
if run.status == "requires_action" and run.required_action is not None:
tool_calls: List[FunctionCall] = []
for required_tool_call in run.required_action.submit_tool_outputs.tool_calls:
if required_tool_call.type == "function":
tool_calls.append(
FunctionCall(
id=required_tool_call.id,
name=required_tool_call.function.name,
arguments=required_tool_call.function.arguments,
)
)
# Add tool call message to inner messages
tool_call_msg = ToolCallRequestEvent(source=self.name, content=tool_calls)
inner_messages.append(tool_call_msg)
event_logger.debug(tool_call_msg)
yield tool_call_msg
# Execute tool calls and get results
tool_outputs: List[FunctionExecutionResult] = []
for tool_call in tool_calls:
try:
result = await self._execute_tool_call(tool_call, cancellation_token)
is_error = False
except Exception as e:
result = f"Error: {e}"
is_error = True
tool_outputs.append(
FunctionExecutionResult(
content=result, call_id=tool_call.id, is_error=is_error, name=tool_call.name
)
)
# Add tool result message to inner messages
tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs)
inner_messages.append(tool_result_msg)
event_logger.debug(tool_result_msg)
yield tool_result_msg
# Submit tool outputs back to the run
run = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.runs.submit_tool_outputs( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
run_id=run.id,
tool_outputs=[{"tool_call_id": t.call_id, "output": t.content} for t in tool_outputs],
)
)
)
continue
if run.status == "completed":
break
await asyncio.sleep(0.5)
# Get messages after run completion
assistant_messages: AsyncCursorPage[Message] = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.list(thread_id=self._thread_id, order="desc", limit=1) # type: ignore[reportDeprecated]
)
)
if not assistant_messages.data:
raise ValueError("No messages received from assistant")
# Get the last message's content
last_message = assistant_messages.data[0]
if not last_message.content:
raise ValueError(f"No content in the last message: {last_message}")
# Extract text content
text_content = [content for content in last_message.content if content.type == "text"]
if not text_content:
raise ValueError(f"Expected text content in the last message: {last_message.content}")
# Return the assistant's response as a Response with inner messages
chat_message = TextMessage(source=self.name, content=text_content[0].text.value)
yield Response(chat_message=chat_message, inner_messages=inner_messages)
async def handle_incoming_message(self, message: BaseChatMessage, cancellation_token: CancellationToken) -> None:
"""Handle regular text messages by adding them to the thread."""
content: str | List[MessageContentPartParam] | None = None
llm_message = message.to_model_message()
if isinstance(llm_message.content, str):
content = llm_message.content
else:
content = []
for c in llm_message.content:
if isinstance(c, str):
content.append(TextContentBlockParam(text=c, type="text"))
elif isinstance(c, Image):
content.append(ImageURLContentBlockParam(image_url=ImageURLParam(url=c.data_uri), type="image_url"))
else:
raise ValueError(f"Unsupported content type: {type(c)} in {message}")
await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.create( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
content=content,
role="user",
)
)
)
async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Handle reset command by deleting new messages and runs since initialization."""
await self._ensure_initialized()
# Retrieve all message IDs in the thread
new_message_ids: List[str] = []
after: str | NotGiven = NOT_GIVEN
while True:
msgs: AsyncCursorPage[Message] = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.list(self._thread_id, after=after, order="asc", limit=100) # type: ignore[reportDeprecated]
)
)
for msg in msgs.data:
if msg.id not in self._initial_message_ids:
new_message_ids.append(msg.id)
if not msgs.has_next_page():
break
after = msgs.data[-1].id
# Delete new messages
for msg_id in new_message_ids:
status: MessageDeleted = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id) # type: ignore[reportDeprecated]
)
)
assert status.deleted is True
async def _upload_files(self, file_paths: str | Iterable[str], cancellation_token: CancellationToken) -> List[str]:
"""Upload files and return their IDs."""
await self._ensure_initialized()
if isinstance(file_paths, str):
file_paths = [file_paths]
file_ids: List[str] = []
for file_path in file_paths:
async with aiofiles.open(file_path, mode="rb") as f:
file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read()))
file_name = os.path.basename(file_path)
file: FileObject = await cancellation_token.link_future(
asyncio.ensure_future(self._client.files.create(file=(file_name, file_content), purpose="assistants"))
)
file_ids.append(file.id)
self._uploaded_file_ids.append(file.id)
return file_ids
async def on_upload_for_code_interpreter(
self, file_paths: str | Iterable[str], cancellation_token: CancellationToken
) -> None:
"""Handle file uploads for the code interpreter."""
await self._ensure_initialized()
file_ids = await self._upload_files(file_paths, cancellation_token)
# Update thread with the new files
thread = await cancellation_token.link_future(
asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id)) # type: ignore[reportDeprecated]
)
tool_resources: ToolResources = thread.tool_resources or ToolResources()
code_interpreter: ToolResourcesCodeInterpreter = (
tool_resources.code_interpreter or ToolResourcesCodeInterpreter()
)
existing_file_ids: List[str] = code_interpreter.file_ids or []
existing_file_ids.extend(file_ids)
tool_resources.code_interpreter = ToolResourcesCodeInterpreter(file_ids=existing_file_ids)
await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.update( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
tool_resources=cast(thread_update_params.ToolResources, tool_resources.model_dump()),
)
)
)
async def on_upload_for_file_search(
self, file_paths: str | Iterable[str], cancellation_token: CancellationToken
) -> None:
"""Handle file uploads for file search."""
await self._ensure_initialized()
# Check if file_search is enabled in tools
if not any(tool.get("type") == "file_search" for tool in self._api_tools):
raise ValueError(
"File search is not enabled for this assistant. Add a file_search tool when creating the assistant."
)
# Create vector store if not already created
if self._vector_store_id is None:
vector_store: VectorStore = await cancellation_token.link_future(
asyncio.ensure_future(self._client.vector_stores.create())
)
self._vector_store_id = vector_store.id
# Update assistant with vector store ID
await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.assistants.update(
assistant_id=self._get_assistant_id,
tool_resources={"file_search": {"vector_store_ids": [self._vector_store_id]}},
)
)
)
file_ids = await self._upload_files(file_paths, cancellation_token)
# Create file batch with the file IDs
await cancellation_token.link_future(
asyncio.ensure_future(
self._client.vector_stores.file_batches.create_and_poll(
vector_store_id=self._vector_store_id, file_ids=file_ids
)
)
)
async def delete_uploaded_files(self, cancellation_token: CancellationToken) -> None:
"""Delete all files that were uploaded by this agent instance."""
await self._ensure_initialized()
for file_id in self._uploaded_file_ids:
try:
await cancellation_token.link_future(asyncio.ensure_future(self._client.files.delete(file_id=file_id)))
except Exception as e:
event_logger.error(f"Failed to delete file {file_id}: {str(e)}")
self._uploaded_file_ids = []
async def delete_assistant(self, cancellation_token: CancellationToken) -> None:
"""Delete the assistant if it was created by this instance."""
await self._ensure_initialized()
if self._assistant is not None and not self._assistant_id:
try:
await cancellation_token.link_future(
asyncio.ensure_future(self._client.beta.assistants.delete(assistant_id=self._get_assistant_id)) # type: ignore[reportDeprecated]
)
self._assistant = None
except Exception as e:
event_logger.error(f"Failed to delete assistant: {str(e)}")
async def delete_vector_store(self, cancellation_token: CancellationToken) -> None:
"""Delete the vector store if it was created by this instance."""
await self._ensure_initialized()
if self._vector_store_id is not None:
try:
await cancellation_token.link_future(
asyncio.ensure_future(self._client.vector_stores.delete(vector_store_id=self._vector_store_id))
)
self._vector_store_id = None
except Exception as e:
event_logger.error(f"Failed to delete vector store: {str(e)}")
async def save_state(self) -> Mapping[str, Any]:
state = OpenAIAssistantAgentState(
assistant_id=self._assistant.id if self._assistant else self._assistant_id,
thread_id=self._thread.id if self._thread else self._init_thread_id,
initial_message_ids=list(self._initial_message_ids),
vector_store_id=self._vector_store_id,
uploaded_file_ids=self._uploaded_file_ids,
)
return state.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
agent_state = OpenAIAssistantAgentState.model_validate(state)
self._assistant_id = agent_state.assistant_id
self._init_thread_id = agent_state.thread_id
self._initial_message_ids = set(agent_state.initial_message_ids)
self._vector_store_id = agent_state.vector_store_id
self._uploaded_file_ids = agent_state.uploaded_file_ids

View File

@@ -0,0 +1,3 @@
from ._video_surfer import VideoSurfer
__all__ = ["VideoSurfer"]

View File

@@ -0,0 +1,172 @@
from typing import Any, Awaitable, Callable, List, Optional
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_core.models import ChatCompletionClient
from agentdhal_core.tools import BaseTool
from pydantic import BaseModel
from .tools import (
extract_audio,
get_screenshot_at,
get_video_length,
save_screenshot,
transcribe_audio_with_timestamps,
transcribe_video_screenshot,
)
class VideoSurfer(AssistantAgent):
"""
VideoSurfer is a specialized agent designed to answer questions about a local video file.
Installation:
.. code-block:: bash
pip install "agentdhal-ext[video-surfer]"
This agent utilizes various tools to extract information from the video, such as its length, screenshots at specific timestamps, and audio transcriptions. It processes these elements to provide detailed answers to user queries.
Available tools:
- :func:`~agentdhal_extensions.agents.video_surfer.tools.extract_audio`
- :func:`~agentdhal_extensions.agents.video_surfer.tools.get_video_length`
- :func:`~agentdhal_extensions.agents.video_surfer.tools.transcribe_audio_with_timestamps`
- :func:`~agentdhal_extensions.agents.video_surfer.tools.get_screenshot_at`
- :func:`~agentdhal_extensions.agents.video_surfer.tools.save_screenshot`
- :func:`~agentdhal_extensions.agents.video_surfer.tools.transcribe_video_screenshot`
Args:
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client used for generating responses.
tools (List[BaseTool[BaseModel, BaseModel] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional):
A list of tools or functions the agent can use. If not provided, defaults to all video tools from the action space.
description (str, optional): A brief description of the agent. Defaults to "An agent that can answer questions about a local video.".
system_message (str | None, optional): The system message guiding the agent's behavior. Defaults to a predefined message.
Example usage:
The following example demonstrates how to create an video surfing agent with
a model client and generate a response to a simple query about a local video
called video.mp4.
.. code-block:: python
import asyncio
from agentdhal_agentchat.ui import Console
from agentdhal_agentchat.conditions import TextMentionTermination
from agentdhal_agentchat.teams import RoundRobinGroupChat
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.agents.video_surfer import VideoSurfer
async def main() -> None:
\"\"\"
Main function to run the video agent.
\"\"\"
# Define an agent
video_agent = VideoSurfer(
name="VideoSurfer",
model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-08-06")
)
# Define termination condition
termination = TextMentionTermination("TERMINATE")
# Define a team
agent_team = RoundRobinGroupChat([video_agent], termination_condition=termination)
# Run the team and stream messages to the console
stream = agent_team.run_stream(task="How does Adam define complex tasks in video.mp4? What concrete example of complex does his use? Can you save this example to disk as well?")
await Console(stream)
asyncio.run(main())
The following example demonstrates how to create and use a VideoSurfer and UserProxyAgent with MagenticOneGroupChat.
.. code-block:: python
import asyncio
from agentdhal_agentchat.ui import Console
from agentdhal_agentchat.teams import MagenticOneGroupChat
from agentdhal_agentchat.agents import UserProxyAgent
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.agents.video_surfer import VideoSurfer
async def main() -> None:
\"\"\"
Main function to run the video agent.
\"\"\"
model_client = OpenAIChatCompletionClient(model="gpt-4o-2024-08-06")
# Define an agent
video_agent = VideoSurfer(
name="VideoSurfer",
model_client=model_client
)
web_surfer_agent = UserProxyAgent(
name="User"
)
# Define a team
agent_team = MagenticOneGroupChat([web_surfer_agent, video_agent], model_client=model_client,)
# Run the team and stream messages to the console
stream = agent_team.run_stream(task="Find a latest video about magentic one on youtube and extract quotes from it that make sense.")
await Console(stream)
asyncio.run(main())
"""
DEFAULT_DESCRIPTION = "An agent that can answer questions about a local video."
DEFAULT_SYSTEM_MESSAGE = """
You are a helpful agent that is an expert at answering questions from a video.
When asked to answer a question about a video, you should:
1. Check if that video is available locally.
2. Use the transcription to find which part of the video the question is referring to.
3. Optionally use screenshots from those timestamps
4. Provide a detailed answer to the question.
Reply with TERMINATE when the task has been completed.
"""
def __init__(
self,
name: str,
model_client: ChatCompletionClient,
*,
tools: List[BaseTool[BaseModel, BaseModel] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
description: Optional[str] = None,
system_message: Optional[str] = None,
):
super().__init__(
name=name,
model_client=model_client,
tools=tools
or [
get_video_length,
get_screenshot_at,
save_screenshot,
self.vs_transribe_video_screenshot,
extract_audio,
transcribe_audio_with_timestamps,
],
description=description or self.DEFAULT_DESCRIPTION,
system_message=system_message or self.DEFAULT_SYSTEM_MESSAGE,
)
async def vs_transribe_video_screenshot(self, video_path: str, timestamp: float) -> str:
"""
Transcribes the video screenshot at a specific timestamp.
Args:
video_path (str): Path to the video file.
timestamp (float): Timestamp to take the screenshot.
Returns:
str: Transcription of the video screenshot.
"""
return await transcribe_video_screenshot(video_path, timestamp, self._model_client)

View File

@@ -0,0 +1,156 @@
import base64
from typing import Any, Dict, List, Tuple
import cv2
import ffmpeg
import numpy as np
import whisper
from agentdhal_core import Image as AGImage
from agentdhal_core.models import (
ChatCompletionClient,
UserMessage,
)
def extract_audio(video_path: str, audio_output_path: str) -> str:
"""
Extracts audio from a video file and saves it as an MP3 file.
:param video_path: Path to the video file.
:param audio_output_path: Path to save the extracted audio file.
:return: Confirmation message with the path to the saved audio file.
"""
(ffmpeg.input(video_path).output(audio_output_path, format="mp3").run(quiet=True, overwrite_output=True)) # type: ignore
return f"Audio extracted and saved to {audio_output_path}."
def transcribe_audio_with_timestamps(audio_path: str) -> str:
"""
Transcribes the audio file with timestamps using the Whisper model.
:param audio_path: Path to the audio file.
:return: Transcription with timestamps.
"""
model = whisper.load_model("base") # type: ignore
result: Dict[str, Any] = model.transcribe(audio_path, task="transcribe", language="en", verbose=False) # type: ignore
segments: List[Dict[str, Any]] = result["segments"]
transcription_with_timestamps = ""
for segment in segments:
start: float = segment["start"]
end: float = segment["end"]
text: str = segment["text"]
transcription_with_timestamps += f"[{start:.2f} - {end:.2f}] {text}\n"
return transcription_with_timestamps
def get_video_length(video_path: str) -> str:
"""
Returns the length of the video in seconds.
:param video_path: Path to the video file.
:return: Duration of the video in seconds.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Cannot open video file {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
duration = frame_count / fps
cap.release()
return f"The video is {duration:.2f} seconds long."
def save_screenshot(video_path: str, timestamp: float, output_path: str) -> None:
"""
Captures a screenshot at the specified timestamp and saves it to the output path.
:param video_path: Path to the video file.
:param timestamp: Timestamp in seconds.
:param output_path: Path to save the screenshot. The file format is determined by the extension in the path.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Cannot open video file {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
frame_number = int(timestamp * fps)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = cap.read()
if ret:
cv2.imwrite(output_path, frame)
else:
raise IOError(f"Failed to capture frame at {timestamp:.2f}s")
cap.release()
async def transcribe_video_screenshot(video_path: str, timestamp: float, model_client: ChatCompletionClient) -> str:
"""
Transcribes the content of a video screenshot captured at the specified timestamp using OpenAI API.
:param video_path: Path to the video file.
:param timestamp: Timestamp in seconds.
:param model_client: ChatCompletionClient instance.
:return: Description of the screenshot content.
"""
screenshots = get_screenshot_at(video_path, [timestamp])
if not screenshots:
return "Failed to capture screenshot."
_, frame = screenshots[0]
# Convert the frame to bytes and then to base64 encoding
_, buffer = cv2.imencode(".jpg", frame)
frame_bytes = buffer.tobytes()
frame_base64 = base64.b64encode(frame_bytes).decode("utf-8")
screenshot_uri = f"data:image/jpeg;base64,{frame_base64}"
messages = [
UserMessage(
content=[
"Following is a screenshot from the video at {} seconds. Describe what you see here.",
AGImage.from_uri(screenshot_uri),
],
source="tool",
)
]
result = await model_client.create(messages=messages)
return str(result.content)
def get_screenshot_at(video_path: str, timestamps: List[float]) -> List[Tuple[float, np.ndarray[Any, Any]]]:
"""
Captures screenshots at the specified timestamps and returns them as Python objects.
:param video_path: Path to the video file.
:param timestamps: List of timestamps in seconds.
:return: List of tuples containing timestamp and the corresponding frame (image).
Each frame is a NumPy array (height x width x channels).
"""
screenshots: List[Tuple[float, np.ndarray[Any, Any]]] = []
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Cannot open video file {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
duration = total_frames / fps
for timestamp in timestamps:
if 0 <= timestamp <= duration:
frame_number = int(timestamp * fps)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = cap.read()
if ret:
# Append the timestamp and frame to the list
screenshots.append((timestamp, frame))
else:
raise IOError(f"Failed to capture frame at {timestamp:.2f}s")
else:
raise ValueError(f"Timestamp {timestamp:.2f}s is out of range [0s, {duration:.2f}s]")
cap.release()
return screenshots

View File

@@ -0,0 +1,4 @@
from ._multimodal_web_surfer import MultimodalWebSurfer
from .playwright_controller import PlaywrightController
__all__ = ["MultimodalWebSurfer", "PlaywrightController"]

View File

@@ -0,0 +1,11 @@
from dataclasses import dataclass
from typing import Any, Dict
@dataclass
class WebSurferEvent:
source: str
message: str
url: str
action: str | None = None
arguments: Dict[str, Any] | None = None

View File

@@ -0,0 +1,988 @@
import asyncio
import base64
import hashlib
import io
import json
import logging
import os
import re
import sys
import time
import traceback
import warnings
from typing import (
Any,
AsyncGenerator,
Dict,
List,
Optional,
Sequence,
)
from urllib.parse import quote_plus
import aiofiles
import PIL.Image
from agentdhal_agentchat.agents import BaseChatAgent
from agentdhal_agentchat.base import Response
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, MultiModalMessage, TextMessage
from agentdhal_agentchat.utils import content_to_str, remove_images
from agentdhal_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel, FunctionCall
from agentdhal_core import Image as AGImage
from agentdhal_core.models import (
AssistantMessage,
ChatCompletionClient,
LLMMessage,
ModelFamily,
RequestUsage,
SystemMessage,
UserMessage,
)
from PIL import Image
from playwright.async_api import BrowserContext, Download, Page, Playwright, async_playwright
from pydantic import BaseModel
from typing_extensions import Self
from ._events import WebSurferEvent
from ._prompts import (
WEB_SURFER_QA_PROMPT,
WEB_SURFER_QA_SYSTEM_MESSAGE,
WEB_SURFER_TOOL_PROMPT_MM,
WEB_SURFER_TOOL_PROMPT_TEXT,
)
from ._set_of_mark import add_set_of_mark
from ._tool_definitions import (
TOOL_CLICK,
TOOL_HISTORY_BACK,
TOOL_HOVER,
TOOL_READ_PAGE_AND_ANSWER,
TOOL_SCROLL_DOWN,
TOOL_SCROLL_UP,
TOOL_SLEEP,
TOOL_SUMMARIZE_PAGE,
TOOL_TYPE,
TOOL_VISIT_URL,
TOOL_WEB_SEARCH,
)
from ._types import InteractiveRegion, UserContent
from .playwright_controller import PlaywrightController
DEFAULT_CONTEXT_SIZE = 128000
class MultimodalWebSurferConfig(BaseModel):
name: str
model_client: ComponentModel
downloads_folder: str | None = None
description: str | None = None
debug_dir: str | None = None
headless: bool = True
start_page: str | None = "https://www.bing.com/"
animate_actions: bool = False
to_save_screenshots: bool = False
use_ocr: bool = False
browser_channel: str | None = None
browser_data_dir: str | None = None
to_resize_viewport: bool = True
class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
"""
MultimodalWebSurfer is a multimodal agent that acts as a web surfer that can search the web and visit web pages.
Installation:
.. code-block:: bash
pip install "agentdhal-ext[web-surfer]"
It launches a chromium browser and allows the playwright to interact with the web browser and can perform a variety of actions. The browser is launched on the first call to the agent and is reused for subsequent calls.
It must be used with a multimodal model client that supports function/tool calling, ideally GPT-4o currently.
When :meth:`on_messages` or :meth:`on_messages_stream` is called, the following occurs:
1) If this is the first call, the browser is initialized and the page is loaded. This is done in :meth:`_lazy_init`. The browser is only closed when :meth:`close` is called.
2) The method :meth:`_generate_reply` is called, which then creates the final response as below.
3) The agent takes a screenshot of the page, extracts the interactive elements, and prepares a set-of-mark screenshot with bounding boxes around the interactive elements.
4) The agent makes a call to the :attr:`model_client` with the SOM screenshot, history of messages, and the list of available tools.
- If the model returns a string, the agent returns the string as the final response.
- If the model returns a list of tool calls, the agent executes the tool calls with :meth:`_execute_tool` using :attr:`_playwright_controller`.
- The agent returns a final response which includes a screenshot of the page, page metadata, description of the action taken and the inner text of the webpage.
5) If at any point the agent encounters an error, it returns the error message as the final response.
.. note::
Please note that using the MultimodalWebSurfer involves interacting with a digital world designed for humans, which carries inherent risks.
Be aware that agents may occasionally attempt risky actions, such as recruiting humans for help or accepting cookie agreements without human involvement. Always ensure agents are monitored and operate within a controlled environment to prevent unintended consequences.
Moreover, be cautious that MultimodalWebSurfer may be susceptible to prompt injection attacks from webpages.
.. note::
On Windows, the event loop policy must be set to `WindowsProactorEventLoopPolicy` to avoid issues with subprocesses.
.. code-block:: python
import sys
import asyncio
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
Args:
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client used by the agent. Must be multimodal and support function calling.
downloads_folder (str, optional): The folder where downloads are saved. Defaults to None, no downloads are saved.
description (str, optional): The description of the agent. Defaults to MultimodalWebSurfer.DEFAULT_DESCRIPTION.
debug_dir (str, optional): The directory where debug information is saved. Defaults to None.
headless (bool, optional): Whether the browser should be headless. Defaults to True.
start_page (str, optional): The start page for the browser. Defaults to MultimodalWebSurfer.DEFAULT_START_PAGE.
animate_actions (bool, optional): Whether to animate actions. Defaults to False.
to_save_screenshots (bool, optional): Whether to save screenshots. Defaults to False.
use_ocr (bool, optional): Whether to use OCR. Defaults to False.
browser_channel (str, optional): The browser channel. Defaults to None.
browser_data_dir (str, optional): The browser data directory. Defaults to None.
to_resize_viewport (bool, optional): Whether to resize the viewport. Defaults to True.
playwright (Playwright, optional): The playwright instance. Defaults to None.
context (BrowserContext, optional): The browser context. Defaults to None.
Example usage:
The following example demonstrates how to create a web surfing agent with
a model client and run it for multiple turns.
.. code-block:: python
import asyncio
from agentdhal_agentchat.ui import Console
from agentdhal_agentchat.teams import RoundRobinGroupChat
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.agents.web_surfer import MultimodalWebSurfer
async def main() -> None:
# Define an agent
web_surfer_agent = MultimodalWebSurfer(
name="MultimodalWebSurfer",
model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-08-06"),
)
# Define a team
agent_team = RoundRobinGroupChat([web_surfer_agent], max_turns=3)
# Run the team and stream messages to the console
stream = agent_team.run_stream(task="Navigate to the AutoGen readme on GitHub.")
await Console(stream)
# Close the browser controlled by the agent
await web_surfer_agent.close()
asyncio.run(main())
"""
component_type = "agent"
component_config_schema = MultimodalWebSurferConfig
component_provider_override = "agentdhal_extensions.agents.web_surfer.MultimodalWebSurfer"
DEFAULT_DESCRIPTION = """
A helpful assistant with access to a web browser.
Ask them to perform web searches, open pages, and interact with content (e.g., clicking links, scrolling the viewport, filling in form fields, etc.).
It can also summarize the entire page, or answer questions based on the content of the page.
It can also be asked to sleep and wait for pages to load, in cases where the page seems not yet fully loaded.
"""
DEFAULT_START_PAGE = "https://www.bing.com/"
# Viewport dimensions
VIEWPORT_HEIGHT = 900
VIEWPORT_WIDTH = 1440
# Size of the image we send to the MLM
# Current values represent a 0.85 scaling to fit within the GPT-4v short-edge constraints (768px)
MLM_HEIGHT = 765
MLM_WIDTH = 1224
SCREENSHOT_TOKENS = 1105
def __init__(
self,
name: str,
model_client: ChatCompletionClient,
downloads_folder: str | None = None,
description: str = DEFAULT_DESCRIPTION,
debug_dir: str | None = None,
headless: bool = True,
start_page: str | None = DEFAULT_START_PAGE,
animate_actions: bool = False,
to_save_screenshots: bool = False,
use_ocr: bool = False,
browser_channel: str | None = None,
browser_data_dir: str | None = None,
to_resize_viewport: bool = True,
playwright: Playwright | None = None,
context: BrowserContext | None = None,
):
"""
Initialize the MultimodalWebSurfer.
"""
super().__init__(name, description)
if debug_dir is None and to_save_screenshots:
raise ValueError(
"Cannot save screenshots without a debug directory. Set it using the 'debug_dir' parameter. The debug directory is created if it does not exist."
)
if model_client.model_info["function_calling"] is False:
raise ValueError(
"The model does not support function calling. MultimodalWebSurfer requires a model that supports function calling."
)
self._model_client = model_client
self.headless = headless
self.browser_channel = browser_channel
self.browser_data_dir = browser_data_dir
self.start_page = start_page or self.DEFAULT_START_PAGE
self.downloads_folder = downloads_folder
self.debug_dir = debug_dir
self.to_save_screenshots = to_save_screenshots
self.use_ocr = use_ocr
self.to_resize_viewport = to_resize_viewport
self.animate_actions = animate_actions
# Call init to set these in case not set
self._playwright: Playwright | None = playwright
self._context: BrowserContext | None = context
self._page: Page | None = None
self._last_download: Download | None = None
self._prior_metadata_hash: str | None = None
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.name}.MultimodalWebSurfer")
self._chat_history: List[LLMMessage] = []
# Define the download handler
def _download_handler(download: Download) -> None:
self._last_download = download
self._download_handler = _download_handler
# Define the Playwright controller that handles the browser interactions
self._playwright_controller = PlaywrightController(
animate_actions=self.animate_actions,
downloads_folder=self.downloads_folder,
viewport_width=self.VIEWPORT_WIDTH,
viewport_height=self.VIEWPORT_HEIGHT,
_download_handler=self._download_handler,
to_resize_viewport=self.to_resize_viewport,
)
self.default_tools = [
TOOL_VISIT_URL,
TOOL_WEB_SEARCH,
TOOL_HISTORY_BACK,
TOOL_CLICK,
TOOL_TYPE,
TOOL_READ_PAGE_AND_ANSWER,
TOOL_SUMMARIZE_PAGE,
TOOL_SLEEP,
TOOL_HOVER,
]
self.did_lazy_init = False # flag to check if we have initialized the browser
async def _lazy_init(
self,
) -> None:
"""
On the first call, we initialize the browser and the page.
"""
# Check the current event loop policy if on windows.
if sys.platform == "win32":
current_policy = asyncio.get_event_loop_policy()
if hasattr(asyncio, "WindowsProactorEventLoopPolicy") and not isinstance(
current_policy, asyncio.WindowsProactorEventLoopPolicy
):
warnings.warn(
"The current event loop policy is not WindowsProactorEventLoopPolicy. "
"This may cause issues with subprocesses. "
"Try setting the event loop policy to WindowsProactorEventLoopPolicy. "
"For example: `asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())`. "
"See https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.ProactorEventLoop.",
stacklevel=2,
)
self._last_download = None
self._prior_metadata_hash = None
# Create the playwright self
launch_args: Dict[str, Any] = {"headless": self.headless}
if self.browser_channel is not None:
launch_args["channel"] = self.browser_channel
if self._playwright is None:
self._playwright = await async_playwright().start()
# Create the context -- are we launching persistent?
if self._context is None:
if self.browser_data_dir is None:
browser = await self._playwright.chromium.launch(**launch_args)
self._context = await browser.new_context(
user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36 Edg/122.0.0.0"
)
else:
self._context = await self._playwright.chromium.launch_persistent_context(
self.browser_data_dir, **launch_args
)
# Create the page
self._context.set_default_timeout(60000) # One minute
self._page = await self._context.new_page()
assert self._page is not None
# self._page.route(lambda x: True, self._route_handler)
self._page.on("download", self._download_handler)
if self.to_resize_viewport:
await self._page.set_viewport_size({"width": self.VIEWPORT_WIDTH, "height": self.VIEWPORT_HEIGHT})
await self._page.add_init_script(
path=os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js")
)
await self._page.goto(self.start_page)
await self._page.wait_for_load_state()
# Prepare the debug directory -- which stores the screenshots generated throughout the process
await self._set_debug_dir(self.debug_dir)
self.did_lazy_init = True
async def close(self) -> None:
"""
Close the browser and the page.
Should be called when the agent is no longer needed.
"""
if self._page is not None:
await self._page.close()
self._page = None
if self._context is not None:
await self._context.close()
self._context = None
if self._playwright is not None:
await self._playwright.stop()
self._playwright = None
async def _set_debug_dir(self, debug_dir: str | None) -> None:
assert self._page is not None
if self.debug_dir is None:
return
if not os.path.isdir(self.debug_dir):
os.mkdir(self.debug_dir)
if self.to_save_screenshots:
current_timestamp = "_" + int(time.time()).__str__()
screenshot_png_name = "screenshot" + current_timestamp + ".png"
await self._page.screenshot(path=os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore
self.logger.info(
WebSurferEvent(
source=self.name,
url=self._page.url,
message="Screenshot: " + screenshot_png_name,
)
)
@property
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (MultiModalMessage,)
async def on_reset(self, cancellation_token: CancellationToken) -> None:
if not self.did_lazy_init:
return
assert self._page is not None
self._chat_history.clear()
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
self._page, self.start_page
)
if reset_last_download and self._last_download is not None:
self._last_download = None
if reset_prior_metadata and self._prior_metadata_hash is not None:
self._prior_metadata_hash = None
if self.to_save_screenshots:
current_timestamp = "_" + int(time.time()).__str__()
screenshot_png_name = "screenshot" + current_timestamp + ".png"
await self._page.screenshot(path=os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore
self.logger.info(
WebSurferEvent(
source=self.name,
url=self._page.url,
message="Screenshot: " + screenshot_png_name,
)
)
self.logger.info(
WebSurferEvent(
source=self.name,
url=self._page.url,
message="Resetting browser.",
)
)
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
async def on_messages_stream(
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
for chat_message in messages:
self._chat_history.append(chat_message.to_model_message())
self.inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
self.model_usage: List[RequestUsage] = []
try:
content = await self._generate_reply(cancellation_token=cancellation_token)
self._chat_history.append(AssistantMessage(content=content_to_str(content), source=self.name))
final_usage = RequestUsage(
prompt_tokens=sum([u.prompt_tokens for u in self.model_usage]),
completion_tokens=sum([u.completion_tokens for u in self.model_usage]),
)
if isinstance(content, str):
yield Response(
chat_message=TextMessage(content=content, source=self.name, models_usage=final_usage),
inner_messages=self.inner_messages,
)
else:
yield Response(
chat_message=MultiModalMessage(content=content, source=self.name, models_usage=final_usage),
inner_messages=self.inner_messages,
)
except BaseException:
content = f"Web surfing error:\n\n{traceback.format_exc()}"
self._chat_history.append(AssistantMessage(content=content, source=self.name))
yield Response(chat_message=TextMessage(content=content, source=self.name))
async def _generate_reply(self, cancellation_token: CancellationToken) -> UserContent:
"""Generates the actual reply. First calls the LLM to figure out which tool to use, then executes the tool."""
# Lazy init, initialize the browser and the page on the first generate reply only
if not self.did_lazy_init:
await self._lazy_init()
assert self._page is not None
# Clone the messages, removing old screenshots
history: List[LLMMessage] = remove_images(self._chat_history)
# Split the history, removing the last message
if len(history):
user_request = history.pop()
else:
user_request = UserMessage(content="Empty request.", source="user")
# Truncate the history for smaller models
if self._model_client.model_info["family"] not in [
ModelFamily.GPT_4O,
ModelFamily.O1,
ModelFamily.O3,
ModelFamily.GPT_4,
ModelFamily.GPT_35,
]:
history = []
# Ask the page for interactive elements, then prepare the state-of-mark screenshot
rects = await self._playwright_controller.get_interactive_rects(self._page)
viewport = await self._playwright_controller.get_visual_viewport(self._page)
screenshot = await self._page.screenshot()
som_screenshot, visible_rects, rects_above, rects_below = add_set_of_mark(screenshot, rects)
if self.to_save_screenshots:
current_timestamp = "_" + int(time.time()).__str__()
screenshot_png_name = "screenshot_som" + current_timestamp + ".png"
som_screenshot.save(os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore
self.logger.info(
WebSurferEvent(
source=self.name,
url=self._page.url,
message="Screenshot: " + screenshot_png_name,
)
)
# What tools are available?
tools = self.default_tools.copy()
# We can scroll up
if viewport["pageTop"] > 5:
tools.append(TOOL_SCROLL_UP)
# Can scroll down
if (viewport["pageTop"] + viewport["height"] + 5) < viewport["scrollHeight"]:
tools.append(TOOL_SCROLL_DOWN)
# Focus hint
focused = await self._playwright_controller.get_focused_rect_id(self._page)
focused_hint = ""
if focused:
name = self._target_name(focused, rects)
if name:
name = f"(and name '{name}') "
else:
name = ""
role = "control"
try:
role = rects[focused]["role"]
except KeyError:
pass
focused_hint = f"\nThe {role} with ID {focused} {name}currently has the input focus.\n\n"
# Everything visible
visible_targets = "\n".join(self._format_target_list(visible_rects, rects)) + "\n\n"
# Everything else
other_targets: List[str] = []
other_targets.extend(self._format_target_list(rects_above, rects))
other_targets.extend(self._format_target_list(rects_below, rects))
if len(other_targets) > 0:
if len(other_targets) > 30:
other_targets = other_targets[0:30]
other_targets.append("...")
other_targets_str = (
"Additional valid interaction targets include (but are not limited to):\n"
+ "\n".join(other_targets)
+ "\n\n"
)
else:
other_targets_str = ""
state_description = "Your " + await self._get_state_description()
tool_names = "\n".join([t["name"] for t in tools])
page_title = await self._page.title()
prompt_message = None
if self._model_client.model_info["vision"]:
text_prompt = WEB_SURFER_TOOL_PROMPT_MM.format(
state_description=state_description,
visible_targets=visible_targets,
other_targets_str=other_targets_str,
focused_hint=focused_hint,
tool_names=tool_names,
title=page_title,
url=self._page.url,
).strip()
# Scale the screenshot for the MLM, and close the original
scaled_screenshot = som_screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
som_screenshot.close()
if self.to_save_screenshots:
scaled_screenshot.save(os.path.join(self.debug_dir, "screenshot_scaled.png")) # type: ignore
# Create the message
prompt_message = UserMessage(
content=[re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), AGImage.from_pil(scaled_screenshot)],
source=self.name,
)
else:
text_prompt = WEB_SURFER_TOOL_PROMPT_TEXT.format(
state_description=state_description,
visible_targets=visible_targets,
other_targets_str=other_targets_str,
focused_hint=focused_hint,
tool_names=tool_names,
title=page_title,
url=self._page.url,
).strip()
# Create the message
prompt_message = UserMessage(content=re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), source=self.name)
history.append(prompt_message)
history.append(user_request)
# {history[-2].content if isinstance(history[-2].content, str) else history[-2].content[0]}
# print(f"""
# ================={len(history)}=================
# {history[-2].content}
# =====
# {history[-1].content}
# ===================================================
# """)
# Make the request
response = await self._model_client.create(
history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token
) # , "parallel_tool_calls": False})
self.model_usage.append(response.usage)
message = response.content
self._last_download = None
if isinstance(message, str):
# Answer directly
self.inner_messages.append(TextMessage(content=message, source=self.name))
return message
elif isinstance(message, list):
# Take an action
return await self._execute_tool(message, rects, tool_names, cancellation_token=cancellation_token)
else:
# Not sure what happened here
raise AssertionError(f"Unknown response format '{message}'")
async def _execute_tool(
self,
message: List[FunctionCall],
rects: Dict[str, InteractiveRegion],
tool_names: str,
cancellation_token: Optional[CancellationToken] = None,
) -> UserContent:
# Execute the tool
name = message[0].name
args = json.loads(message[0].arguments)
action_description = ""
assert self._page is not None
self.logger.info(
WebSurferEvent(
source=self.name,
url=self._page.url,
action=name,
arguments=args,
message=f"{name}( {json.dumps(args)} )",
)
)
self.inner_messages.append(TextMessage(content=f"{name}( {json.dumps(args)} )", source=self.name))
if name == "visit_url":
url = args.get("url")
action_description = f"I typed '{url}' into the browser address bar."
# Check if the argument starts with a known protocol
if url.startswith(("https://", "http://", "file://", "about:")):
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
self._page, url
)
# If the argument contains a space, treat it as a search query
elif " " in url:
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
self._page, f"https://www.bing.com/search?q={quote_plus(url)}&FORM=QBLH"
)
# Otherwise, prefix with https://
else:
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
self._page, "https://" + url
)
if reset_last_download and self._last_download is not None:
self._last_download = None
if reset_prior_metadata and self._prior_metadata_hash is not None:
self._prior_metadata_hash = None
elif name == "history_back":
action_description = "I clicked the browser back button."
await self._playwright_controller.back(self._page)
elif name == "web_search":
query = args.get("query")
action_description = f"I typed '{query}' into the browser search bar."
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
self._page, f"https://www.bing.com/search?q={quote_plus(query)}&FORM=QBLH"
)
if reset_last_download and self._last_download is not None:
self._last_download = None
if reset_prior_metadata and self._prior_metadata_hash is not None:
self._prior_metadata_hash = None
elif name == "scroll_up":
action_description = "I scrolled up one page in the browser."
await self._playwright_controller.page_up(self._page)
elif name == "scroll_down":
action_description = "I scrolled down one page in the browser."
await self._playwright_controller.page_down(self._page)
elif name == "click":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I clicked '{target_name}'."
else:
action_description = "I clicked the control."
new_page_tentative = await self._playwright_controller.click_id(self._page, target_id)
if new_page_tentative is not None:
self._page = new_page_tentative
self._prior_metadata_hash = None
self.logger.info(
WebSurferEvent(
source=self.name,
url=self._page.url,
message="New tab or window.",
)
)
elif name == "input_text":
input_field_id = str(args.get("input_field_id"))
text_value = str(args.get("text_value"))
input_field_name = self._target_name(input_field_id, rects)
if input_field_name:
action_description = f"I typed '{text_value}' into '{input_field_name}'."
else:
action_description = f"I input '{text_value}'."
await self._playwright_controller.fill_id(self._page, input_field_id, text_value)
elif name == "scroll_element_up":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I scrolled '{target_name}' up."
else:
action_description = "I scrolled the control up."
await self._playwright_controller.scroll_id(self._page, target_id, "up")
elif name == "scroll_element_down":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I scrolled '{target_name}' down."
else:
action_description = "I scrolled the control down."
await self._playwright_controller.scroll_id(self._page, target_id, "down")
elif name == "answer_question":
question = str(args.get("question"))
action_description = f"I answered the following question '{question}' based on the web page."
# Do Q&A on the DOM. No need to take further action. Browser state does not change.
return await self._summarize_page(question=question, cancellation_token=cancellation_token)
elif name == "summarize_page":
# Summarize the DOM. No need to take further action. Browser state does not change.
action_description = "I summarized the current web page"
return await self._summarize_page(cancellation_token=cancellation_token)
elif name == "hover":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I hovered over '{target_name}'."
else:
action_description = "I hovered over the control."
await self._playwright_controller.hover_id(self._page, target_id)
elif name == "sleep":
action_description = "I am waiting a short period of time before taking further action."
await self._playwright_controller.sleep(self._page, 3)
else:
raise ValueError(f"Unknown tool '{name}'. Please choose from:\n\n{tool_names}")
await self._page.wait_for_load_state()
await self._playwright_controller.sleep(self._page, 3)
# Handle downloads
if self._last_download is not None and self.downloads_folder is not None:
fname = os.path.join(self.downloads_folder, self._last_download.suggested_filename)
await self._last_download.save_as(fname) # type: ignore
page_body = f"<html><head><title>Download Successful</title></head><body style=\"margin: 20px;\"><h1>Successfully downloaded '{self._last_download.suggested_filename}' to local path:<br><br>{fname}</h1></body></html>"
await self._page.goto(
"data:text/html;base64," + base64.b64encode(page_body.encode("utf-8")).decode("utf-8")
)
await self._page.wait_for_load_state()
# Handle metadata
page_metadata = json.dumps(await self._playwright_controller.get_page_metadata(self._page), indent=4)
metadata_hash = hashlib.md5(page_metadata.encode("utf-8")).hexdigest()
if metadata_hash != self._prior_metadata_hash:
page_metadata = (
"\n\nThe following metadata was extracted from the webpage:\n\n" + page_metadata.strip() + "\n"
)
else:
page_metadata = ""
self._prior_metadata_hash = metadata_hash
new_screenshot = await self._page.screenshot()
if self.to_save_screenshots:
current_timestamp = "_" + int(time.time()).__str__()
screenshot_png_name = "screenshot" + current_timestamp + ".png"
async with aiofiles.open(os.path.join(self.debug_dir, screenshot_png_name), "wb") as file: # type: ignore
await file.write(new_screenshot) # type: ignore
self.logger.info(
WebSurferEvent(
source=self.name,
url=self._page.url,
message="Screenshot: " + screenshot_png_name,
)
)
# Return the complete observation
state_description = "The " + await self._get_state_description()
message_content = (
f"{action_description}\n\n" + state_description + page_metadata + "\nHere is a screenshot of the page."
)
return [
re.sub(r"(\n\s*){3,}", "\n\n", message_content), # Removing blank lines
AGImage.from_pil(PIL.Image.open(io.BytesIO(new_screenshot))),
]
async def _get_state_description(self) -> str:
assert self._playwright_controller is not None
assert self._page is not None
# Describe the viewport of the new page in words
viewport = await self._playwright_controller.get_visual_viewport(self._page)
percent_visible = int(viewport["height"] * 100 / viewport["scrollHeight"])
percent_scrolled = int(viewport["pageTop"] * 100 / viewport["scrollHeight"])
if percent_scrolled < 1: # Allow some rounding error
position_text = "at the top of the page"
elif percent_scrolled + percent_visible >= 99: # Allow some rounding error
position_text = "at the bottom of the page"
else:
position_text = str(percent_scrolled) + "% down from the top of the page"
visible_text = await self._playwright_controller.get_visible_text(self._page)
# Return the complete observation
page_title = await self._page.title()
message_content = f"web browser is open to the page [{page_title}]({self._page.url}).\nThe viewport shows {percent_visible}% of the webpage, and is positioned {position_text}\n"
message_content += f"The following text is visible in the viewport:\n\n{visible_text}"
return message_content
def _target_name(self, target: str, rects: Dict[str, InteractiveRegion]) -> str | None:
try:
return rects[target]["aria_name"].strip()
except KeyError:
return None
def _format_target_list(self, ids: List[str], rects: Dict[str, InteractiveRegion]) -> List[str]:
"""
Format the list of targets in the webpage as a string to be used in the agent's prompt.
"""
targets: List[str] = []
for r in list(set(ids)):
if r in rects:
# Get the role
aria_role = rects[r].get("role", "").strip()
if len(aria_role) == 0:
aria_role = rects[r].get("tag_name", "").strip()
# Get the name
aria_name = re.sub(r"[\n\r]+", " ", rects[r].get("aria_name", "")).strip()
# What are the actions?
actions = ['"click", "hover"']
if rects[r]["role"] in ["textbox", "searchbox", "search"]:
actions = ['"input_text"']
actions_str = "[" + ",".join(actions) + "]"
targets.append(f'{{"id": {r}, "name": "{aria_name}", "role": "{aria_role}", "tools": {actions_str} }}')
return targets
async def _summarize_page(
self,
question: str | None = None,
cancellation_token: Optional[CancellationToken] = None,
) -> str:
assert self._page is not None
page_markdown: str = await self._playwright_controller.get_page_markdown(self._page)
title: str = self._page.url
try:
title = await self._page.title()
except Exception:
pass
# Take a screenshot and scale it
screenshot = Image.open(io.BytesIO(await self._page.screenshot()))
scaled_screenshot = screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
screenshot.close()
ag_image = AGImage.from_pil(scaled_screenshot)
# Prepare the system prompt
messages: List[LLMMessage] = []
messages.append(SystemMessage(content=WEB_SURFER_QA_SYSTEM_MESSAGE))
prompt = WEB_SURFER_QA_PROMPT(title, question)
# Grow the buffer (which is added to the prompt) until we overflow the context window or run out of lines
buffer = ""
# for line in re.split(r"([\r\n]+)", page_markdown):
for line in page_markdown.splitlines():
trial_message = UserMessage(
content=prompt + buffer + line,
source=self.name,
)
try:
remaining = self._model_client.remaining_tokens(messages + [trial_message])
except KeyError:
# Use the default if the model isn't found
remaining = DEFAULT_CONTEXT_SIZE - self._model_client.count_tokens(messages + [trial_message])
if self._model_client.model_info["vision"] and remaining <= 0:
break
if self._model_client.model_info["vision"] and remaining <= self.SCREENSHOT_TOKENS:
break
buffer += line
# Nothing to do
buffer = buffer.strip()
if len(buffer) == 0:
return "Nothing to summarize."
# Append the message
if self._model_client.model_info["vision"]:
# Multimodal
messages.append(
UserMessage(
content=[
prompt + buffer,
ag_image,
],
source=self.name,
)
)
else:
# Text only
messages.append(
UserMessage(
content=prompt + buffer,
source=self.name,
)
)
# Generate the response
response = await self._model_client.create(messages, cancellation_token=cancellation_token)
self.model_usage.append(response.usage)
scaled_screenshot.close()
assert isinstance(response.content, str)
return response.content
def _to_config(self) -> MultimodalWebSurferConfig:
return MultimodalWebSurferConfig(
name=self.name,
model_client=self._model_client.dump_component(),
downloads_folder=self.downloads_folder,
description=self.description,
debug_dir=self.debug_dir,
headless=self.headless,
start_page=self.start_page,
animate_actions=self.animate_actions,
to_save_screenshots=self.to_save_screenshots,
use_ocr=self.use_ocr,
browser_channel=self.browser_channel,
browser_data_dir=self.browser_data_dir,
to_resize_viewport=self.to_resize_viewport,
)
@classmethod
def _from_config(cls, config: MultimodalWebSurferConfig) -> Self:
return cls(
name=config.name,
model_client=ChatCompletionClient.load_component(config.model_client),
downloads_folder=config.downloads_folder,
description=config.description or cls.DEFAULT_DESCRIPTION,
debug_dir=config.debug_dir,
headless=config.headless,
start_page=config.start_page or cls.DEFAULT_START_PAGE,
animate_actions=config.animate_actions,
to_save_screenshots=config.to_save_screenshots,
use_ocr=config.use_ocr,
browser_channel=config.browser_channel,
browser_data_dir=config.browser_data_dir,
to_resize_viewport=config.to_resize_viewport,
)

View File

@@ -0,0 +1,52 @@
WEB_SURFER_TOOL_PROMPT_MM = """
{state_description}
Consider the following screenshot of the page. In this screenshot, interactive elements are outlined in bounding boxes of different colors. Each bounding box has a numeric ID label in the same color. Additional information about each visible label is listed below:
{visible_targets}{other_targets_str}{focused_hint}
You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible:
{tool_names}
When deciding between tools, consider if the request can be best addressed by:
- the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate)
- contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate
- on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option)
My request follows:
"""
WEB_SURFER_TOOL_PROMPT_TEXT = """
{state_description}
You have also identified the following interactive components:
{visible_targets}{other_targets_str}{focused_hint}
You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible:
{tool_names}
When deciding between tools, consider if the request can be best addressed by:
- the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate)
- contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate
- on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option)
My request follows:
"""
WEB_SURFER_QA_SYSTEM_MESSAGE = """
You are a helpful assistant that can summarize long documents to answer question.
"""
def WEB_SURFER_QA_PROMPT(title: str, question: str | None = None) -> str:
base_prompt = f"We are visiting the webpage '{title}'. Its full-text content are pasted below, along with a screenshot of the page's current viewport."
if question is not None:
return (
f"{base_prompt} Please summarize the webpage into one or two paragraphs with respect to '{question}':\n\n"
)
else:
return f"{base_prompt} Please summarize the webpage into one or two paragraphs:\n\n"

View File

@@ -0,0 +1,96 @@
import io
import random
from typing import BinaryIO, Dict, List, Tuple, cast
from PIL import Image, ImageDraw, ImageFont
from ._types import DOMRectangle, InteractiveRegion
TOP_NO_LABEL_ZONE = 20 # Don't print any labels close the top of the page
def add_set_of_mark(
screenshot: bytes | Image.Image | io.BufferedIOBase, ROIs: Dict[str, InteractiveRegion]
) -> Tuple[Image.Image, List[str], List[str], List[str]]:
if isinstance(screenshot, Image.Image):
return _add_set_of_mark(screenshot, ROIs)
if isinstance(screenshot, bytes):
screenshot = io.BytesIO(screenshot)
# TODO: Not sure why this cast was needed, but by this point screenshot is a binary file-like object
image = Image.open(cast(BinaryIO, screenshot))
comp, visible_rects, rects_above, rects_below = _add_set_of_mark(image, ROIs)
image.close()
return comp, visible_rects, rects_above, rects_below
def _add_set_of_mark(
screenshot: Image.Image, ROIs: Dict[str, InteractiveRegion]
) -> Tuple[Image.Image, List[str], List[str], List[str]]:
visible_rects: List[str] = list()
rects_above: List[str] = list() # Scroll up to see
rects_below: List[str] = list() # Scroll down to see
fnt = ImageFont.load_default(14)
base = screenshot.convert("L").convert("RGBA")
overlay = Image.new("RGBA", base.size)
draw = ImageDraw.Draw(overlay)
for r in ROIs:
for rect in ROIs[r]["rects"]:
# Empty rectangles
if not rect:
continue
if rect["width"] * rect["height"] == 0:
continue
mid = ((rect["right"] + rect["left"]) / 2.0, (rect["top"] + rect["bottom"]) / 2.0)
if 0 <= mid[0] and mid[0] < base.size[0]:
if mid[1] < 0:
rects_above.append(r)
elif mid[1] >= base.size[1]:
rects_below.append(r)
else:
visible_rects.append(r)
_draw_roi(draw, int(r), fnt, rect)
comp = Image.alpha_composite(base, overlay)
overlay.close()
return comp, visible_rects, rects_above, rects_below
def _draw_roi(
draw: ImageDraw.ImageDraw, idx: int, font: ImageFont.FreeTypeFont | ImageFont.ImageFont, rect: DOMRectangle
) -> None:
color = _color(idx)
luminance = color[0] * 0.3 + color[1] * 0.59 + color[2] * 0.11
text_color = (0, 0, 0, 255) if luminance > 90 else (255, 255, 255, 255)
roi = ((rect["left"], rect["top"]), (rect["right"], rect["bottom"]))
label_location = (rect["right"], rect["top"])
label_anchor = "rb"
if label_location[1] <= TOP_NO_LABEL_ZONE:
label_location = (rect["right"], rect["bottom"])
label_anchor = "rt"
draw.rectangle(roi, outline=color, fill=(color[0], color[1], color[2], 48), width=2)
# TODO: Having trouble with these types being partially Unknown.
bbox = draw.textbbox(label_location, str(idx), font=font, anchor=label_anchor, align="center") # type: ignore
bbox = (bbox[0] - 3, bbox[1] - 3, bbox[2] + 3, bbox[3] + 3)
draw.rectangle(bbox, fill=color)
# TODO: Having trouble with these types being partially Unknown.
draw.text(label_location, str(idx), fill=text_color, font=font, anchor=label_anchor, align="center") # type: ignore
def _color(identifier: int) -> Tuple[int, int, int, int]:
rnd = random.Random(int(identifier))
color = [rnd.randint(0, 255), rnd.randint(125, 255), rnd.randint(0, 50)]
rnd.shuffle(color)
color.append(255)
return cast(Tuple[int, int, int, int], tuple(color))

View File

@@ -0,0 +1,317 @@
from typing import Any, Dict
from agentdhal_core.tools._base import ParametersSchema, ToolSchema
def _load_tool(tooldef: Dict[str, Any]) -> ToolSchema:
return ToolSchema(
name=tooldef["function"]["name"],
description=tooldef["function"]["description"],
parameters=ParametersSchema(
type="object",
properties=tooldef["function"]["parameters"]["properties"],
required=tooldef["function"]["parameters"]["required"],
),
)
REASONING_TOOL_PROMPT = (
"A short description of the action to be performed and reason for doing so, do not mention the user."
)
TOOL_VISIT_URL: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "visit_url",
"description": "Navigate directly to a provided URL using the browser's address bar. Prefer this tool over other navigation techniques in cases where the user provides a fully-qualified URL (e.g., choose it over clicking links, or inputing queries into search boxes).",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"url": {
"type": "string",
"description": "The URL to visit in the browser.",
},
},
"required": ["reasoning", "url"],
},
},
}
)
TOOL_WEB_SEARCH: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "web_search",
"description": "Performs a web search on Bing.com with the given query.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"query": {
"type": "string",
"description": "The web search query to use.",
},
},
"required": ["reasoning", "query"],
},
},
}
)
TOOL_HISTORY_BACK: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "history_back",
"description": "Navigates back one page in the browser's history. This is equivalent to clicking the browser back button.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
},
"required": ["reasoning"],
},
},
}
)
TOOL_SCROLL_UP: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "scroll_up",
"description": "Scrolls the entire browser viewport one page UP towards the beginning.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
},
"required": ["reasoning"],
},
},
}
)
TOOL_SCROLL_DOWN: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "scroll_down",
"description": "Scrolls the entire browser viewport one page DOWN towards the end.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
},
"required": ["reasoning"],
},
},
}
)
TOOL_CLICK: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "click",
"description": "Clicks the mouse on the target with the given id.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"target_id": {
"type": "integer",
"description": "The numeric id of the target to click.",
},
},
"required": ["reasoning", "target_id"],
},
},
}
)
TOOL_TYPE: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "input_text",
"description": "Types the given text value into the specified field.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"input_field_id": {
"type": "integer",
"description": "The numeric id of the input field to receive the text.",
},
"text_value": {
"type": "string",
"description": "The text to type into the input field.",
},
},
"required": ["reasoning", "input_field_id", "text_value"],
},
},
}
)
TOOL_SCROLL_ELEMENT_DOWN: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "scroll_element_down",
"description": "Scrolls a given html element (e.g., a div or a menu) DOWN.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"target_id": {
"type": "integer",
"description": "The numeric id of the target to scroll down.",
},
},
"required": ["reasoning", "target_id"],
},
},
}
)
TOOL_SCROLL_ELEMENT_UP: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "scroll_element_up",
"description": "Scrolls a given html element (e.g., a div or a menu) UP.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"target_id": {
"type": "integer",
"description": "The numeric id of the target to scroll UP.",
},
},
"required": ["reasoning", "target_id"],
},
},
}
)
TOOL_HOVER: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "hover",
"description": "Hovers the mouse over the target with the given id.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"target_id": {
"type": "integer",
"description": "The numeric id of the target to hover over.",
},
},
"required": ["reasoning", "target_id"],
},
},
}
)
TOOL_READ_PAGE_AND_ANSWER: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "answer_question",
"description": "Uses AI to answer a question about the current webpage's content.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"question": {
"type": "string",
"description": "The question to answer.",
},
},
"required": ["reasoning", "question"],
},
},
}
)
TOOL_SUMMARIZE_PAGE: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "summarize_page",
"description": "Uses AI to summarize the entire page.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
},
"required": ["reasoning"],
},
},
}
)
TOOL_SLEEP: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "sleep",
"description": "Wait a short period of time. Call this function if the page has not yet fully loaded, or if it is determined that a small delay would increase the task's chances of success.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
},
"required": ["reasoning"],
},
},
}
)

View File

@@ -0,0 +1,106 @@
from typing import Any, Dict, List, TypedDict, Union
from agentdhal_core import FunctionCall, Image
from agentdhal_core.models import FunctionExecutionResult
UserContent = Union[str, List[Union[str, Image]]]
AssistantContent = Union[str, List[FunctionCall]]
FunctionExecutionContent = List[FunctionExecutionResult]
SystemContent = str
class DOMRectangle(TypedDict):
x: Union[int, float]
y: Union[int, float]
width: Union[int, float]
height: Union[int, float]
top: Union[int, float]
right: Union[int, float]
bottom: Union[int, float]
left: Union[int, float]
class VisualViewport(TypedDict):
height: Union[int, float]
width: Union[int, float]
offsetLeft: Union[int, float]
offsetTop: Union[int, float]
pageLeft: Union[int, float]
pageTop: Union[int, float]
scale: Union[int, float]
clientWidth: Union[int, float]
clientHeight: Union[int, float]
scrollWidth: Union[int, float]
scrollHeight: Union[int, float]
class InteractiveRegion(TypedDict):
tag_name: str
role: str
aria_name: str
v_scrollable: bool
rects: List[DOMRectangle]
# Helper functions for dealing with JSON. Not sure there's a better way?
def _get_str(d: Any, k: str) -> str:
val = d[k]
assert isinstance(val, str)
return val
def _get_number(d: Any, k: str) -> Union[int, float]:
val = d[k]
assert isinstance(val, int) or isinstance(val, float)
return val
def _get_bool(d: Any, k: str) -> bool:
val = d[k]
assert isinstance(val, bool)
return val
def domrectangle_from_dict(rect: Dict[str, Any]) -> DOMRectangle:
return DOMRectangle(
x=_get_number(rect, "x"),
y=_get_number(rect, "y"),
width=_get_number(rect, "width"),
height=_get_number(rect, "height"),
top=_get_number(rect, "top"),
right=_get_number(rect, "right"),
bottom=_get_number(rect, "bottom"),
left=_get_number(rect, "left"),
)
def interactiveregion_from_dict(region: Dict[str, Any]) -> InteractiveRegion:
typed_rects: List[DOMRectangle] = []
for rect in region["rects"]:
typed_rects.append(domrectangle_from_dict(rect))
return InteractiveRegion(
tag_name=_get_str(region, "tag_name"),
role=_get_str(region, "role"),
aria_name=_get_str(region, "aria-name"),
v_scrollable=_get_bool(region, "v-scrollable"),
rects=typed_rects,
)
def visualviewport_from_dict(viewport: Dict[str, Any]) -> VisualViewport:
return VisualViewport(
height=_get_number(viewport, "height"),
width=_get_number(viewport, "width"),
offsetLeft=_get_number(viewport, "offsetLeft"),
offsetTop=_get_number(viewport, "offsetTop"),
pageLeft=_get_number(viewport, "pageLeft"),
pageTop=_get_number(viewport, "pageTop"),
scale=_get_number(viewport, "scale"),
clientWidth=_get_number(viewport, "clientWidth"),
clientHeight=_get_number(viewport, "clientHeight"),
scrollWidth=_get_number(viewport, "scrollWidth"),
scrollHeight=_get_number(viewport, "scrollHeight"),
)

View File

@@ -0,0 +1,429 @@
var MultimodalWebSurfer = MultimodalWebSurfer || (function() {
let nextLabel = 10;
let roleMapping = {
"a": "link",
"area": "link",
"button": "button",
"input, type=button": "button",
"input, type=checkbox": "checkbox",
"input, type=email": "textbox",
"input, type=number": "spinbutton",
"input, type=radio": "radio",
"input, type=range": "slider",
"input, type=reset": "button",
"input, type=search": "searchbox",
"input, type=submit": "button",
"input, type=tel": "textbox",
"input, type=text": "textbox",
"input, type=url": "textbox",
"search": "search",
"select": "combobox",
"option": "option",
"textarea": "textbox"
};
let getCursor = function(elm) {
return window.getComputedStyle(elm)["cursor"];
};
let getInteractiveElements = function() {
let results = []
let roles = ["scrollbar", "searchbox", "slider", "spinbutton", "switch", "tab", "treeitem", "button", "checkbox", "gridcell", "link", "menuitem", "menuitemcheckbox", "menuitemradio", "option", "progressbar", "radio", "textbox", "combobox", "menu", "tree", "treegrid", "grid", "listbox", "radiogroup", "widget"];
let inertCursors = ["auto", "default", "none", "text", "vertical-text", "not-allowed", "no-drop"];
// Get the main interactive elements
let nodeList = document.querySelectorAll("input, select, textarea, button, [href], [onclick], [contenteditable], [tabindex]:not([tabindex='-1'])");
for (let i=0; i<nodeList.length; i++) { // Copy to something mutable
results.push(nodeList[i]);
}
// Anything not already included that has a suitable role
nodeList = document.querySelectorAll("[role]");
for (let i=0; i<nodeList.length; i++) { // Copy to something mutable
if (results.indexOf(nodeList[i]) == -1) {
let role = nodeList[i].getAttribute("role");
if (roles.indexOf(role) > -1) {
results.push(nodeList[i]);
}
}
}
// Any element that changes the cursor to something implying interactivity
nodeList = document.querySelectorAll("*");
for (let i=0; i<nodeList.length; i++) {
let node = nodeList[i];
// Cursor is default, or does not suggest interactivity
let cursor = getCursor(node);
if (inertCursors.indexOf(cursor) >= 0) {
continue;
}
// Move up to the first instance of this cursor change
parent = node.parentNode;
while (parent && getCursor(parent) == cursor) {
node = parent;
parent = node.parentNode;
}
// Add the node if it is new
if (results.indexOf(node) == -1) {
results.push(node);
}
}
return results;
};
let labelElements = function(elements) {
for (let i=0; i<elements.length; i++) {
if (!elements[i].hasAttribute("__elementId")) {
elements[i].setAttribute("__elementId", "" + (nextLabel++));
}
}
};
let isTopmost = function(element, x, y) {
let hit = document.elementFromPoint(x, y);
// Hack to handle elements outside the viewport
if (hit === null) {
return true;
}
while (hit) {
if (hit == element) return true;
hit = hit.parentNode;
}
return false;
};
let getFocusedElementId = function() {
let elm = document.activeElement;
while (elm) {
if (elm.hasAttribute && elm.hasAttribute("__elementId")) {
return elm.getAttribute("__elementId");
}
elm = elm.parentNode;
}
return null;
};
let trimmedInnerText = function(element) {
if (!element) {
return "";
}
let text = element.innerText;
if (!text) {
return "";
}
return text.trim();
};
let getApproximateAriaName = function(element) {
// Check for aria labels
if (element.hasAttribute("aria-labelledby")) {
let buffer = "";
let ids = element.getAttribute("aria-labelledby").split(" ");
for (let i=0; i<ids.length; i++) {
let label = document.getElementById(ids[i]);
if (label) {
buffer = buffer + " " + trimmedInnerText(label);
}
}
return buffer.trim();
}
if (element.hasAttribute("aria-label")) {
return element.getAttribute("aria-label");
}
// Check for labels
if (element.hasAttribute("id")) {
let label_id = element.getAttribute("id");
let label = "";
let labels = document.querySelectorAll("label[for='" + label_id + "']");
for (let j=0; j<labels.length; j++) {
label += labels[j].innerText + " ";
}
label = label.trim();
if (label != "") {
return label;
}
}
if (element.parentElement && element.parentElement.tagName == "LABEL") {
return element.parentElement.innerText;
}
// Check for alt text or titles
if (element.hasAttribute("alt")) {
return element.getAttribute("alt")
}
if (element.hasAttribute("title")) {
return element.getAttribute("title")
}
return trimmedInnerText(element);
};
let getApproximateAriaRole = function(element) {
let tag = element.tagName.toLowerCase();
if (tag == "input" && element.hasAttribute("type")) {
tag = tag + ", type=" + element.getAttribute("type");
}
if (element.hasAttribute("role")) {
return [element.getAttribute("role"), tag];
}
else if (tag in roleMapping) {
return [roleMapping[tag], tag];
}
else {
return ["", tag];
}
};
let getInteractiveRects = function() {
labelElements(getInteractiveElements());
let elements = document.querySelectorAll("[__elementId]");
let results = {};
for (let i=0; i<elements.length; i++) {
let key = elements[i].getAttribute("__elementId");
let rects = elements[i].getClientRects();
let ariaRole = getApproximateAriaRole(elements[i]);
let ariaName = getApproximateAriaName(elements[i]);
let vScrollable = elements[i].scrollHeight - elements[i].clientHeight >= 1;
let record = {
"tag_name": ariaRole[1],
"role": ariaRole[0],
"aria-name": ariaName,
"v-scrollable": vScrollable,
"rects": []
};
for (const rect of rects) {
let x = rect.left + rect.width/2;
let y = rect.top + rect.height/2;
if (isTopmost(elements[i], x, y)) {
record["rects"].push(JSON.parse(JSON.stringify(rect)));
}
}
if (record["rects"].length > 0) {
results[key] = record;
}
}
return results;
};
let getVisualViewport = function() {
let vv = window.visualViewport;
let de = document.documentElement;
return {
"height": vv ? vv.height : 0,
"width": vv ? vv.width : 0,
"offsetLeft": vv ? vv.offsetLeft : 0,
"offsetTop": vv ? vv.offsetTop : 0,
"pageLeft": vv ? vv.pageLeft : 0,
"pageTop": vv ? vv.pageTop : 0,
"scale": vv ? vv.scale : 0,
"clientWidth": de ? de.clientWidth : 0,
"clientHeight": de ? de.clientHeight : 0,
"scrollWidth": de ? de.scrollWidth : 0,
"scrollHeight": de ? de.scrollHeight : 0
};
};
let _getMetaTags = function() {
let meta = document.querySelectorAll("meta");
let results = {};
for (let i = 0; i<meta.length; i++) {
let key = null;
if (meta[i].hasAttribute("name")) {
key = meta[i].getAttribute("name");
}
else if (meta[i].hasAttribute("property")) {
key = meta[i].getAttribute("property");
}
else {
continue;
}
if (meta[i].hasAttribute("content")) {
results[key] = meta[i].getAttribute("content");
}
}
return results;
};
let _getJsonLd = function() {
let jsonld = [];
let scripts = document.querySelectorAll('script[type="application/ld+json"]');
for (let i=0; i<scripts.length; i++) {
jsonld.push(scripts[i].innerHTML.trim());
}
return jsonld;
};
// From: https://www.stevefenton.co.uk/blog/2022/12/parse-microdata-with-javascript/
let _getMicrodata = function() {
function sanitize(input) {
return input.replace(/\s/gi, ' ').trim();
}
function addValue(information, name, value) {
if (information[name]) {
if (typeof information[name] === 'array') {
information[name].push(value);
} else {
const arr = [];
arr.push(information[name]);
arr.push(value);
information[name] = arr;
}
} else {
information[name] = value;
}
}
function traverseItem(item, information) {
const children = item.children;
for (let i = 0; i < children.length; i++) {
const child = children[i];
if (child.hasAttribute('itemscope')) {
if (child.hasAttribute('itemprop')) {
const itemProp = child.getAttribute('itemprop');
const itemType = child.getAttribute('itemtype');
const childInfo = {
itemType: itemType
};
traverseItem(child, childInfo);
itemProp.split(' ').forEach(propName => {
addValue(information, propName, childInfo);
});
}
} else if (child.hasAttribute('itemprop')) {
const itemProp = child.getAttribute('itemprop');
itemProp.split(' ').forEach(propName => {
if (propName === 'url') {
addValue(information, propName, child.href);
} else {
addValue(information, propName, sanitize(child.getAttribute("content") || child.content || child.textContent || child.src || ""));
}
});
traverseItem(child, information);
} else {
traverseItem(child, information);
}
}
}
const microdata = [];
document.querySelectorAll("[itemscope]").forEach(function(elem, i) {
const itemType = elem.getAttribute('itemtype');
const information = {
itemType: itemType
};
traverseItem(elem, information);
microdata.push(information);
});
return microdata;
};
let getPageMetadata = function() {
let jsonld = _getJsonLd();
let metaTags = _getMetaTags();
let microdata = _getMicrodata();
let results = {}
if (jsonld.length > 0) {
try {
results["jsonld"] = JSON.parse(jsonld);
}
catch (e) {
results["jsonld"] = jsonld;
}
}
if (microdata.length > 0) {
results["microdata"] = microdata;
}
for (let key in metaTags) {
if (metaTags.hasOwnProperty(key)) {
results["meta_tags"] = metaTags;
break;
}
}
return results;
};
let getVisibleText = function() {
// Get the windows current viewport boundaries
const viewportHeight = window.innerHeight || document.documentElement.clientHeight;
const viewportWidth = window.innerWidth || document.documentElement.clientWidth;
let textInView = "";
const walker = document.createTreeWalker(
document.body,
NodeFilter.SHOW_TEXT,
null,
false
);
while (walker.nextNode()) {
const textNode = walker.currentNode;
// Create a range to retrieve bounding rectangles of the current text node
const range = document.createRange();
range.selectNodeContents(textNode);
const rects = range.getClientRects();
// Check if any rect is inside (or partially inside) the viewport
for (const rect of rects) {
const isVisible =
rect.width > 0 &&
rect.height > 0 &&
rect.bottom >= 0 &&
rect.right >= 0 &&
rect.top <= viewportHeight &&
rect.left <= viewportWidth;
if (isVisible) {
textInView += textNode.nodeValue.replace(/\s+/g, " ");
// Is the parent a block element?
if (textNode.parentNode) {
const parent = textNode.parentNode;
const style = window.getComputedStyle(parent);
if (["inline", "hidden", "none"].indexOf(style.display) === -1) {
textInView += "\n";
}
}
break; // No need to check other rects once found visible
}
}
}
// Remove blank lines from textInView
textInView = textInView.replace(/^\s*\n/gm, "").trim().replace(/\n+/g, "\n");
return textInView;
};
return {
getInteractiveRects: getInteractiveRects,
getVisualViewport: getVisualViewport,
getFocusedElementId: getFocusedElementId,
getPageMetadata: getPageMetadata,
getVisibleText: getVisibleText,
};
})();

View File

@@ -0,0 +1,578 @@
import asyncio
import base64
import io
import os
import random
import warnings
from types import ModuleType
from typing import Any, Callable, Dict, Optional, Tuple, Union, cast
from playwright._impl._errors import Error as PlaywrightError
from playwright._impl._errors import TimeoutError
from playwright.async_api import Download, Page
from ._types import (
InteractiveRegion,
VisualViewport,
interactiveregion_from_dict,
visualviewport_from_dict,
)
markitdown: ModuleType | None = None
try:
# Suppress warnings from markitdown -- which is pretty chatty
warnings.filterwarnings(action="ignore", module="markitdown")
import markitdown
except ImportError:
pass
class PlaywrightController:
"""
A helper class to allow Playwright to interact with web pages to perform actions such as clicking, filling, and scrolling.
Args:
downloads_folder (str | None): The folder to save downloads to. If None, downloads are not saved.
animate_actions (bool): Whether to animate the actions (create fake cursor to click).
viewport_width (int): The width of the viewport.
viewport_height (int): The height of the viewport.
_download_handler (Optional[Callable[[Download], None]]): A function to handle downloads.
to_resize_viewport (bool): Whether to resize the viewport
"""
def __init__(
self,
downloads_folder: str | None = None,
animate_actions: bool = False,
viewport_width: int = 1440,
viewport_height: int = 900,
_download_handler: Optional[Callable[[Download], None]] = None,
to_resize_viewport: bool = True,
) -> None:
"""
Initialize the PlaywrightController.
"""
assert isinstance(animate_actions, bool)
assert isinstance(viewport_width, int)
assert isinstance(viewport_height, int)
assert viewport_height > 0
assert viewport_width > 0
self.animate_actions = animate_actions
self.downloads_folder = downloads_folder
self.viewport_width = viewport_width
self.viewport_height = viewport_height
self._download_handler = _download_handler
self.to_resize_viewport = to_resize_viewport
self._page_script: str = ""
self.last_cursor_position: Tuple[float, float] = (0.0, 0.0)
self._markdown_converter: Optional[Any] | None = None
# Read page_script
with open(
os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js"), "rt", encoding="utf-8"
) as fh:
self._page_script = fh.read()
async def sleep(self, page: Page, duration: Union[int, float]) -> None:
"""
Pause the execution for a specified duration.
Args:
page (Page): The Playwright page object.
duration (Union[int, float]): The duration to sleep in milliseconds.
"""
assert page is not None
await page.wait_for_timeout(duration * 1000)
async def get_interactive_rects(self, page: Page) -> Dict[str, InteractiveRegion]:
"""
Retrieve interactive regions from the web page.
Args:
page (Page): The Playwright page object.
Returns:
Dict[str, InteractiveRegion]: A dictionary of interactive regions.
"""
assert page is not None
# Read the regions from the DOM
try:
await page.evaluate(self._page_script)
except Exception:
pass
result = cast(Dict[str, Dict[str, Any]], await page.evaluate("MultimodalWebSurfer.getInteractiveRects();"))
# Convert the results into appropriate types
assert isinstance(result, dict)
typed_results: Dict[str, InteractiveRegion] = {}
for k in result:
assert isinstance(k, str)
typed_results[k] = interactiveregion_from_dict(result[k])
return typed_results
async def get_visual_viewport(self, page: Page) -> VisualViewport:
"""
Retrieve the visual viewport of the web page.
Args:
page (Page): The Playwright page object.
Returns:
VisualViewport: The visual viewport of the page.
"""
assert page is not None
try:
await page.evaluate(self._page_script)
except Exception:
pass
return visualviewport_from_dict(await page.evaluate("MultimodalWebSurfer.getVisualViewport();"))
async def get_focused_rect_id(self, page: Page) -> str | None:
"""
Retrieve the ID of the currently focused element.
Args:
page (Page): The Playwright page object.
Returns:
str: The ID of the focused element or None if no control has focus.
"""
assert page is not None
try:
await page.evaluate(self._page_script)
except Exception:
pass
result = await page.evaluate("MultimodalWebSurfer.getFocusedElementId();")
return None if result is None else str(result)
async def get_page_metadata(self, page: Page) -> Dict[str, Any]:
"""
Retrieve metadata from the web page.
Args:
page (Page): The Playwright page object.
Returns:
Dict[str, Any]: A dictionary of page metadata.
"""
assert page is not None
try:
await page.evaluate(self._page_script)
except Exception:
pass
result = await page.evaluate("MultimodalWebSurfer.getPageMetadata();")
assert isinstance(result, dict)
return cast(Dict[str, Any], result)
async def on_new_page(self, page: Page) -> None:
"""
Handle actions to perform on a new page.
Args:
page (Page): The Playwright page object.
"""
assert page is not None
page.on("download", self._download_handler) # type: ignore
if self.to_resize_viewport and self.viewport_width and self.viewport_height:
await page.set_viewport_size({"width": self.viewport_width, "height": self.viewport_height})
await self.sleep(page, 0.2)
await page.add_init_script(path=os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js"))
await page.wait_for_load_state()
async def back(self, page: Page) -> None:
"""
Navigate back to the previous page.
Args:
page (Page): The Playwright page object.
"""
assert page is not None
await page.go_back()
async def visit_page(self, page: Page, url: str) -> Tuple[bool, bool]:
"""
Visit a specified URL.
Args:
page (Page): The Playwright page object.
url (str): The URL to visit.
Returns:
Tuple[bool, bool]: A tuple indicating whether to reset prior metadata hash and last download.
"""
assert page is not None
reset_prior_metadata_hash = False
reset_last_download = False
try:
# Regular webpage
await page.goto(url)
await page.wait_for_load_state()
reset_prior_metadata_hash = True
except Exception as e_outer:
# Downloaded file
if self.downloads_folder and "net::ERR_ABORTED" in str(e_outer):
async with page.expect_download() as download_info:
try:
await page.goto(url)
except Exception as e_inner:
if "net::ERR_ABORTED" in str(e_inner):
pass
else:
raise e_inner
download = await download_info.value
fname = os.path.join(self.downloads_folder, download.suggested_filename)
await download.save_as(fname)
message = f"<body style=\"margin: 20px;\"><h1>Successfully downloaded '{download.suggested_filename}' to local path:<br><br>{fname}</h1></body>"
await page.goto(
"data:text/html;base64," + base64.b64encode(message.encode("utf-8")).decode("utf-8")
)
reset_last_download = True
else:
raise e_outer
return reset_prior_metadata_hash, reset_last_download
async def page_down(self, page: Page) -> None:
"""
Scroll the page down by one viewport height minus 50 pixels.
Args:
page (Page): The Playwright page object.
"""
assert page is not None
await page.evaluate(f"window.scrollBy(0, {self.viewport_height-50});")
async def page_up(self, page: Page) -> None:
"""
Scroll the page up by one viewport height minus 50 pixels.
Args:
page (Page): The Playwright page object.
"""
assert page is not None
await page.evaluate(f"window.scrollBy(0, -{self.viewport_height-50});")
async def gradual_cursor_animation(
self, page: Page, start_x: float, start_y: float, end_x: float, end_y: float
) -> None:
"""
Animate the cursor movement gradually from start to end coordinates.
Args:
page (Page): The Playwright page object.
start_x (float): The starting x-coordinate.
start_y (float): The starting y-coordinate.
end_x (float): The ending x-coordinate.
end_y (float): The ending y-coordinate.
"""
# animation helper
steps = 20
for step in range(steps):
x = start_x + (end_x - start_x) * (step / steps)
y = start_y + (end_y - start_y) * (step / steps)
# await page.mouse.move(x, y, steps=1)
await page.evaluate(f"""
(function() {{
let cursor = document.getElementById('red-cursor');
cursor.style.left = '{x}px';
cursor.style.top = '{y}px';
}})();
""")
await asyncio.sleep(0.05)
self.last_cursor_position = (end_x, end_y)
async def add_cursor_box(self, page: Page, identifier: str) -> None:
"""
Add a red cursor box around the element with the given identifier.
Args:
page (Page): The Playwright page object.
identifier (str): The element identifier.
"""
# animation helper
await page.evaluate(f"""
(function() {{
let elm = document.querySelector("[__elementId='{identifier}']");
if (elm) {{
elm.style.transition = 'border 0.3s ease-in-out';
elm.style.border = '2px solid red';
}}
}})();
""")
await asyncio.sleep(0.3)
# Create a red cursor
await page.evaluate("""
(function() {
let cursor = document.createElement('div');
cursor.id = 'red-cursor';
cursor.style.width = '10px';
cursor.style.height = '10px';
cursor.style.backgroundColor = 'red';
cursor.style.position = 'absolute';
cursor.style.borderRadius = '50%';
cursor.style.zIndex = '10000';
document.body.appendChild(cursor);
})();
""")
async def remove_cursor_box(self, page: Page, identifier: str) -> None:
"""
Remove the red cursor box around the element with the given identifier.
Args:
page (Page): The Playwright page object.
identifier (str): The element identifier.
"""
# Remove the highlight and cursor
await page.evaluate(f"""
(function() {{
let elm = document.querySelector("[__elementId='{identifier}']");
if (elm) {{
elm.style.border = '';
}}
let cursor = document.getElementById('red-cursor');
if (cursor) {{
cursor.remove();
}}
}})();
""")
async def click_id(self, page: Page, identifier: str) -> Page | None:
"""
Click the element with the given identifier.
Args:
page (Page): The Playwright page object.
identifier (str): The element identifier.
Returns:
Page | None: The new page if a new page is opened, otherwise None.
"""
new_page: Page | None = None
assert page is not None
target = page.locator(f"[__elementId='{identifier}']")
# See if it exists
try:
await target.wait_for(timeout=5000)
except TimeoutError:
raise ValueError("No such element.") from None
# Click it
await target.scroll_into_view_if_needed()
await asyncio.sleep(0.3)
box = cast(Dict[str, Union[int, float]], await target.bounding_box())
if self.animate_actions:
await self.add_cursor_box(page, identifier)
# Move cursor to the box slowly
start_x, start_y = self.last_cursor_position
end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2
await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y)
await asyncio.sleep(0.1)
try:
# Give it a chance to open a new page
async with page.expect_event("popup", timeout=1000) as page_info: # type: ignore
await page.mouse.click(end_x, end_y, delay=10)
new_page = await page_info.value # type: ignore
assert isinstance(new_page, Page)
await self.on_new_page(new_page)
except TimeoutError:
pass
await self.remove_cursor_box(page, identifier)
else:
try:
# Give it a chance to open a new page
async with page.expect_event("popup", timeout=1000) as page_info: # type: ignore
await page.mouse.click(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2, delay=10)
new_page = await page_info.value # type: ignore
assert isinstance(new_page, Page)
await self.on_new_page(new_page)
except TimeoutError:
pass
return new_page # type: ignore
async def hover_id(self, page: Page, identifier: str) -> None:
"""
Hover the mouse over the element with the given identifier.
Args:
page (Page): The Playwright page object.
identifier (str): The element identifier.
"""
assert page is not None
target = page.locator(f"[__elementId='{identifier}']")
# See if it exists
try:
await target.wait_for(timeout=5000)
except TimeoutError:
raise ValueError("No such element.") from None
# Hover over it
await target.scroll_into_view_if_needed()
await asyncio.sleep(0.3)
box = cast(Dict[str, Union[int, float]], await target.bounding_box())
if self.animate_actions:
await self.add_cursor_box(page, identifier)
# Move cursor to the box slowly
start_x, start_y = self.last_cursor_position
end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2
await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y)
await asyncio.sleep(0.1)
await page.mouse.move(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2)
await self.remove_cursor_box(page, identifier)
else:
await page.mouse.move(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2)
async def fill_id(self, page: Page, identifier: str, value: str, press_enter: bool = True) -> None:
"""
Fill the element with the given identifier with the specified value.
Args:
page (Page): The Playwright page object.
identifier (str): The element identifier.
value (str): The value to fill.
"""
assert page is not None
target = page.locator(f"[__elementId='{identifier}']")
# See if it exists
try:
await target.wait_for(timeout=5000)
except TimeoutError:
raise ValueError("No such element.") from None
# Fill it
await target.scroll_into_view_if_needed()
box = cast(Dict[str, Union[int, float]], await target.bounding_box())
if self.animate_actions:
await self.add_cursor_box(page, identifier)
# Move cursor to the box slowly
start_x, start_y = self.last_cursor_position
end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2
await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y)
await asyncio.sleep(0.1)
# Focus on the element
await target.focus()
if self.animate_actions:
# fill char by char to mimic human speed for short text and type fast for long text
if len(value) < 100:
delay_typing_speed = 50 + 100 * random.random()
else:
delay_typing_speed = 10
await target.press_sequentially(value, delay=delay_typing_speed)
else:
try:
await target.fill(value)
except PlaywrightError:
await target.press_sequentially(value)
if press_enter:
await target.press("Enter")
if self.animate_actions:
await self.remove_cursor_box(page, identifier)
async def scroll_id(self, page: Page, identifier: str, direction: str) -> None:
"""
Scroll the element with the given identifier in the specified direction.
Args:
page (Page): The Playwright page object.
identifier (str): The element identifier.
direction (str): The direction to scroll ("up" or "down").
"""
assert page is not None
await page.evaluate(
f"""
(function() {{
let elm = document.querySelector("[__elementId='{identifier}']");
if (elm) {{
if ("{direction}" == "up") {{
elm.scrollTop = Math.max(0, elm.scrollTop - elm.clientHeight);
}}
else {{
elm.scrollTop = Math.min(elm.scrollHeight - elm.clientHeight, elm.scrollTop + elm.clientHeight);
}}
}}
}})();
"""
)
async def get_webpage_text(self, page: Page, n_lines: int = 50) -> str:
"""
Retrieve the text content of the web page.
Args:
page (Page): The Playwright page object.
n_lines (int): The number of lines to return from the page inner text.
Returns:
str: The text content of the page.
"""
assert page is not None
try:
text_in_viewport = await page.evaluate("""() => {
return document.body.innerText;
}""")
text_in_viewport = "\n".join(text_in_viewport.split("\n")[:n_lines])
# remove empty lines
text_in_viewport = "\n".join([line for line in text_in_viewport.split("\n") if line.strip()])
assert isinstance(text_in_viewport, str)
return text_in_viewport
except Exception:
return ""
async def get_visible_text(self, page: Page) -> str:
"""
Retrieve the text content of the browser viewport (approximately).
Args:
page (Page): The Playwright page object.
Returns:
str: The text content of the page.
"""
assert page is not None
try:
await page.evaluate(self._page_script)
except Exception:
pass
result = await page.evaluate("MultimodalWebSurfer.getVisibleText();")
assert isinstance(result, str)
return result
async def get_page_markdown(self, page: Page) -> str:
"""
Retrieve the markdown content of the web page.
Currently not implemented.
Args:
page (Page): The Playwright page object.
Returns:
str: The markdown content of the page.
"""
assert page is not None
if self._markdown_converter is None and markitdown is not None:
self._markdown_converter = markitdown.MarkItDown()
assert self._markdown_converter is not None
html = await page.evaluate("document.documentElement.outerHTML;")
res = self._markdown_converter.convert_stream(
io.BytesIO(html.encode("utf-8")), file_extension=".html", url=page.url
)
assert hasattr(res, "text_content") and isinstance(res.text_content, str)
return res.text_content
else:
return await self.get_webpage_text(page, n_lines=200)

View File

@@ -0,0 +1,56 @@
from typing import List
from agentdhal_core import Component, ComponentBase
from pydantic import BaseModel
from typing_extensions import Self
from azure.core.credentials import TokenProvider
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
class TokenProviderConfig(BaseModel):
provider_kind: str
scopes: List[str]
class AzureTokenProvider(ComponentBase[TokenProviderConfig], Component[TokenProviderConfig]):
component_type = "token_provider"
component_config_schema = TokenProviderConfig
component_provider_override = "agentdhal_extensions.auth.azure.AzureTokenProvider"
def __init__(self, credential: TokenProvider, *scopes: str):
self.credential = credential
self.scopes = list(scopes)
self.provider = get_bearer_token_provider(self.credential, *self.scopes)
def __call__(self) -> str:
return self.provider()
def _to_config(self) -> TokenProviderConfig:
"""Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance.
Returns:
T: The configuration of the component.
"""
if isinstance(self.credential, DefaultAzureCredential):
# NOTE: we are not currently inspecting the chained credentials, so this could result in a loss of information
return TokenProviderConfig(provider_kind="DefaultAzureCredential", scopes=self.scopes)
else:
raise ValueError("Only DefaultAzureCredential is supported")
@classmethod
def _from_config(cls, config: TokenProviderConfig) -> Self:
"""Create a new instance of the component from a configuration object.
Args:
config (T): The configuration object.
Returns:
Self: The new instance of the component.
"""
if config.provider_kind == "DefaultAzureCredential":
return cls(DefaultAzureCredential(), *config.scopes)
else:
raise ValueError("Only DefaultAzureCredential is supported")

View File

@@ -0,0 +1,46 @@
from typing import Any, Optional, TypeVar, cast
import diskcache
from agentdhal_core import CacheStore, Component
from pydantic import BaseModel
from typing_extensions import Self
T = TypeVar("T")
class DiskCacheStoreConfig(BaseModel):
"""Configuration for DiskCacheStore"""
directory: str # Path where cache is stored
# Could add other diskcache.Cache parameters like size_limit, etc.
class DiskCacheStore(CacheStore[T], Component[DiskCacheStoreConfig]):
"""
A typed CacheStore implementation that uses diskcache as the underlying storage.
See :class:`~agentdhal_extensions.models.cache.ChatCompletionCache` for an example of usage.
Args:
cache_instance: An instance of diskcache.Cache.
The user is responsible for managing the DiskCache instance's lifetime.
"""
component_config_schema = DiskCacheStoreConfig
component_provider_override = "agentdhal_extensions.cache_store.diskcache.DiskCacheStore"
def __init__(self, cache_instance: diskcache.Cache): # type: ignore[no-any-unimported]
self.cache = cache_instance
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
return cast(Optional[T], self.cache.get(key, default)) # type: ignore[reportUnknownMemberType]
def set(self, key: str, value: T) -> None:
self.cache.set(key, cast(Any, value)) # type: ignore[reportUnknownMemberType]
def _to_config(self) -> DiskCacheStoreConfig:
# Get directory from cache instance
return DiskCacheStoreConfig(directory=self.cache.directory)
@classmethod
def _from_config(cls, config: DiskCacheStoreConfig) -> Self:
return cls(cache_instance=diskcache.Cache(config.directory)) # type: ignore[no-any-return]

View File

@@ -0,0 +1,142 @@
import json
from typing import Any, Dict, Optional, TypeVar, cast
import redis
from agentdhal_core import CacheStore, Component
from pydantic import BaseModel
from typing_extensions import Self
T = TypeVar("T")
class RedisStoreConfig(BaseModel):
"""Configuration for RedisStore"""
host: str = "localhost"
port: int = 6379
db: int = 0
# Add other relevant redis connection parameters
username: Optional[str] = None
password: Optional[str] = None
ssl: bool = False
socket_timeout: Optional[float] = None
class RedisStore(CacheStore[T], Component[RedisStoreConfig]):
"""
A typed CacheStore implementation that uses redis as the underlying storage.
See :class:`~agentdhal_extensions.models.cache.ChatCompletionCache` for an example of usage.
This implementation provides automatic serialization and deserialization for:
- Pydantic models (uses model_dump_json/model_validate_json)
- Primitive types (strings, numbers, etc.)
Args:
cache_instance: An instance of `redis.Redis`.
The user is responsible for managing the Redis instance's lifetime.
"""
component_config_schema = RedisStoreConfig
component_provider_override = "agentdhal_extensions.cache_store.redis.RedisStore"
def __init__(self, redis_instance: redis.Redis):
self.cache = redis_instance
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
"""
Retrieve a value from the Redis cache.
This method handles both primitive values and complex objects:
- Pydantic models are automatically deserialized from JSON
- Primitive values (strings, numbers, etc.) are returned as-is
- If deserialization fails, returns the raw value or default
Args:
key: The key to retrieve
default: Value to return if key doesn't exist
Returns:
The value if found and properly deserialized, otherwise the default
"""
try:
raw_value = self.cache.get(key)
if raw_value is None:
return default
if isinstance(raw_value, bytes):
try:
# First try to decode as UTF-8 string
decoded_str = raw_value.decode("utf-8")
try:
# Try to parse as JSON and return the parsed object
parsed_json = json.loads(decoded_str)
return cast(Optional[T], parsed_json)
except json.JSONDecodeError:
# If not valid JSON, return the decoded string.
return cast(Optional[T], decoded_str)
except UnicodeDecodeError:
return default
else:
# Backward compatibility for primitives
return cast(Optional[T], raw_value)
except (redis.RedisError, ConnectionError):
# Log Redis-specific errors but return default gracefully
return default
def set(self, key: str, value: T) -> None:
"""
Store a value in the Redis cache.
This method handles both primitive values and complex objects:
- Pydantic models are automatically serialized to JSON
- Primitive values (strings, numbers, etc.) are stored as-is
Args:
key: The key to store the value under
value: The value to store
"""
try:
if isinstance(value, BaseModel):
# Serialize Pydantic models to JSON
serialized_value = value.model_dump_json().encode("utf-8")
self.cache.set(key, serialized_value)
else:
# Backward compatibility for primitives
self.cache.set(key, cast(Any, value))
except (redis.RedisError, ConnectionError, UnicodeEncodeError):
# Log the error but don't re-raise to maintain robustness
pass
def _to_config(self) -> RedisStoreConfig:
# Extract connection info from redis instance
connection_pool = self.cache.connection_pool
connection_kwargs: Dict[str, Any] = connection_pool.connection_kwargs # type: ignore[reportUnknownMemberType]
username = connection_kwargs.get("username")
password = connection_kwargs.get("password")
socket_timeout = connection_kwargs.get("socket_timeout")
return RedisStoreConfig(
host=str(connection_kwargs.get("host", "localhost")),
port=int(connection_kwargs.get("port", 6379)),
db=int(connection_kwargs.get("db", 0)),
username=str(username) if username is not None else None,
password=str(password) if password is not None else None,
ssl=bool(connection_kwargs.get("ssl", False)),
socket_timeout=float(socket_timeout) if socket_timeout is not None else None,
)
@classmethod
def _from_config(cls, config: RedisStoreConfig) -> Self:
# Create new redis instance from config
redis_instance = redis.Redis(
host=config.host,
port=config.port,
db=config.db,
username=config.username,
password=config.password,
ssl=config.ssl,
socket_timeout=config.socket_timeout,
)
return cls(redis_instance=redis_instance)

View File

@@ -0,0 +1,199 @@
import inspect
import re
import shutil
from dataclasses import dataclass
from pathlib import Path
from textwrap import dedent, indent
from typing import Any, Callable, Optional, Sequence, Set, TypeVar, Union
from agentdhal_core.code_executor import Alias, CodeResult, FunctionWithRequirements, FunctionWithRequirementsStr, Import
from typing_extensions import ParamSpec
@dataclass
class CommandLineCodeResult(CodeResult):
"""A code result class for command line code executor."""
code_file: Optional[str]
T = TypeVar("T")
P = ParamSpec("P")
def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str:
if isinstance(func, FunctionWithRequirementsStr):
return func.func
code = inspect.getsource(func)
# Strip the decorator
if code.startswith("@"):
code = code[code.index("\n") + 1 :]
return code
def _import_to_str(im: Import) -> str:
if isinstance(im, str):
return f"import {im}"
elif isinstance(im, Alias):
return f"import {im.name} as {im.alias}"
else:
def to_str(i: Union[str, Alias]) -> str:
if isinstance(i, str):
return i
else:
return f"{i.name} as {i.alias}"
imports = ", ".join(map(to_str, im.imports))
return f"from {im.module} import {imports}"
def build_python_functions_file(
funcs: Sequence[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]],
) -> str:
""":meta private:"""
# First collect all global imports
global_imports: Set[Import] = set()
for func in funcs:
if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)):
global_imports.update(func.global_imports)
content = "\n".join(map(_import_to_str, global_imports)) + "\n\n"
for func in funcs:
content += _to_code(func) + "\n\n"
return content
def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str:
"""Generate a stub for a function as a string
Args:
func (Callable[..., Any]): The function to generate a stub for
Returns:
str: The stub for the function
"""
if isinstance(func, FunctionWithRequirementsStr):
return to_stub(func.compiled_func)
content = f"def {func.__name__}{inspect.signature(func)}:\n"
docstring = func.__doc__
if docstring:
docstring = dedent(docstring)
docstring = '"""' + docstring + '"""'
docstring = indent(docstring, " ")
content += docstring + "\n"
content += " ..."
return content
# Raises ValueError if the file is not in the workspace
def get_file_name_from_content(code: str, workspace_path: Path) -> Optional[str]:
first_line = code.split("\n")[0]
# TODO - support other languages
if first_line.startswith("# filename:"):
filename = first_line.split(":")[1].strip()
# Handle relative paths in the filename
path = Path(filename)
if not path.is_absolute():
path = workspace_path / path
path = path.resolve()
# Throws an error if the file is not in the workspace
relative = path.relative_to(workspace_path.resolve())
return str(relative)
return None
def silence_pip(code: str, lang: str) -> str:
"""Apply -qqq flag to pip install commands."""
if lang == "python":
regex = r"^! ?pip install"
elif lang in ["bash", "shell", "sh", "pwsh", "powershell", "ps1"]:
regex = r"^pip install"
else:
return code
# Find lines that start with pip install and make sure "-qqq" flag is added.
lines = code.split("\n")
for i, line in enumerate(lines):
# use regex to find lines that start with pip install.
match = re.search(regex, line)
if match is not None:
if "-qqq" not in line:
lines[i] = line.replace(match.group(0), match.group(0) + " -qqq")
return "\n".join(lines)
def get_required_packages(code: str, lang: str) -> set[str]:
ret: set[str] = set()
if lang == "python":
regex = r"^! ?pip install(.*)$"
else:
return ret
# Find lines that start with pip install and make sure "-qqq" flag is added.
lines = code.split("\n")
for _, line in enumerate(lines):
# use regex to find lines that start with pip install.
match = re.search(regex, line)
if match is not None:
reqs = match.group(1).split(",")
ret = {req.strip(" ") for req in reqs}
return ret
PYTHON_VARIANTS = ["python", "Python", "py"]
def lang_to_cmd(lang: str) -> str:
if lang in PYTHON_VARIANTS:
return "python"
if lang.startswith("python") or lang in ["bash", "sh"]:
return lang
if lang in ["shell"]:
return "sh"
if lang in ["pwsh", "powershell", "ps1"]:
# Check if pwsh is available, otherwise fall back to powershell
if shutil.which("pwsh") is not None:
return "pwsh"
elif shutil.which("powershell") is not None:
return "powershell"
else:
raise ValueError("Powershell or pwsh is not installed. Please install one of them.")
else:
raise ValueError(f"Unsupported language: {lang}")
# Regular expression for finding a code block
# ```[ \t]*(\w+)?[ \t]*\r?\n(.*?)[ \t]*\r?\n``` Matches multi-line code blocks.
# The [ \t]* matches the potential spaces before language name.
# The (\w+)? matches the language, where the ? indicates it is optional.
# The [ \t]* matches the potential spaces (not newlines) after language name.
# The \r?\n makes sure there is a linebreak after ```.
# The (.*?) matches the code itself (non-greedy).
# The \r?\n makes sure there is a linebreak before ```.
# The [ \t]* matches the potential spaces before closing ``` (the spec allows indentation).
CODE_BLOCK_PATTERN = r"```[ \t]*(\w+)?[ \t]*\r?\n(.*?)\r?\n[ \t]*```"
def infer_lang(code: str) -> str:
"""infer the language for the code.
TODO: make it robust.
"""
if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "):
return "sh"
# check if code is a valid python code
try:
compile(code, "test", "exec")
return "python"
except SyntaxError:
# not a valid python code
return "unknown"

View File

@@ -0,0 +1,3 @@
from ._azure_container_code_executor import ACADynamicSessionsCodeExecutor, TokenProvider
__all__ = ["TokenProvider", "ACADynamicSessionsCodeExecutor"]

View File

@@ -0,0 +1,522 @@
# Credit to original authors
from __future__ import annotations
import asyncio
import os
import tempfile
import warnings
from pathlib import Path
from string import Template
from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Protocol, Sequence, Union
from uuid import uuid4
import aiohttp
# async functions shouldn't use open()
from anyio import open_file
from agentdhal_core import CancellationToken
from agentdhal_core.code_executor import (
CodeBlock,
CodeExecutor,
CodeResult,
FunctionWithRequirements,
FunctionWithRequirementsStr,
)
from typing_extensions import ParamSpec
from .._common import build_python_functions_file, get_required_packages, to_stub
if TYPE_CHECKING:
from azure.core.credentials import AccessToken
PYTHON_VARIANTS = ["python", "Python", "py"]
__all__ = ("ACADynamicSessionsCodeExecutor", "TokenProvider")
A = ParamSpec("A")
class TokenProvider(Protocol):
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken: ...
class ACADynamicSessionsCodeExecutor(CodeExecutor):
"""(Experimental) A code executor class that executes code through a an Azure
Container Apps Dynamic Sessions instance.
.. note::
This class requires the :code:`azure` extra for the :code:`autogen-ext` package:
.. code-block:: bash
pip install "agentdhal-ext[azure]"
.. caution::
**This will execute LLM generated code on an Azure dynamic code container.**
The execution environment is similar to that of a jupyter notebook which allows for incremental code execution. The parameter functions are executed in order once at the beginning of each session. Each code block is then executed serially and in the order they are received. Each environment has a statically defined set of available packages which cannot be changed.
Currently, attempting to use packages beyond what is available on the environment will result in an error. To get the list of supported packages, call the `get_available_packages` function.
Currently the only supported language is Python.
For Python code, use the language "python" for the code block.
Args:
pool_management_endpoint (str): The azure container apps dynamic sessions endpoint.
credential (TokenProvider): An object that implements the get_token function.
timeout (int): The timeout for the execution of any single code block. Default is 60.
work_dir (str): The working directory for the code execution. If None,
a default working directory will be used. The default working
directory is a temporal directory.
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
suppress_result_output bool: By default the executor will attach any result info in the execution response to the result outpu. Set this to True to prevent this.
session_id (str): The session id for the code execution (passed to Dynamic Sessions). If None, a new session id will be generated. Default is None. Note this value will be reset when calling `restart`
.. note::
Using the current directory (".") as working directory is deprecated. Using it will raise a deprecation warning.
"""
SUPPORTED_LANGUAGES: ClassVar[List[str]] = [
"python",
]
FUNCTION_PROMPT_TEMPLATE: ClassVar[str] = """You have access to the following user defined functions.
$functions"""
_AZURE_API_VER = "2024-02-02-preview"
def __init__(
self,
pool_management_endpoint: str,
credential: TokenProvider,
timeout: int = 60,
work_dir: Union[Path, str, None] = None,
functions: Sequence[
Union[
FunctionWithRequirements[Any, A],
Callable[..., Any],
FunctionWithRequirementsStr,
]
] = [],
functions_module: str = "functions",
suppress_result_output: bool = False,
session_id: Optional[str] = None,
):
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")
self._work_dir: Optional[Path] = None
self._temp_dir: Optional[tempfile.TemporaryDirectory[str]] = None
# If a user specifies a working directory, use that
if work_dir is not None:
if isinstance(work_dir, str):
self._work_dir = Path(work_dir)
else:
self._work_dir = work_dir
# Create the directory if it doesn't exist
self._work_dir.mkdir(exist_ok=True, parents=True)
# If a user does not specify a working directory, use the default directory (tempfile.TemporaryDirectory)
else:
self._temp_dir = tempfile.TemporaryDirectory()
temp_dir_path = Path(self._temp_dir.name)
temp_dir_path.mkdir(exist_ok=True, parents=True)
self._started = False
# Rest of initialization remains the same
self._functions_module = functions_module
self._timeout = timeout
self._functions = functions
self._func_code: Optional[str] = None
# Setup could take some time so we intentionally wait for the first code block to do it.
if len(functions) > 0:
self._setup_functions_complete = False
else:
self._setup_functions_complete = True
self._suppress_result_output = suppress_result_output
self._pool_management_endpoint = pool_management_endpoint
self._access_token: str | None = None
self._session_id: str = session_id or str(uuid4())
self._available_packages: set[str] | None = None
self._credential: TokenProvider = credential
# cwd needs to be set to /mnt/data to properly read uploaded files and download written files
self._setup_cwd_complete = False
# TODO: expiration?
def _ensure_access_token(self) -> None:
if not self._access_token:
scope = "https://dynamicsessions.io/.default"
self._access_token = self._credential.get_token(scope).token
def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str:
"""(Experimental) Format the functions for a prompt.
The template includes one variable:
- `$functions`: The functions formatted as stubs with two newlines between each function.
Args:
prompt_template (str): The prompt template. Default is the class default.
Returns:
str: The formatted prompt.
"""
template = Template(prompt_template)
return template.substitute(
functions="\n\n".join([to_stub(func) for func in self._functions]),
)
@property
def functions_module(self) -> str:
"""(Experimental) The module name for the functions."""
return self._functions_module
@property
def functions(self) -> List[str]:
raise NotImplementedError
@property
def timeout(self) -> int:
"""(Experimental) The timeout for code execution."""
return self._timeout
@property
def work_dir(self) -> Path:
# If a user specifies a working directory, use that
if self._work_dir is not None:
# If a user specifies the current directory, warn them that this is deprecated
if self._work_dir == Path("."):
warnings.warn(
"Using the current directory as work_dir is deprecated",
DeprecationWarning,
stacklevel=2,
)
return self._work_dir
# If a user does not specify a working directory, use the default directory (tempfile.TemporaryDirectory)
elif self._temp_dir is not None:
return Path(self._temp_dir.name)
else:
raise RuntimeError("Working directory not properly initialized")
def _construct_url(self, path: str) -> str:
endpoint = self._pool_management_endpoint
if not endpoint.endswith("/"):
endpoint += "/"
url = endpoint + f"{path}?api-version={self._AZURE_API_VER}&identifier={self._session_id}"
return url
async def get_available_packages(self, cancellation_token: CancellationToken) -> set[str]:
if self._available_packages is not None:
return self._available_packages
avail_pkgs = """
import pkg_resources\n[d.project_name for d in pkg_resources.working_set]
"""
ret = await self._execute_code_dont_check_setup(
[CodeBlock(code=avail_pkgs, language="python")], cancellation_token
)
if ret.exit_code != 0:
raise ValueError(f"Failed to get list of available packages: {ret.output.strip()}")
pkgs = ret.output.strip("[]")
pkglist = pkgs.split(",\n")
return {pkg.strip(" '") for pkg in pkglist}
async def _populate_available_packages(self, cancellation_token: CancellationToken) -> None:
self._available_packages = await self.get_available_packages(cancellation_token)
async def _setup_functions(self, cancellation_token: CancellationToken) -> None:
if not self._func_code:
self._func_code = build_python_functions_file(self._functions)
# Check required function imports and packages
lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)]
# Should we also be checking the imports?
flattened_packages = [item for sublist in lists_of_packages for item in sublist]
required_packages = set(flattened_packages)
if self._available_packages is None:
await self._populate_available_packages(cancellation_token)
if self._available_packages is not None:
missing_pkgs = set(required_packages - self._available_packages)
if len(missing_pkgs) > 0:
raise ValueError(f"Packages unavailable in environment: {missing_pkgs}")
func_file = self.work_dir / f"{self._functions_module}.py"
func_file.write_text(self._func_code)
# Attempt to load the function file to check for syntax errors, imports etc.
exec_result = await self._execute_code_dont_check_setup(
[CodeBlock(code=self._func_code, language="python")], cancellation_token
)
if exec_result.exit_code != 0:
raise ValueError(f"Functions failed to load: {exec_result.output.strip()}")
self._setup_functions_complete = True
async def _setup_cwd(self, cancellation_token: CancellationToken) -> None:
# Change the cwd to /mnt/data to properly have access to uploaded files
exec_result = await self._execute_code_dont_check_setup(
[CodeBlock(code="import os; os.chdir('/mnt/data')", language="python")], cancellation_token
)
if exec_result.exit_code != 0:
raise ValueError("Failed to set up Azure container working directory")
self._setup_cwd_complete = True
async def get_file_list(self, cancellation_token: CancellationToken) -> List[str]:
self._ensure_access_token()
timeout = aiohttp.ClientTimeout(total=float(self._timeout))
headers = {
"Authorization": f"Bearer {self._access_token}",
}
url = self._construct_url("files")
async with aiohttp.ClientSession(timeout=timeout) as client:
task = asyncio.create_task(
client.get(
url,
headers=headers,
)
)
cancellation_token.link_future(task)
try:
resp = await task
resp.raise_for_status()
data = await resp.json()
except asyncio.TimeoutError as e:
# e.add_note is only in py 3.11+
raise asyncio.TimeoutError("Timeout getting file list") from e
except asyncio.CancelledError as e:
# e.add_note is only in py 3.11+
raise asyncio.CancelledError("File list retrieval cancelled") from e
except aiohttp.ClientResponseError as e:
raise ConnectionError("Error while getting file list") from e
values = data["value"]
file_info_list: List[str] = []
for value in values:
file = value["properties"]
file_info_list.append(file["filename"])
return file_info_list
async def upload_files(self, files: List[Union[Path, str]], cancellation_token: CancellationToken) -> None:
self._ensure_access_token()
# TODO: Better to use the client auth system rather than headers
headers = {"Authorization": f"Bearer {self._access_token}"}
url = self._construct_url("files/upload")
timeout = aiohttp.ClientTimeout(total=float(self._timeout))
async with aiohttp.ClientSession(timeout=timeout) as client:
for file in files:
file_path = self.work_dir / file
if not file_path.is_file():
# TODO: what to do here?
raise FileNotFoundError(f"{file} does not exist")
data = aiohttp.FormData()
async with await open_file(file_path, "rb") as f:
data.add_field(
"file",
f,
filename=os.path.basename(file_path),
content_type="application/octet-stream",
)
task = asyncio.create_task(
client.post(
url,
headers=headers,
data=data,
)
)
cancellation_token.link_future(task)
try:
resp = await task
resp.raise_for_status()
except asyncio.TimeoutError as e:
# e.add_note is only in py 3.11+
raise asyncio.TimeoutError("Timeout uploading files") from e
except asyncio.CancelledError as e:
# e.add_note is only in py 3.11+
raise asyncio.CancelledError("Uploading files cancelled") from e
except aiohttp.ClientResponseError as e:
raise ConnectionError("Error while uploading files") from e
async def download_files(self, files: List[Union[Path, str]], cancellation_token: CancellationToken) -> List[str]:
self._ensure_access_token()
available_files = await self.get_file_list(cancellation_token)
# TODO: Better to use the client auth system rather than headers
headers = {"Authorization": f"Bearer {self._access_token}"}
timeout = aiohttp.ClientTimeout(total=float(self._timeout))
local_paths: List[str] = []
async with aiohttp.ClientSession(timeout=timeout) as client:
for file in files:
if file not in available_files:
# TODO: what's the right thing to do here?
raise FileNotFoundError(f"{file} does not exist")
url = self._construct_url(f"files/content/{file}")
task = asyncio.create_task(
client.get(
url,
headers=headers,
)
)
cancellation_token.link_future(task)
try:
resp = await task
resp.raise_for_status()
local_path = self.work_dir / file
local_paths.append(str(local_path))
async with await open_file(local_path, "wb") as f:
await f.write(await resp.read())
except asyncio.TimeoutError as e:
# e.add_note is only in py 3.11+
raise asyncio.TimeoutError("Timeout downloading files") from e
except asyncio.CancelledError as e:
# e.add_note is only in py 3.11+
raise asyncio.CancelledError("Downloading files cancelled") from e
except aiohttp.ClientResponseError as e:
raise ConnectionError("Error while downloading files") from e
return local_paths
async def execute_code_blocks(
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
) -> CodeResult:
"""(Experimental) Execute the code blocks and return the result.
Args:
code_blocks (List[CodeBlock]): The code blocks to execute.
cancellation_token (CancellationToken): a token to cancel the operation
input_files (Optional[Union[Path, str]]): Any files the code blocks will need to access
Returns:
CodeResult: The result of the code execution."""
self._ensure_access_token()
if self._available_packages is None:
await self._populate_available_packages(cancellation_token)
if not self._setup_functions_complete:
await self._setup_functions(cancellation_token)
if not self._setup_cwd_complete:
await self._setup_cwd(cancellation_token)
return await self._execute_code_dont_check_setup(code_blocks, cancellation_token)
# The http call here should be replaced by an actual Azure client call once its available
async def _execute_code_dont_check_setup(
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
) -> CodeResult:
logs_all = ""
exitcode = 0
# TODO: Better to use the client auth system rather than headers
assert self._access_token is not None
headers = {
"Authorization": f"Bearer {self._access_token}",
"Content-Type": "application/json",
}
properties = {
"codeInputType": "inline",
"executionType": "synchronous",
"code": "", # Filled in later
}
url = self._construct_url("code/execute")
timeout = aiohttp.ClientTimeout(total=float(self._timeout))
async with aiohttp.ClientSession(timeout=timeout) as client:
for code_block in code_blocks:
lang, code = code_block.language, code_block.code
lang = lang.lower()
if lang in PYTHON_VARIANTS:
lang = "python"
if lang not in self.SUPPORTED_LANGUAGES:
# In case the language is not supported, we return an error message.
exitcode = 1
logs_all += "\n" + f"unknown language {lang}"
break
if self._available_packages is not None:
req_pkgs = get_required_packages(code, lang)
missing_pkgs = set(req_pkgs - self._available_packages)
if len(missing_pkgs) > 0:
# In case the code requires packages that are not available in the environment
exitcode = 1
logs_all += "\n" + f"Python packages unavailable in environment: {missing_pkgs}"
break
properties["code"] = code_block.code
task = asyncio.create_task(
client.post(
url,
headers=headers,
json={"properties": properties},
)
)
cancellation_token.link_future(task)
try:
response = await task
response.raise_for_status()
data = await response.json()
data = data["properties"]
logs_all += data.get("stderr", "") + data.get("stdout", "")
if "Success" in data["status"]:
if not self._suppress_result_output:
logs_all += str(data["result"])
elif "Failure" in data["status"]:
exitcode = 1
except asyncio.TimeoutError as e:
logs_all += "\n Timeout"
# e.add_note is only in py 3.11+
raise asyncio.TimeoutError(logs_all) from e
except asyncio.CancelledError as e:
logs_all += "\n Cancelled"
# e.add_note is only in py 3.11+
raise asyncio.CancelledError(logs_all) from e
except aiohttp.ClientResponseError as e:
logs_all += "\nError while sending code block to endpoint"
raise ConnectionError(logs_all) from e
return CodeResult(exit_code=exitcode, output=logs_all)
async def restart(self) -> None:
"""(Experimental) Restart the code executor.
Resets the internal state of the executor by generating a new session ID and resetting the setup variables.
This causes the next code execution to reinitialize the environment and re-run any setup code.
"""
self._session_id = str(uuid4())
self._setup_functions_complete = False
self._access_token = None
self._available_packages = None
self._setup_cwd_complete = False
async def start(self) -> None:
"""(Experimental) Start the code executor.
Marks the code executor as started."""
# No setup needed for this executor
self._started = True
async def stop(self) -> None:
"""(Experimental) Stop the code executor.
Stops the code executor after cleaning up the temporary working directory (if it was created)."""
if self._temp_dir is not None:
self._temp_dir.cleanup()
self._temp_dir = None
self._started = False

View File

@@ -0,0 +1,3 @@
from ._docker_code_executor import DockerCommandLineCodeExecutor
__all__ = ["DockerCommandLineCodeExecutor"]

View File

@@ -0,0 +1,613 @@
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/docker_commandline_code_executor.py
# Credit to original authors
from __future__ import annotations
import asyncio
import logging
import shlex
import sys
import tempfile
import uuid
import warnings
from collections.abc import Sequence
from concurrent.futures import Future as ConcurrentFuture
from hashlib import sha256
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, ParamSpec, Tuple, Union
from agentdhal_core import CancellationToken, Component
from agentdhal_core.code_executor import (
CodeBlock,
CodeExecutor,
FunctionWithRequirements,
FunctionWithRequirementsStr,
)
from pydantic import BaseModel
from typing_extensions import Self
from docker.types import DeviceRequest
from .._common import (
CommandLineCodeResult,
build_python_functions_file,
get_file_name_from_content,
lang_to_cmd,
silence_pip,
)
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
try:
import asyncio_atexit
import docker
from docker.errors import DockerException, ImageNotFound, NotFound
from docker.models.containers import Container
except ImportError as e:
raise RuntimeError(
"Missing dependecies for DockerCommandLineCodeExecutor. Please ensure the autogen-ext package was installed with the 'docker' extra."
) from e
async def _wait_for_ready(container: Any, timeout: int = 60, stop_time: float = 0.1) -> None:
elapsed_time = 0.0
while container.status != "running" and elapsed_time < timeout:
await asyncio.sleep(stop_time)
elapsed_time += stop_time
await asyncio.to_thread(container.reload)
continue
if container.status != "running":
raise ValueError("Container failed to start")
A = ParamSpec("A")
class DockerCommandLineCodeExecutorConfig(BaseModel):
"""Configuration for DockerCommandLineCodeExecutor"""
image: str = "python:3-slim"
container_name: Optional[str] = None
timeout: int = 60
work_dir: Optional[str] = None
bind_dir: Optional[str] = None
auto_remove: bool = True
stop_container: bool = True
functions_module: str = "functions"
extra_volumes: Dict[str, Dict[str, str]] = {}
extra_hosts: Dict[str, str] = {}
init_command: Optional[str] = None
delete_tmp_files: bool = False
class DockerCommandLineCodeExecutor(CodeExecutor, Component[DockerCommandLineCodeExecutorConfig]):
"""Executes code through a command line environment in a Docker container.
.. note::
This class requires the :code:`docker` extra for the :code:`autogen-ext` package:
.. code-block:: bash
pip install "agentdhal-ext[docker]"
The executor first saves each code block in a file in the working
directory, and then executes the code file in the container.
The executor executes the code blocks in the order they are received.
Currently, the executor only supports Python and shell scripts.
For Python code, use the language "python" for the code block.
For shell scripts, use the language "bash", "shell", "sh", "pwsh", "powershell", or "ps1" for the code block.
Args:
image (_type_, optional): Docker image to use for code execution.
Defaults to "python:3-slim".
container_name (Optional[str], optional): Name of the Docker container
which is created. If None, will autogenerate a name. Defaults to None.
timeout (int, optional): The timeout for code execution. Defaults to 60.
work_dir (Union[Path, str], optional): The working directory for the code
execution. Defaults to temporary directory.
bind_dir (Union[Path, str], optional): The directory that will be bound
to the code executor container. Useful for cases where you want to spawn
the container from within a container. Defaults to work_dir.
auto_remove (bool, optional): If true, will automatically remove the Docker
container when it is stopped. Defaults to True.
stop_container (bool, optional): If true, will automatically stop the
container when stop is called, when the context manager exits or when
the Python process exits with atext. Defaults to True.
device_requests (Optional[List[DeviceRequest]], optional): A list of device request instances to add to the container for exposing GPUs (e.g., [docker.types.DeviceRequest(count=-1, capabilities=[['gpu']])]). Defaults to None.
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
functions_module (str, optional): The name of the module that will be created to store the functions. Defaults to "functions".
extra_volumes (Optional[Dict[str, Dict[str, str]]], optional): A dictionary of extra volumes (beyond the work_dir) to mount to the container;
key is host source path and value 'bind' is the container path. See Defaults to None.
Example: extra_volumes = {'/home/user1/': {'bind': '/mnt/vol2', 'mode': 'rw'}, '/var/www': {'bind': '/mnt/vol1', 'mode': 'ro'}}
extra_hosts (Optional[Dict[str, str]], optional): A dictionary of host mappings to add to the container. (See Docker docs on extra_hosts) Defaults to None.
Example: extra_hosts = {"kubernetes.docker.internal": "host-gateway"}
init_command (Optional[str], optional): A shell command to run before each shell operation execution. Defaults to None.
Example: init_command="kubectl config use-context docker-hub"
delete_tmp_files (bool, optional): If true, will delete temporary files after execution. Defaults to False.
.. note::
Using the current directory (".") as working directory is deprecated. Using it will raise a deprecation warning.
"""
component_config_schema = DockerCommandLineCodeExecutorConfig
component_provider_override = "agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor"
SUPPORTED_LANGUAGES: ClassVar[List[str]] = [
"bash",
"shell",
"sh",
"pwsh",
"powershell",
"ps1",
"python",
]
FUNCTION_PROMPT_TEMPLATE: ClassVar[
str
] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names.
For example, if there was a function called `foo` you could import it by writing `from $module_name import foo`
$functions"""
def __init__(
self,
image: str = "python:3-slim",
container_name: Optional[str] = None,
*,
timeout: int = 60,
work_dir: Union[Path, str, None] = None,
bind_dir: Optional[Union[Path, str]] = None,
auto_remove: bool = True,
stop_container: bool = True,
device_requests: Optional[List[DeviceRequest]] = None,
functions: Sequence[
Union[
FunctionWithRequirements[Any, A],
Callable[..., Any],
FunctionWithRequirementsStr,
]
] = [],
functions_module: str = "functions",
extra_volumes: Optional[Dict[str, Dict[str, str]]] = None,
extra_hosts: Optional[Dict[str, str]] = None,
init_command: Optional[str] = None,
delete_tmp_files: bool = False,
):
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")
# Handle working directory logic
if work_dir is None:
self._work_dir = None
else:
if isinstance(work_dir, str):
work_dir = Path(work_dir)
# Emit a deprecation warning if the user is using the current directory as working directory
if work_dir.resolve() == Path.cwd().resolve():
warnings.warn(
"Using the current directory as work_dir is deprecated.",
DeprecationWarning,
stacklevel=2,
)
self._work_dir = work_dir
# Create the working directory if it doesn't exist
self._work_dir.mkdir(exist_ok=True, parents=True)
if container_name is None:
self.container_name = f"agentdhal-code-exec-{uuid.uuid4()}"
else:
self.container_name = container_name
self._timeout = timeout
# Handle bind_dir
self._bind_dir: Optional[Path] = None
if bind_dir is not None:
self._bind_dir = Path(bind_dir) if isinstance(bind_dir, str) else bind_dir
else:
self._bind_dir = self._work_dir # Default to work_dir if not provided
# Track temporary directory
self._temp_dir: Optional[tempfile.TemporaryDirectory[str]] = None
self._temp_dir_path: Optional[Path] = None
self._started = False
self._auto_remove = auto_remove
self._stop_container = stop_container
self._image = image
if not functions_module.isidentifier():
raise ValueError("Module name must be a valid Python identifier")
self._functions_module = functions_module
self._functions = functions
self._extra_volumes = extra_volumes if extra_volumes is not None else {}
self._extra_hosts = extra_hosts if extra_hosts is not None else {}
self._init_command = init_command
self._delete_tmp_files = delete_tmp_files
self._device_requests = device_requests
# Setup could take some time so we intentionally wait for the first code block to do it.
if len(functions) > 0:
self._setup_functions_complete = False
else:
self._setup_functions_complete = True
self._container: Container | None = None
self._running = False
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._cancellation_futures: List[ConcurrentFuture[None]] = []
@property
def timeout(self) -> int:
"""(Experimental) The timeout for code execution."""
return self._timeout
async def _setup_functions(self, cancellation_token: CancellationToken) -> None:
func_file_content = build_python_functions_file(self._functions)
func_file = self.work_dir / f"{self._functions_module}.py"
func_file.write_text(func_file_content)
# Collect requirements
lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)]
flattened_packages = [item for sublist in lists_of_packages for item in sublist]
required_packages = list(set(flattened_packages))
if len(required_packages) > 0:
logging.info("Ensuring packages are installed in executor.")
packages = shlex.join(required_packages)
result = await self._execute_code_dont_check_setup(
[CodeBlock(code=f"python -m pip install {packages}", language="sh")], cancellation_token
)
if result.exit_code != 0:
stdout = result.output
stderr = result.output
raise ValueError(f"Pip install failed. {stdout}, {stderr}")
# Attempt to load the function file to check for syntax errors, imports etc.
exec_result = await self._execute_code_dont_check_setup(
[CodeBlock(code=func_file_content, language="python")], cancellation_token
)
if exec_result.exit_code != 0:
raise ValueError(f"Functions failed to load: {exec_result.output}")
self._setup_functions_complete = True
async def _kill_running_command(self, command: List[str]) -> None:
if self._container is None or not self._running:
return
await asyncio.to_thread(self._container.exec_run, ["pkill", "-f", " ".join(command)])
async def _execute_command(self, command: List[str], cancellation_token: CancellationToken) -> Tuple[str, int]:
if self._container is None or not self._running:
raise ValueError("Container is not running. Must first be started with either start or a context manager.")
exec_task = asyncio.create_task(asyncio.to_thread(self._container.exec_run, command))
cancellation_token.link_future(exec_task)
# Wait for the exec task to finish.
try:
result = await exec_task
exit_code = result.exit_code
output = result.output.decode("utf-8")
if exit_code == 124:
output += "\n Timeout"
return output, exit_code
except asyncio.CancelledError:
# Schedule a task to kill the running command in the background.
if self._loop and not self._loop.is_closed():
try:
logging.debug(f"Scheduling kill command via run_coroutine_threadsafe on loop {self._loop!r}")
future: ConcurrentFuture[None] = asyncio.run_coroutine_threadsafe(
self._kill_running_command(command), self._loop
)
self._cancellation_futures.append(future)
logging.debug(f"Kill command scheduled, future: {future!r}")
except RuntimeError as e:
logging.error(f"Failed to schedule kill command on loop {self._loop!r}: {e}")
except Exception as e:
logging.exception(f"Unexpected error scheduling kill command: {e}")
else:
logging.warning(
f"Cannot schedule kill command: Executor loop is not available or closed (loop: {self._loop!r})."
)
return "Code execution was cancelled.", 1
async def _execute_code_dont_check_setup(
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
) -> CommandLineCodeResult:
if self._container is None or not self._running:
raise ValueError("Container is not running. Must first be started with either start or a context manager.")
if len(code_blocks) == 0:
raise ValueError("No code blocks to execute.")
outputs: List[str] = []
files: List[Path] = []
last_exit_code = 0
try:
for code_block in code_blocks:
lang = code_block.language.lower()
code = silence_pip(code_block.code, lang)
# Check if there is a filename comment
try:
filename = get_file_name_from_content(code, self.work_dir)
except ValueError:
outputs.append("Filename is not in the workspace")
last_exit_code = 1
break
if not filename:
filename = f"tmp_code_{sha256(code.encode()).hexdigest()}.{lang}"
code_path = self.work_dir / filename
with code_path.open("w", encoding="utf-8") as fout:
fout.write(code)
files.append(code_path)
command = ["timeout", str(self._timeout), lang_to_cmd(lang), filename]
output, exit_code = await self._execute_command(command, cancellation_token)
outputs.append(output)
last_exit_code = exit_code
if exit_code != 0:
break
finally:
if self._delete_tmp_files:
for file in files:
try:
file.unlink()
except (OSError, FileNotFoundError):
pass
code_file = str(files[0]) if files else None
return CommandLineCodeResult(exit_code=last_exit_code, output="".join(outputs), code_file=code_file)
@property
def work_dir(self) -> Path:
# If a user specifies a working directory, use that
if self._work_dir is not None:
# If a user specifies the current directory, warn them that this is deprecated
if self._work_dir == Path("."):
warnings.warn(
"Using the current directory as work_dir is deprecated.",
DeprecationWarning,
stacklevel=2,
)
return self._work_dir
# If a user does not specify a working directory, use the default directory (tempfile.TemporaryDirectory)
elif self._temp_dir is not None:
return Path(self._temp_dir.name)
else:
raise RuntimeError("Working directory not properly initialized")
@property
def bind_dir(self) -> Path:
# If the user specified a bind directory, return it
if self._bind_dir is not None:
return self._bind_dir
# Otherwise bind_dir is set to the current work_dir as default
else:
return self.work_dir
async def execute_code_blocks(
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
) -> CommandLineCodeResult:
"""(Experimental) Execute the code blocks and return the result.
Args:
code_blocks (List[CodeBlock]): The code blocks to execute.
Returns:
CommandlineCodeResult: The result of the code execution."""
if not self._setup_functions_complete:
await self._setup_functions(cancellation_token)
return await self._execute_code_dont_check_setup(code_blocks, cancellation_token)
async def restart(self) -> None:
"""(Experimental) Restart the Docker container code executor."""
if self._container is None or not self._running:
raise ValueError("Container is not running. Must first be started with either start or a context manager.")
await asyncio.to_thread(self._container.restart) # type: ignore
if self._container.status != "running":
self._running = False
logs_str = self._container.logs().decode("utf-8")
raise ValueError(f"Failed to restart container. Logs: {logs_str}")
async def stop(self) -> None:
"""(Experimental) Stop the code executor.
Stops the Docker container and cleans up any temporary files (if they were created), along with the temporary directory.
The method first waits for all cancellation tasks to finish before stopping the container. Finally it marks the executor as not running.
If the container is not running, the method does nothing.
"""
if not self._running:
return
if self._temp_dir is not None:
self._temp_dir.cleanup()
self._temp_dir = None
client = docker.from_env()
try:
try:
container = await asyncio.to_thread(client.containers.get, self.container_name)
except NotFound:
logging.debug(f"Container {self.container_name} not found during stop...")
self._running = False
self._cancellation_futures.clear()
return
if self._cancellation_futures:
if not self._loop or self._loop.is_closed():
logging.warning(
f"Executor loop ({self._loop!r}) is closed or unavailable. Cannot reliably wait for "
f"{len(self._cancellation_futures)} cancellation futures."
)
self._cancellation_futures.clear()
else:
# concurrent.futures.Future -> asyncio.Future
asyncio_futures = [asyncio.wrap_future(f, loop=self._loop) for f in self._cancellation_futures]
if asyncio_futures:
logging.debug(
f"Waiting for {len(asyncio_futures)} cancellation futures to complete on loop {self._loop!r}..."
)
results = await asyncio.gather(*asyncio_futures, return_exceptions=True)
for i, result in enumerate(results):
original_future = self._cancellation_futures[i]
if isinstance(result, Exception):
logging.warning(f"Cancellation future {original_future!r} failed: {result}")
else:
logging.debug(f"Cancellation future {original_future!r} completed successfully.")
else:
logging.debug("No valid cancellation futures to await.")
self._cancellation_futures.clear()
logging.debug(f"Stopping container {self.container_name}...")
await asyncio.to_thread(container.stop)
logging.debug(f"Container {self.container_name} stopped.")
except DockerException as e:
logging.error(f"Docker error while stopping container {self.container_name}: {e}")
except Exception as e:
logging.exception(f"Unexpected error during stop operation for container {self.container_name}: {e}")
finally:
self._running = False
self._cancellation_futures.clear()
async def start(self) -> None:
"""(Experimental) Start the code executor.
This method sets the working environment variables, connects to Docker and starts the code executor.
If no working directory was provided to the code executor, it creates a temporary directory and sets it as the code executor working directory.
"""
if self._work_dir is None and self._temp_dir is None:
self._temp_dir = tempfile.TemporaryDirectory()
self._temp_dir_path = Path(self._temp_dir.name)
self._temp_dir_path.mkdir(exist_ok=True)
# Start a container from the image, read to exec commands later
try:
client = docker.from_env()
except DockerException as e:
if "FileNotFoundError" in str(e):
raise RuntimeError("Failed to connect to Docker. Please ensure Docker is installed and running.") from e
raise
except Exception as e:
raise RuntimeError(f"Unexpected error while connecting to Docker: {str(e)}") from e
# Check if the image exists
try:
await asyncio.to_thread(client.images.get, self._image)
except ImageNotFound:
# TODO logger
logging.info(f"Pulling image {self._image}...")
# Let the docker exception escape if this fails.
await asyncio.to_thread(client.images.pull, self._image)
# Prepare the command (if needed)
shell_command = "/bin/sh"
command = ["-c", f"{(self._init_command)};exec {shell_command}"] if self._init_command else None
# Check if a container with the same name already exists and remove it
try:
existing_container = await asyncio.to_thread(client.containers.get, self.container_name)
await asyncio.to_thread(existing_container.remove, force=True)
except NotFound:
pass
self._container = await asyncio.to_thread(
client.containers.create,
self._image,
name=self.container_name,
entrypoint=shell_command,
command=command,
tty=True,
detach=True,
auto_remove=self._auto_remove,
volumes={str(self.bind_dir.resolve()): {"bind": "/workspace", "mode": "rw"}, **self._extra_volumes},
working_dir="/workspace",
extra_hosts=self._extra_hosts,
device_requests=self._device_requests,
)
await asyncio.to_thread(self._container.start)
await _wait_for_ready(self._container)
async def cleanup() -> None:
await self.stop()
asyncio_atexit.unregister(cleanup) # type: ignore
if self._stop_container:
asyncio_atexit.register(cleanup) # type: ignore
# Check if the container is running
if self._container.status != "running":
logs_str = self._container.logs().decode("utf-8")
raise ValueError(f"Failed to start container from image {self._image}. Logs: {logs_str}")
self._loop = asyncio.get_running_loop()
self._cancellation_futures = []
logging.debug(f"Executor started, associated with event loop: {self._loop!r}")
self._running = True
def _to_config(self) -> DockerCommandLineCodeExecutorConfig:
"""(Experimental) Convert the component to a config object."""
if self._functions:
logging.info("Functions will not be included in serialized configuration")
return DockerCommandLineCodeExecutorConfig(
image=self._image,
container_name=self.container_name,
timeout=self._timeout,
work_dir=str(self._work_dir) if self._work_dir else None,
bind_dir=str(self._bind_dir) if self._bind_dir else None,
auto_remove=self._auto_remove,
stop_container=self._stop_container,
functions_module=self._functions_module,
extra_volumes=self._extra_volumes,
extra_hosts=self._extra_hosts,
init_command=self._init_command,
delete_tmp_files=self._delete_tmp_files,
)
@classmethod
def _from_config(cls, config: DockerCommandLineCodeExecutorConfig) -> Self:
"""(Experimental) Create a component from a config object."""
return cls(
image=config.image,
container_name=config.container_name,
timeout=config.timeout,
work_dir=Path(config.work_dir) if config.work_dir else None,
bind_dir=Path(config.bind_dir) if config.bind_dir else None,
auto_remove=config.auto_remove,
stop_container=config.stop_container,
functions=[], # Functions not restored from config
functions_module=config.functions_module,
extra_volumes=config.extra_volumes,
extra_hosts=config.extra_hosts,
init_command=config.init_command,
delete_tmp_files=config.delete_tmp_files,
)

View File

@@ -0,0 +1,10 @@
from ._docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterCodeResult
from ._jupyter_server import DockerJupyterServer, JupyterClient, JupyterKernelClient
__all__ = [
"DockerJupyterCodeExecutor",
"DockerJupyterServer",
"JupyterClient",
"JupyterKernelClient",
"DockerJupyterCodeResult",
]

View File

@@ -0,0 +1,300 @@
import asyncio
import base64
import json
import os
import tempfile
import uuid
from dataclasses import dataclass
from pathlib import Path
from types import TracebackType
from typing import List, Optional, Union
from agentdhal_core import CancellationToken, Component
from agentdhal_core.code_executor import CodeBlock, CodeExecutor, CodeResult
from agentdhal_extensions.code_executors._common import silence_pip
from pydantic import BaseModel
from typing_extensions import Self
from ._jupyter_server import JupyterClient, JupyterConnectable, JupyterConnectionInfo, JupyterKernelClient
@dataclass
class DockerJupyterCodeResult(CodeResult):
"""(Experimental) A code result class for IPython code executor."""
output_files: list[Path]
class DockerJupyterCodeExecutorConfig(BaseModel):
"""Configuration for JupyterCodeExecutor"""
jupyter_server: Union[JupyterConnectable, JupyterConnectionInfo]
kernel_name: str = "python3"
timeout: int = 60
output_dir: Optional[Union[Path, str]] = None
class Config:
arbitrary_types_allowed = True
class DockerJupyterCodeExecutor(CodeExecutor, Component[DockerJupyterCodeExecutorConfig]):
"""(Experimental) A code executor class that executes code statefully using
a Jupyter server supplied to this class.
Each execution is stateful and can access variables created from previous
executions in the same session.
To use this, you need to install the following dependencies:
.. code-block:: shell
pip install "agentdhal-ext[docker-jupyter-executor]"
Args:
jupyter_server (Union[JupyterConnectable, JupyterConnectionInfo]): The Jupyter server to use.
kernel_name (str): The kernel name to use. Make sure it is installed.
By default, it is "python3".
timeout (int): The timeout for code execution, by default 60.
output_dir (str): The directory to save output files, by default None.
Example of using it directly:
.. code-block:: python
import asyncio
from agentdhal_core import CancellationToken
from agentdhal_core.code_executor import CodeBlock
from agentdhal_extensions.code_executors.docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterServer
async def main() -> None:
async with DockerJupyterServer() as jupyter_server:
async with DockerJupyterCodeExecutor(jupyter_server=jupyter_server) as executor:
code_blocks = [CodeBlock(code="print('hello world!')", language="python")]
code_result = await executor.execute_code_blocks(code_blocks, cancellation_token=CancellationToken())
print(code_result)
asyncio.run(main())
Example of using it with your own jupyter image:
.. code-block:: python
import asyncio
from agentdhal_core import CancellationToken
from agentdhal_core.code_executor import CodeBlock
from agentdhal_extensions.code_executors.docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterServer
async def main() -> None:
async with DockerJupyterServer(custom_image_name="your_custom_images_name", expose_port=8888) as jupyter_server:
async with DockerJupyterCodeExecutor(jupyter_server=jupyter_server) as executor:
code_blocks = [CodeBlock(code="print('hello world!')", language="python")]
code_result = await executor.execute_code_blocks(code_blocks, cancellation_token=CancellationToken())
print(code_result)
asyncio.run(main())
Example of using it with :class:`~agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool`:
.. code-block:: python
import asyncio
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_extensions.code_executors.docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterServer
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.tools.code_execution import PythonCodeExecutionTool
async def main() -> None:
async with DockerJupyterServer() as jupyter_server:
async with DockerJupyterCodeExecutor(jupyter_server=jupyter_server) as executor:
tool = PythonCodeExecutionTool(executor)
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent("assistant", model_client=model_client, tools=[tool])
result = await agent.run(task="What is the 10th Fibonacci number? Use Python to calculate it.")
print(result)
asyncio.run(main())
Example of using it inside a :class:`~agentdhal_agentchat.agents._code_executor_agent.CodeExecutorAgent`:
.. code-block:: python
import asyncio
from agentdhal_agentchat.agents import CodeExecutorAgent
from agentdhal_agentchat.messages import TextMessage
from agentdhal_extensions.code_executors.docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterServer
from agentdhal_core import CancellationToken
async def main() -> None:
async with DockerJupyterServer() as jupyter_server:
async with DockerJupyterCodeExecutor(jupyter_server=jupyter_server) as executor:
code_executor_agent = CodeExecutorAgent("code_executor", code_executor=executor)
task = TextMessage(
content='''Here is some code
```python
print('Hello world')
```
''',
source="user",
)
response = await code_executor_agent.on_messages([task], CancellationToken())
print(response.chat_message)
asyncio.run(main())
"""
component_config_schema = DockerJupyterCodeExecutorConfig
component_provider_override = "agentdhal_extensions.code_executors.docker_jupyter.DockerJupyterCodeExecutor"
def __init__(
self,
jupyter_server: Union[JupyterConnectable, JupyterConnectionInfo],
kernel_name: str = "python3",
timeout: int = 60,
output_dir: Path | None = None,
):
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")
if isinstance(jupyter_server, JupyterConnectable):
self._connection_info = jupyter_server.connection_info
elif isinstance(jupyter_server, JupyterConnectionInfo):
self._connection_info = jupyter_server
else:
raise ValueError("jupyter_server must be a JupyterConnectable or JupyterConnectionInfo.")
self._output_dir = output_dir or getattr(jupyter_server, "_bind_dir", None)
if not self._output_dir:
with tempfile.TemporaryDirectory() as temp_dir:
self._output_dir = Path(temp_dir)
self._output_dir.mkdir(exist_ok=True)
self._jupyter_client = JupyterClient(self._connection_info)
self._kernel_name = kernel_name
self._timeout = timeout
self._async_jupyter_kernel_client: Optional[JupyterKernelClient] = None
self._kernel_id: Optional[str] = None
async def _ensure_async_kernel_client(self) -> JupyterKernelClient:
"""Ensure that an async kernel client exists and return it."""
if self._kernel_id is None:
await self.start()
assert self._kernel_id is not None
if self._async_jupyter_kernel_client is None:
self._async_jupyter_kernel_client = await self._jupyter_client.get_kernel_client(self._kernel_id)
return self._async_jupyter_kernel_client
async def execute_code_blocks(
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
) -> DockerJupyterCodeResult:
"""(Experimental) Execute a list of code blocks and return the result.
This method executes a list of code blocks as cells in the Jupyter kernel.
See: https://jupyter-client.readthedocs.io/en/stable/messaging.html
for the message protocol.
Args:
code_blocks (List[CodeBlock]): A list of code blocks to execute.
Returns:
DockerJupyterCodeResult: The result of the code execution.
"""
kernel_client = await self._ensure_async_kernel_client()
# Wait for kernel to be ready using async client
is_ready = await kernel_client.wait_for_ready(timeout_seconds=self._timeout)
if not is_ready:
return DockerJupyterCodeResult(exit_code=1, output="ERROR: Kernel not ready", output_files=[])
outputs: List[str] = []
output_files: List[Path] = []
for code_block in code_blocks:
code = silence_pip(code_block.code, code_block.language)
# Execute code using async client
exec_task = asyncio.create_task(kernel_client.execute(code, timeout_seconds=self._timeout))
cancellation_token.link_future(exec_task)
result = await exec_task
if result.is_ok:
outputs.append(result.output)
for data in result.data_items:
if data.mime_type == "image/png":
path = self._save_image(data.data)
outputs.append(path)
output_files.append(Path(path))
elif data.mime_type == "text/html":
path = self._save_html(data.data)
outputs.append(path)
output_files.append(Path(path))
else:
outputs.append(json.dumps(data.data))
else:
existing_output = "\n".join([str(output) for output in outputs])
return DockerJupyterCodeResult(
exit_code=1, output=existing_output + "\nERROR: " + result.output, output_files=output_files
)
return DockerJupyterCodeResult(
exit_code=0, output="\n".join([str(output) for output in outputs]), output_files=output_files
)
async def restart(self) -> None:
"""(Experimental) Restart a new session."""
# Use async client to restart kernel
if self._kernel_id is not None:
await self._jupyter_client.restart_kernel(self._kernel_id)
# Reset the clients to force recreation
if self._async_jupyter_kernel_client is not None:
await self._async_jupyter_kernel_client.stop()
self._async_jupyter_kernel_client = None
async def start(self) -> None:
"""(Experimental) Start a new session."""
available_kernels = await self._jupyter_client.list_kernel_specs()
if self._kernel_name not in available_kernels["kernelspecs"]:
raise ValueError(f"Kernel {self._kernel_name} is not installed.")
self._kernel_id = await self._jupyter_client.start_kernel(self._kernel_name)
def _save_image(self, image_data_base64: str) -> str:
"""Save image data to a file."""
image_data = base64.b64decode(image_data_base64)
filename = f"{uuid.uuid4().hex}.png"
path = os.path.join(str(self._output_dir), filename)
with open(path, "wb") as f:
f.write(image_data)
return os.path.abspath(path)
def _save_html(self, html_data: str) -> str:
"""Save html data to a file."""
filename = f"{uuid.uuid4().hex}.html"
path = os.path.join(str(self._output_dir), filename)
with open(path, "w") as f:
f.write(html_data)
return os.path.abspath(path)
async def stop(self) -> None:
"""Stop the kernel."""
if self._kernel_id is not None:
await self._jupyter_client.delete_kernel(self._kernel_id)
if self._async_jupyter_kernel_client is not None:
await self._async_jupyter_kernel_client.stop()
self._async_jupyter_kernel_client = None
await self._jupyter_client.close()
async def __aenter__(self) -> Self:
await self.start()
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.stop()

View File

@@ -0,0 +1,430 @@
import asyncio
import atexit
import datetime
import io
import json
import logging
import os
import secrets
import uuid
from dataclasses import dataclass
from pathlib import Path
from time import sleep
from types import TracebackType
from typing import Any, Dict, List, Optional, Protocol, Type, Union, cast, runtime_checkable
import aiohttp
import docker
import docker.errors
import requests
import websockets
from requests.adapters import HTTPAdapter, Retry
from typing_extensions import Self
@dataclass
class JupyterConnectionInfo:
"""(Experimental)"""
host: str
"""`str` - Host of the Jupyter gateway server"""
use_https: bool
"""`bool` - Whether to use HTTPS"""
port: Optional[int] = None
"""`Optional[int]` - Port of the Jupyter gateway server. If None, the default port is used"""
token: Optional[str] = None
"""`Optional[str]` - Token for authentication. If None, no token is used"""
@runtime_checkable
class JupyterConnectable(Protocol):
"""(Experimental)"""
@property
def connection_info(self) -> JupyterConnectionInfo:
"""Return the connection information for this connectable."""
...
class JupyterClient:
def __init__(self, connection_info: JupyterConnectionInfo):
"""(Experimental) A client for communicating with a Jupyter gateway server.
Args:
connection_info (JupyterConnectionInfo): Connection information
"""
self._connection_info = connection_info
self._session = requests.Session()
retries = Retry(total=5, backoff_factor=0.1)
self._session.mount("http://", HTTPAdapter(max_retries=retries))
# Create aiohttp session for async requests
self._async_session: aiohttp.ClientSession | None = None
async def _ensure_async_session(self) -> aiohttp.ClientSession:
if self._async_session is None:
self._async_session = aiohttp.ClientSession()
return self._async_session
def _get_headers(self) -> Dict[str, str]:
if self._connection_info.token is None:
return {}
return {"Authorization": f"token {self._connection_info.token}"}
def _get_api_base_url(self) -> str:
protocol = "https" if self._connection_info.use_https else "http"
port = f":{self._connection_info.port}" if self._connection_info.port else ""
return f"{protocol}://{self._connection_info.host}{port}"
def _get_ws_base_url(self) -> str:
port = f":{self._connection_info.port}" if self._connection_info.port else ""
return f"ws://{self._connection_info.host}{port}"
async def list_kernel_specs(self) -> Dict[str, Dict[str, str]]:
response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers())
return cast(Dict[str, Dict[str, str]], response.json())
async def list_kernels(self) -> List[Dict[str, str]]:
response = self._session.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers())
return cast(List[Dict[str, str]], response.json())
async def start_kernel(self, kernel_spec_name: str) -> str:
"""Start a new kernel asynchronously.
Args:
kernel_spec_name (str): Name of the kernel spec to start
Returns:
str: ID of the started kernel
"""
session = await self._ensure_async_session()
async with session.post(
f"{self._get_api_base_url()}/api/kernels",
headers=self._get_headers(),
json={"name": kernel_spec_name},
) as response:
data = await response.json()
return cast(str, data["id"])
async def delete_kernel(self, kernel_id: str) -> None:
session = await self._ensure_async_session()
async with session.delete(
f"{self._get_api_base_url()}/api/kernels/{kernel_id}", headers=self._get_headers()
) as response:
response.raise_for_status()
async def restart_kernel(self, kernel_id: str) -> None:
session = await self._ensure_async_session()
async with session.post(
f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", headers=self._get_headers()
) as response:
response.raise_for_status()
async def get_kernel_client(self, kernel_id: str) -> "JupyterKernelClient":
ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels"
# Using websockets library for async websocket connections
ws = await websockets.connect(ws_url, additional_headers=self._get_headers())
return JupyterKernelClient(ws)
async def close(self) -> None:
"""Close the async session"""
if self._async_session is not None:
await self._async_session.close()
self._async_session = None
self._session.close()
@dataclass
class DataItem:
mime_type: str
data: str
@dataclass
class ExecutionResult:
is_ok: bool
output: str
data_items: List[DataItem]
class JupyterKernelClient:
"""An asynchronous client for communicating with a Jupyter kernel."""
def __init__(self, websocket: websockets.ClientConnection) -> None:
self._session_id = uuid.uuid4().hex
self._websocket = websocket
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> None:
await self.stop()
async def stop(self) -> None:
await self._websocket.close()
async def _send_message(self, *, content: Dict[str, Any], channel: str, message_type: str) -> str:
timestamp = datetime.datetime.now().isoformat()
message_id = uuid.uuid4().hex
message = {
"header": {
"username": "agentdhal",
"version": "5.0",
"session": self._session_id,
"msg_id": message_id,
"msg_type": message_type,
"date": timestamp,
},
"parent_header": {},
"channel": channel,
"content": content,
"metadata": {},
"buffers": {},
}
await self._websocket.send(json.dumps(message))
return message_id
async def _receive_message(self, timeout_seconds: Optional[float]) -> Optional[Dict[str, Any]]:
try:
if timeout_seconds is not None:
data = await asyncio.wait_for(self._websocket.recv(), timeout=timeout_seconds)
else:
data = await self._websocket.recv()
if isinstance(data, bytes):
return cast(Dict[str, Any], json.loads(data.decode("utf-8")))
return cast(Dict[str, Any], json.loads(data))
except asyncio.TimeoutError:
return None
async def wait_for_ready(self, timeout_seconds: Optional[float] = None) -> bool:
message_id = await self._send_message(content={}, channel="shell", message_type="kernel_info_request")
while True:
message = await self._receive_message(timeout_seconds)
# This means we timed out with no new messages.
if message is None:
return False
if (
message.get("parent_header", {}).get("msg_id") == message_id
and message["msg_type"] == "kernel_info_reply"
):
return True
async def execute(self, code: str, timeout_seconds: Optional[float] = None) -> ExecutionResult:
message_id = await self._send_message(
content={
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
channel="shell",
message_type="execute_request",
)
text_output: List[str] = []
data_output: List[DataItem] = []
while True:
message = await self._receive_message(timeout_seconds)
if message is None:
return ExecutionResult(
is_ok=False, output="ERROR: Timeout waiting for output from code block.", data_items=[]
)
# Ignore messages that are not for this execution.
if message.get("parent_header", {}).get("msg_id") != message_id:
continue
msg_type = message["msg_type"]
content = message["content"]
if msg_type in ["execute_result", "display_data"]:
for data_type, data in content["data"].items():
if data_type == "text/plain":
text_output.append(data)
elif data_type.startswith("image/") or data_type == "text/html":
data_output.append(DataItem(mime_type=data_type, data=data))
else:
text_output.append(json.dumps(data))
elif msg_type == "stream":
text_output.append(content["text"])
elif msg_type == "error":
# Output is an error.
return ExecutionResult(
is_ok=False,
output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}",
data_items=[],
)
if msg_type == "status" and content["execution_state"] == "idle":
break
return ExecutionResult(
is_ok=True, output="\n".join([str(output) for output in text_output]), data_items=data_output
)
class DockerJupyterServer(JupyterConnectable):
DEFAULT_DOCKERFILE = """FROM quay.io/jupyter/docker-stacks-foundation
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
USER ${NB_UID}
RUN mamba install --yes jupyter_kernel_gateway ipykernel && \
mamba clean --all -f -y && \
fix-permissions "${CONDA_DIR}" && \
fix-permissions "/home/${NB_USER}"
ENV TOKEN="UNSET"
CMD python -m jupyter kernelgateway --KernelGatewayApp.ip=0.0.0.0 \
--KernelGatewayApp.port=8888 \
--KernelGatewayApp.auth_token="${TOKEN}" \
--JupyterApp.answer_yes=true \
--JupyterWebsocketPersonality.list_kernels=true
EXPOSE 8888
WORKDIR "${HOME}"
"""
class GenerateToken:
pass
def __init__(
self,
*,
custom_image_name: Optional[str] = None,
container_name: Optional[str] = None,
auto_remove: bool = True,
stop_container: bool = True,
docker_env: Optional[Dict[str, str]] = None,
expose_port: int = 8888,
token: Optional[Union[str, GenerateToken]] = None,
work_dir: Union[Path, str] = "/workspace",
bind_dir: Optional[Union[Path, str]] = None,
):
"""Start a Jupyter kernel gateway server in a Docker container.
Args:
custom_image_name: Custom Docker image to use. If None, builds and uses bundled image.
container_name: Name for the Docker container. Auto-generated if None.
auto_remove: If True, container will be deleted when stopped.
stop_container: If True, container stops on program exit or when context manager exits.
docker_env: Additional environment variables for the container.
expose_port: Port to expose for Jupyter connection.
token: Authentication token. If GenerateToken, creates random token. Empty for no auth.
work_dir: Working directory inside the container.
bind_dir: Local directory to bind to container's work_dir.
"""
# Generate container name if not provided
container_name = container_name or f"agentdhal-jupyterkernelgateway-{uuid.uuid4()}"
# Initialize Docker client
client = docker.from_env()
# Set up bind directory if specified
self._bind_dir: Optional[Path] = None
if bind_dir:
self._bind_dir = Path(bind_dir) if isinstance(bind_dir, str) else bind_dir
self._bind_dir.mkdir(exist_ok=True)
os.chmod(bind_dir, 0o777)
# Determine and prepare Docker image
image_name = custom_image_name or "agentdhal-jupyterkernelgateway"
if not custom_image_name:
try:
client.images.get(image_name)
except docker.errors.ImageNotFound:
# Build default image if not found
here = Path(__file__).parent
dockerfile = io.BytesIO(self.DEFAULT_DOCKERFILE.encode("utf-8"))
logging.info(f"Building image {image_name}...")
client.images.build(path=str(here), fileobj=dockerfile, tag=image_name)
logging.info(f"Image {image_name} built successfully")
else:
# Verify custom image exists
try:
client.images.get(image_name)
except docker.errors.ImageNotFound as err:
raise ValueError(f"Custom image {image_name} does not exist") from err
if docker_env is None:
docker_env = {}
if token is None:
token = DockerJupyterServer.GenerateToken()
# Set up authentication token
self._token = secrets.token_hex(32) if isinstance(token, DockerJupyterServer.GenerateToken) else token
# Prepare environment variables
env = {"TOKEN": self._token}
env.update(docker_env)
# Define volume configuration if bind directory is specified
volumes = {str(self._bind_dir): {"bind": str(work_dir), "mode": "rw"}} if self._bind_dir else None
# Start the container
container = client.containers.run(
image_name,
detach=True,
auto_remove=auto_remove,
environment=env,
publish_all_ports=True,
name=container_name,
volumes=volumes,
working_dir=str(work_dir),
)
# Wait for container to be ready
self._wait_for_ready(container)
# Store container information
self._container = container
self._port = int(container.ports[f"{expose_port}/tcp"][0]["HostPort"])
self._container_id = container.id
self._expose_port = expose_port
if self._container_id is None:
raise ValueError("Failed to obtain container id.")
# Define cleanup function
def cleanup() -> None:
try:
assert self._container_id is not None
inner_container = client.containers.get(self._container_id)
inner_container.stop()
except docker.errors.NotFound:
pass
atexit.unregister(cleanup)
# Register cleanup if container should be stopped automatically
if stop_container:
atexit.register(cleanup)
self._cleanup_func = cleanup
self._stop_container = stop_container
@property
def connection_info(self) -> JupyterConnectionInfo:
return JupyterConnectionInfo(host="127.0.0.1", use_https=False, port=self._port, token=self._token)
def _wait_for_ready(self, container: Any, timeout: int = 60, stop_time: float = 0.1) -> None:
elapsed_time = 0.0
while container.status != "running" and elapsed_time < timeout:
sleep(stop_time)
elapsed_time += stop_time
container.reload()
continue
if container.status != "running":
raise ValueError("Container failed to start")
async def stop(self) -> None:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._cleanup_func)
async def get_client(self) -> JupyterClient:
return JupyterClient(self.connection_info)
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> None:
await self.stop()

View File

@@ -0,0 +1,6 @@
from ._jupyter_code_executor import JupyterCodeExecutor, JupyterCodeResult
__all__ = [
"JupyterCodeExecutor",
"JupyterCodeResult",
]

View File

@@ -0,0 +1,335 @@
import asyncio
import base64
import json
import re
import sys
import tempfile
import uuid
import warnings
from dataclasses import dataclass
from pathlib import Path
from agentdhal_core import Component
from pydantic import BaseModel
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
from contextlib import AbstractAsyncContextManager
from typing import Optional, Union
from agentdhal_core import CancellationToken
from agentdhal_core.code_executor import CodeBlock, CodeExecutor, CodeResult
from nbclient import NotebookClient
from nbformat import NotebookNode
from nbformat import v4 as nbformat
from typing_extensions import Self
from .._common import silence_pip
@dataclass
class JupyterCodeResult(CodeResult):
"""A code result class for Jupyter code executor."""
output_files: list[Path]
class JupyterCodeExecutorConfig(BaseModel):
"""Configuration for JupyterCodeExecutor"""
kernel_name: str = "python3"
timeout: int = 60
output_dir: Optional[str] = None
class JupyterCodeExecutor(CodeExecutor, Component[JupyterCodeExecutorConfig]):
"""A code executor class that executes code statefully using [nbclient](https://github.com/jupyter/nbclient).
.. danger::
This will execute code on the local machine. If being used with LLM generated code, caution should be used.
Example of using it directly:
.. code-block:: python
import asyncio
from agentdhal_core import CancellationToken
from agentdhal_core.code_executor import CodeBlock
from agentdhal_extensions.code_executors.jupyter import JupyterCodeExecutor
async def main() -> None:
async with JupyterCodeExecutor() as executor:
cancel_token = CancellationToken()
code_blocks = [CodeBlock(code="print('hello world!')", language="python")]
code_result = await executor.execute_code_blocks(code_blocks, cancel_token)
print(code_result)
asyncio.run(main())
Example of using it with :class:`~agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool`:
.. code-block:: python
import asyncio
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_extensions.code_executors.jupyter import JupyterCodeExecutor
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.tools.code_execution import PythonCodeExecutionTool
async def main() -> None:
async with JupyterCodeExecutor() as executor:
tool = PythonCodeExecutionTool(executor)
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent("assistant", model_client=model_client, tools=[tool])
result = await agent.run(task="What is the 10th Fibonacci number? Use Python to calculate it.")
print(result)
asyncio.run(main())
Example of using it inside a :class:`~agentdhal_agentchat.agents._code_executor_agent.CodeExecutorAgent`:
.. code-block:: python
import asyncio
from agentdhal_agentchat.agents import CodeExecutorAgent
from agentdhal_agentchat.messages import TextMessage
from agentdhal_extensions.code_executors.jupyter import JupyterCodeExecutor
from agentdhal_core import CancellationToken
async def main() -> None:
async with JupyterCodeExecutor() as executor:
code_executor_agent = CodeExecutorAgent("code_executor", code_executor=executor)
task = TextMessage(
content='''Here is some code
```python
print('Hello world')
```
''',
source="user",
)
response = await code_executor_agent.on_messages([task], CancellationToken())
print(response.chat_message)
asyncio.run(main())
Args:
kernel_name (str): The kernel name to use. By default, "python3".
timeout (int): The timeout for code execution, by default 60.
output_dir (Path): The directory to save output files, by default a temporary directory.
.. note::
Using the current directory (".") as output directory is deprecated. Using it will raise a deprecation warning.
"""
component_config_schema = JupyterCodeExecutorConfig
component_provider_override = "agentdhal_extensions.code_executors.jupyter.JupyterCodeExecutor"
def __init__(
self,
kernel_name: str = "python3",
timeout: int = 60,
output_dir: Optional[Union[Path, str]] = None,
):
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")
self._output_dir: Path = Path(tempfile.mkdtemp()) if output_dir is None else Path(output_dir)
self._output_dir.mkdir(exist_ok=True, parents=True)
self._temp_dir: Optional[tempfile.TemporaryDirectory[str]] = None
self._temp_dir_path: Optional[Path] = None
self._started = False
self._kernel_name = kernel_name
self._timeout = timeout
self._client: Optional[NotebookClient] = None
self.kernel_context: Optional[AbstractAsyncContextManager[None]] = None
async def execute_code_blocks(
self, code_blocks: list[CodeBlock], cancellation_token: CancellationToken
) -> JupyterCodeResult:
"""Execute code blocks and return the result.
Args:
code_blocks (list[CodeBlock]): The code blocks to execute.
Returns:
JupyterCodeResult: The result of the code execution.
"""
outputs: list[str] = []
output_files: list[Path] = []
exit_code = 0
for code_block in code_blocks:
result = await self._execute_code_block(code_block, cancellation_token)
exit_code = result.exit_code
outputs.append(result.output)
output_files.extend(result.output_files)
# Stop execution if one code block fails
if exit_code != 0:
break
return JupyterCodeResult(exit_code=exit_code, output="\n".join(outputs), output_files=output_files)
async def _execute_code_block(
self, code_block: CodeBlock, cancellation_token: CancellationToken
) -> JupyterCodeResult:
"""Execute single code block and return the result.
Args:
code_block (CodeBlock): The code block to execute.
Returns:
JupyterCodeResult: The result of the code execution.
"""
execute_task = asyncio.create_task(
self._execute_cell(
nbformat.new_code_cell(silence_pip(code_block.code, code_block.language)) # type: ignore
)
)
cancellation_token.link_future(execute_task)
output_cell = await asyncio.wait_for(asyncio.shield(execute_task), timeout=self._timeout)
outputs: list[str] = []
output_files: list[Path] = []
exit_code = 0
for output in output_cell.get("outputs", []):
match output.get("output_type"):
case "stream":
outputs.append(output.get("text", ""))
case "error":
traceback = re.sub(r"\x1b\[[0-9;]*[A-Za-z]", "", "\n".join(output["traceback"]))
outputs.append(traceback)
exit_code = 1
case "execute_result" | "display_data":
data = output.get("data", {})
for mime, content in data.items():
match mime:
case "text/plain":
outputs.append(content)
case "image/png":
path = self._save_image(content)
output_files.append(path)
case "image/jpeg":
# TODO: Should this also be encoded? Images are encoded as both png and jpg
pass
case "text/html":
path = self._save_html(content)
output_files.append(path)
case _:
outputs.append(json.dumps(content))
case _:
pass
return JupyterCodeResult(exit_code=exit_code, output="\n".join(outputs), output_files=output_files)
async def _execute_cell(self, cell: NotebookNode) -> NotebookNode:
# Temporary push cell to nb as async_execute_cell expects it. But then we want to remove it again as cells can take up significant amount of memory (especially with images)
if not self._client:
raise RuntimeError("Executor must be started before executing cells")
self._client.nb.cells.append(cell)
output = await self._client.async_execute_cell(
cell,
cell_index=0,
)
self._client.nb.cells.pop()
return output
def _save_image(self, image_data_base64: str) -> Path:
"""Save image data to a file."""
image_data = base64.b64decode(image_data_base64)
path = self._output_dir / f"{uuid.uuid4().hex}.png"
path.write_bytes(image_data)
return path.absolute()
def _save_html(self, html_data: str) -> Path:
"""Save HTML data to a file."""
path = self._output_dir / f"{uuid.uuid4().hex}.html"
path.write_text(html_data)
return path.absolute()
async def restart(self) -> None:
"""Restart the code executor."""
await self.stop()
await self.start()
async def start(self) -> None:
"""(Experimental) Start the code executor.
Initializes the Jupyter Notebook execution environment by creating a new notebook and setting it up with the specified Jupyter Kernel.
Marks the executor as started, allowing for code execution.
This method should be called before executing any code blocks.
"""
if self._started:
return
notebook: NotebookNode = nbformat.new_notebook() # type: ignore
self._client = NotebookClient(
nb=notebook,
kernel_name=self._kernel_name,
timeout=self._timeout,
allow_errors=True,
)
self.kernel_context = self._client.async_setup_kernel()
await self.kernel_context.__aenter__()
self._started = True
async def stop(self) -> None:
"""(Experimental) Stop the code executor.
Terminates the Jupyter Notebook execution by exiting the kernel context and cleaning up the associated resources."""
if not self._started:
return
if self.kernel_context is not None:
await self.kernel_context.__aexit__(None, None, None)
self.kernel_context = None
self._client = None
self._started = False
def _to_config(self) -> JupyterCodeExecutorConfig:
"""Convert current instance to config object"""
return JupyterCodeExecutorConfig(
kernel_name=self._kernel_name, timeout=self._timeout, output_dir=str(self.output_dir)
)
@property
def output_dir(self) -> Path:
# If a user specifies the current directory, warn them that this is deprecated
if self._output_dir == Path("."):
warnings.warn(
"Using the current directory as output_dir is deprecated",
DeprecationWarning,
stacklevel=2,
)
return self._output_dir
@classmethod
def _from_config(cls, config: JupyterCodeExecutorConfig) -> Self:
"""Create instance from config object"""
return cls(
kernel_name=config.kernel_name,
timeout=config.timeout,
output_dir=Path(config.output_dir) if config.output_dir else None,
)

View File

@@ -0,0 +1,517 @@
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/local_commandline_code_executor.py
# Credit to original authors
import asyncio
import logging
import os
import sys
import tempfile
import warnings
from hashlib import sha256
from pathlib import Path
from string import Template
from types import SimpleNamespace
from typing import Any, Callable, ClassVar, List, Optional, Sequence, Union
from agentdhal_core import CancellationToken, Component
from agentdhal_core.code_executor import CodeBlock, CodeExecutor, FunctionWithRequirements, FunctionWithRequirementsStr
from pydantic import BaseModel
from typing_extensions import ParamSpec, Self
from .._common import (
PYTHON_VARIANTS,
CommandLineCodeResult,
build_python_functions_file,
get_file_name_from_content,
lang_to_cmd,
silence_pip,
to_stub,
)
__all__ = ("LocalCommandLineCodeExecutor",)
A = ParamSpec("A")
class LocalCommandLineCodeExecutorConfig(BaseModel):
"""Configuration for LocalCommandLineCodeExecutor"""
timeout: int = 60
work_dir: Optional[str] = None
functions_module: str = "functions"
cleanup_temp_files: bool = True
class LocalCommandLineCodeExecutor(CodeExecutor, Component[LocalCommandLineCodeExecutorConfig]):
"""A code executor class that executes code through a local command line
environment.
.. danger::
This will execute code on the local machine. If being used with LLM generated code, caution should be used.
Each code block is saved as a file and executed in a separate process in
the working directory, and a unique file is generated and saved in the
working directory for each code block.
The code blocks are executed in the order they are received.
Command line code is sanitized using regular expression match against a list of dangerous commands in order to prevent self-destructive
commands from being executed which may potentially affect the users environment.
Currently the only supported languages is Python and shell scripts.
For Python code, use the language "python" for the code block.
For shell scripts, use the language "bash", "shell", "sh", "pwsh", "powershell", or "ps1" for the code
block.
.. note::
On Windows, the event loop policy must be set to `WindowsProactorEventLoopPolicy` to avoid issues with subprocesses.
.. code-block:: python
import sys
import asyncio
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
Args:
timeout (int): The timeout for the execution of any single code block. Default is 60.
work_dir (str): The working directory for the code execution. If None,
a default working directory will be used. The default working directory is a temporary directory.
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
functions_module (str, optional): The name of the module that will be created to store the functions. Defaults to "functions".
cleanup_temp_files (bool, optional): Whether to automatically clean up temporary files after execution. Defaults to True.
virtual_env_context (Optional[SimpleNamespace], optional): The virtual environment context. Defaults to None.
.. note::
Using the current directory (".") as working directory is deprecated. Using it will raise a deprecation warning.
Example:
How to use `LocalCommandLineCodeExecutor` with a virtual environment different from the one used to run the autogen application:
Set up a virtual environment using the `venv` module, and pass its context to the initializer of `LocalCommandLineCodeExecutor`. This way, the executor will run code within the new environment.
.. code-block:: python
import venv
from pathlib import Path
import asyncio
from agentdhal_core import CancellationToken
from agentdhal_core.code_executor import CodeBlock
from agentdhal_extensions.code_executors.local import LocalCommandLineCodeExecutor
async def example():
work_dir = Path("coding")
work_dir.mkdir(exist_ok=True)
venv_dir = work_dir / ".venv"
venv_builder = venv.EnvBuilder(with_pip=True)
venv_builder.create(venv_dir)
venv_context = venv_builder.ensure_directories(venv_dir)
local_executor = LocalCommandLineCodeExecutor(work_dir=work_dir, virtual_env_context=venv_context)
await local_executor.execute_code_blocks(
code_blocks=[
CodeBlock(language="bash", code="pip install matplotlib"),
],
cancellation_token=CancellationToken(),
)
asyncio.run(example())
"""
component_config_schema = LocalCommandLineCodeExecutorConfig
component_provider_override = "agentdhal_extensions.code_executors.local.LocalCommandLineCodeExecutor"
SUPPORTED_LANGUAGES: ClassVar[List[str]] = [
"bash",
"shell",
"sh",
"pwsh",
"powershell",
"ps1",
"python",
]
FUNCTION_PROMPT_TEMPLATE: ClassVar[
str
] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names.
For example, if there was a function called `foo` you could import it by writing `from $module_name import foo`
$functions"""
def __init__(
self,
timeout: int = 60,
work_dir: Optional[Union[Path, str]] = None,
functions: Sequence[
Union[
FunctionWithRequirements[Any, A],
Callable[..., Any],
FunctionWithRequirementsStr,
]
] = [],
functions_module: str = "functions",
cleanup_temp_files: bool = True,
virtual_env_context: Optional[SimpleNamespace] = None,
):
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")
self._timeout = timeout
self._work_dir: Optional[Path] = None
if work_dir is not None:
# Check if user provided work_dir is the current directory and warn if so.
if Path(work_dir).resolve() == Path.cwd().resolve():
warnings.warn(
"Using the current directory as work_dir is deprecated.",
DeprecationWarning,
stacklevel=2,
)
if isinstance(work_dir, str):
self._work_dir = Path(work_dir)
else:
self._work_dir = work_dir
self._work_dir.mkdir(exist_ok=True)
self._functions = functions
# Setup could take some time so we intentionally wait for the first code block to do it.
if len(functions) > 0:
self._setup_functions_complete = False
else:
self._setup_functions_complete = True
if not functions_module.isidentifier():
raise ValueError("Module name must be a valid Python identifier")
self._functions_module = functions_module
self._cleanup_temp_files = cleanup_temp_files
self._virtual_env_context: Optional[SimpleNamespace] = virtual_env_context
self._temp_dir: Optional[tempfile.TemporaryDirectory[str]] = None
self._started = False
# Check the current event loop policy if on windows.
if sys.platform == "win32":
current_policy = asyncio.get_event_loop_policy()
if hasattr(asyncio, "WindowsProactorEventLoopPolicy") and not isinstance(
current_policy, asyncio.WindowsProactorEventLoopPolicy
):
warnings.warn(
"The current event loop policy is not WindowsProactorEventLoopPolicy. "
"This may cause issues with subprocesses. "
"Try setting the event loop policy to WindowsProactorEventLoopPolicy. "
"For example: `asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())`. "
"See https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.ProactorEventLoop.",
stacklevel=2,
)
def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str:
"""(Experimental) Format the functions for a prompt.
The template includes two variables:
- `$module_name`: The module name.
- `$functions`: The functions formatted as stubs with two newlines between each function.
Args:
prompt_template (str): The prompt template. Default is the class default.
Returns:
str: The formatted prompt.
"""
template = Template(prompt_template)
return template.substitute(
module_name=self._functions_module,
functions="\n\n".join([to_stub(func) for func in self._functions]),
)
@property
def timeout(self) -> int:
"""(Experimental) The timeout for code execution."""
return self._timeout
@property
def work_dir(self) -> Path:
"""(Experimental) The working directory for the code execution."""
if self._work_dir is not None:
return self._work_dir
else:
# Automatically create temp directory if not exists
if self._temp_dir is None:
self._temp_dir = tempfile.TemporaryDirectory()
self._started = True
return Path(self._temp_dir.name)
@property
def functions(self) -> List[str]:
raise NotImplementedError
@property
def functions_module(self) -> str:
"""(Experimental) The module name for the functions."""
return self._functions_module
@property
def cleanup_temp_files(self) -> bool:
"""(Experimental) Whether to automatically clean up temporary files after execution."""
return self._cleanup_temp_files
async def _setup_functions(self, cancellation_token: CancellationToken) -> None:
func_file_content = build_python_functions_file(self._functions)
func_file = self.work_dir / f"{self._functions_module}.py"
func_file.write_text(func_file_content)
# Collect requirements
lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)]
flattened_packages = [item for sublist in lists_of_packages for item in sublist]
required_packages = list(set(flattened_packages))
if len(required_packages) > 0:
logging.info("Ensuring packages are installed in executor.")
cmd_args = ["-m", "pip", "install"]
cmd_args.extend(required_packages)
if self._virtual_env_context:
py_executable = self._virtual_env_context.env_exe
else:
py_executable = sys.executable
task = asyncio.create_task(
asyncio.create_subprocess_exec(
py_executable,
*cmd_args,
cwd=self.work_dir,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
)
cancellation_token.link_future(task)
try:
proc = await task
stdout, stderr = await asyncio.wait_for(proc.communicate(), self._timeout)
except asyncio.TimeoutError as e:
raise ValueError("Pip install timed out") from e
except asyncio.CancelledError as e:
raise ValueError("Pip install was cancelled") from e
if proc.returncode is not None and proc.returncode != 0:
raise ValueError(f"Pip install failed. {stdout.decode()}, {stderr.decode()}")
# Attempt to load the function file to check for syntax errors, imports etc.
exec_result = await self._execute_code_dont_check_setup(
[CodeBlock(code=func_file_content, language="python")], cancellation_token
)
if exec_result.exit_code != 0:
raise ValueError(f"Functions failed to load: {exec_result.output}")
self._setup_functions_complete = True
async def execute_code_blocks(
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
) -> CommandLineCodeResult:
"""(Experimental) Execute the code blocks and return the result.
Args:
code_blocks (List[CodeBlock]): The code blocks to execute.
cancellation_token (CancellationToken): a token to cancel the operation
Returns:
CommandLineCodeResult: The result of the code execution."""
if not self._setup_functions_complete:
await self._setup_functions(cancellation_token)
return await self._execute_code_dont_check_setup(code_blocks, cancellation_token)
async def _execute_code_dont_check_setup(
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
) -> CommandLineCodeResult:
"""
Execute the provided code blocks in the local command line without re-checking setup.
Returns a CommandLineCodeResult indicating success or failure.
"""
logs_all: str = ""
file_names: List[Path] = []
exitcode = 0
for code_block in code_blocks:
lang, code = code_block.language, code_block.code
lang = lang.lower()
# Remove pip output where possible
code = silence_pip(code, lang)
# Normalize python variants to "python"
if lang in PYTHON_VARIANTS:
lang = "python"
# Abort if not supported
if lang not in self.SUPPORTED_LANGUAGES:
exitcode = 1
logs_all += "\n" + f"unknown language {lang}"
break
# Try extracting a filename (if present)
try:
filename = get_file_name_from_content(code, self.work_dir)
except ValueError:
return CommandLineCodeResult(
exit_code=1,
output="Filename is not in the workspace",
code_file=None,
)
# If no filename is found, create one
if filename is None:
code_hash = sha256(code.encode()).hexdigest()
if lang.startswith("python"):
ext = "py"
elif lang in ["pwsh", "powershell", "ps1"]:
ext = "ps1"
else:
ext = lang
filename = f"tmp_code_{code_hash}.{ext}"
written_file = (self.work_dir / filename).resolve()
with written_file.open("w", encoding="utf-8") as f:
f.write(code)
file_names.append(written_file)
# Build environment
env = os.environ.copy()
if self._virtual_env_context:
virtual_env_bin_abs_path = os.path.abspath(self._virtual_env_context.bin_path)
env["PATH"] = f"{virtual_env_bin_abs_path}{os.pathsep}{env['PATH']}"
# Decide how to invoke the script
if lang == "python":
program = (
os.path.abspath(self._virtual_env_context.env_exe) if self._virtual_env_context else sys.executable
)
extra_args = [str(written_file.absolute())]
else:
# Get the appropriate command for the language
program = lang_to_cmd(lang)
# Special handling for PowerShell
if program == "pwsh":
extra_args = [
"-NoProfile",
"-ExecutionPolicy",
"Bypass",
"-File",
str(written_file.absolute()),
]
else:
# Shell commands (bash, sh, etc.)
extra_args = [str(written_file.absolute())]
# Create a subprocess and run
task = asyncio.create_task(
asyncio.create_subprocess_exec(
program,
*extra_args,
cwd=self.work_dir,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
)
)
cancellation_token.link_future(task)
proc = None # Track the process
try:
proc = await task
stdout, stderr = await asyncio.wait_for(proc.communicate(), self._timeout)
exitcode = proc.returncode or 0
except asyncio.TimeoutError:
logs_all += "\nTimeout"
exitcode = 124
if proc:
proc.terminate()
await proc.wait() # Ensure process is fully dead
break
except asyncio.CancelledError:
logs_all += "\nCancelled"
exitcode = 125
if proc:
proc.terminate()
await proc.wait()
break
logs_all += stderr.decode()
logs_all += stdout.decode()
if exitcode != 0:
break
code_file = str(file_names[0]) if file_names else None
code_result = CommandLineCodeResult(exit_code=exitcode, output=logs_all, code_file=code_file)
if self._cleanup_temp_files:
for file in file_names:
try:
file.unlink(missing_ok=True)
except OSError as error:
logging.error(f"Failed to delete temporary file {file}: {error}")
return code_result
async def restart(self) -> None:
"""(Experimental) Restart the code executor."""
warnings.warn(
"Restarting local command line code executor is not supported. No action is taken.",
stacklevel=2,
)
async def start(self) -> None:
"""(Experimental) Start the code executor.
Initializes the local code executor and should be called before executing any code blocks.
It marks the executor internal state as started.
If no working directory is provided, the method creates a temporary directory for the executor to use.
"""
if self._work_dir is None and self._temp_dir is None:
self._temp_dir = tempfile.TemporaryDirectory()
self._started = True
async def stop(self) -> None:
"""(Experimental) Stop the code executor.
Stops the local code executor and performs the cleanup of the temporary working directory (if it was created).
The executor's internal state is markes as no longer started.
"""
if self._temp_dir is not None:
self._temp_dir.cleanup()
self._temp_dir = None
self._started = False
pass
def _to_config(self) -> LocalCommandLineCodeExecutorConfig:
if self._functions:
logging.info("Functions will not be included in serialized configuration")
if self._virtual_env_context:
logging.info("Virtual environment context will not be included in serialized configuration")
return LocalCommandLineCodeExecutorConfig(
timeout=self._timeout,
work_dir=str(self.work_dir),
functions_module=self._functions_module,
cleanup_temp_files=self._cleanup_temp_files,
)
@classmethod
def _from_config(cls, config: LocalCommandLineCodeExecutorConfig) -> Self:
return cls(
timeout=config.timeout,
work_dir=Path(config.work_dir) if config.work_dir is not None else None,
functions_module=config.functions_module,
cleanup_temp_files=config.cleanup_temp_files,
)

View File

@@ -0,0 +1,210 @@
# Task-Centric Memory
_(EXPERIMENTAL, RESEARCH IN PROGRESS)_
**Task-Centric Memory** is an active research project aimed at giving AI agents the ability to:
* Accomplish general tasks more effectively by learning quickly and continually beyond context-window limitations.
* Remember guidance, corrections, plans, and demonstrations provided by users.
* Learn through the agent's own experience and adapt quickly to changing circumstances.
* Avoid repeating mistakes on tasks that are similar to those previously encountered.
## Installation
Install AutoGen and its extension package as follows:
```bash
pip install -U "autogen-agentchat" "autogen-ext[openai]" "autogen-ext[task-centric-memory]"
```
## Quickstart
<p align="right">
<img src="../../../../imgs/task_centric_memory_2.png" alt="Description" width="150" align="right" style="margin-left: 10px;">
</p>
This first code snippet runs a basic test to verify that the installation was successful,
as illustrated by the diagram to the right.
```python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.experimental.task_centric_memory import MemoryController
from autogen_ext.experimental.task_centric_memory.utils import PageLogger
async def main() -> None:
client = OpenAIChatCompletionClient(model="gpt-4o")
logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart"}) # Optional, but very useful.
memory_controller = MemoryController(reset=True, client=client, logger=logger)
# Add a few task-insight pairs as memories, where an insight can be any string that may help solve the task.
await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color")
await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan")
await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite")
# Retrieve memories for a new task that's related to only two of the stored memories.
memos = await memory_controller.retrieve_relevant_memos(task="What colors do I like most?")
print("{} memories retrieved".format(len(memos)))
for memo in memos:
print("- " + memo.insight)
asyncio.run(main())
```
<p align="right">
<img src="../../../../imgs/task_centric_memory_3.png" alt="Description" width="150" align="right" style="margin-left: 10px;">
</p>
This second code example shows one way to incorporate task-centric memory directly into an AutoGen agent,
in this case a subclass of RoutedAgent.
To keep the code short, only the simplest form of memory retrieval is exercised by this agent.
```python
import asyncio
from dataclasses import dataclass
from typing import List
from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.experimental.task_centric_memory import MemoryController
from autogen_ext.experimental.task_centric_memory.utils import PageLogger
@dataclass
class Message:
content: str
class MemoryEnabledAgent(RoutedAgent):
def __init__(
self, description: str, model_client: ChatCompletionClient, memory_controller: MemoryController
) -> None:
super().__init__(description)
self._model_client = model_client
self._memory_controller = memory_controller
@message_handler
async def handle_message(self, message: Message, context: MessageContext) -> Message:
# Retrieve relevant memories for the task.
memos = await self._memory_controller.retrieve_relevant_memos(task=message.content)
# Format the memories for the model.
formatted_memos = "Info that may be useful:\n" + "\n".join(["- " + memo.insight for memo in memos])
print(f"{'-' * 23}Text appended to the user message{'-' * 24}\n{formatted_memos}\n{'-' * 80}")
# Create the messages for the model with the retrieved memories.
messages: List[LLMMessage] = [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content=message.content, source="user"),
UserMessage(content=formatted_memos, source="user"),
]
# Call the model with the messages.
model_result = await self._model_client.create(messages=messages)
assert isinstance(model_result.content, str)
# Send the model's response to the user.
return Message(content=model_result.content)
async def main() -> None:
client = OpenAIChatCompletionClient(model="gpt-4o")
logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart2"}) # Optional, but very useful.
memory_controller = MemoryController(reset=True, client=client, logger=logger)
# Prepopulate memory to mimic learning from a prior session.
await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color")
await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan")
await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite")
# Create and start an agent runtime.
runtime = SingleThreadedAgentRuntime()
runtime.start()
# Register the agent type.
await MemoryEnabledAgent.register(
runtime,
"memory_enabled_agent",
lambda: MemoryEnabledAgent(
"A agent with memory", model_client=client, memory_controller=memory_controller
),
)
# Send a direct message to the agent.
request = "What colors do I like most?"
print("User request: " + request)
response = await runtime.send_message(
Message(content=request), AgentId("memory_enabled_agent", "default")
)
print("Agent response: " + response.content)
# Stop the agent runtime.
await runtime.stop()
asyncio.run(main())
```
## Sample Code
The example above modifies the agent's code.
But it's also possible to add task-centric memory to an agent or multi-agent team _without_ modifying any agent code.
See the [sample code](../../../../../../samples/task_centric_memory) for that and other forms of fast, memory-based learning.
## Architecture
<p align="right">
<img src="../../../../imgs/task_centric_memory.png" alt="Description" width="300" align="right" style="margin-left: 10px;">
</p>
The block diagram to the right outlines the key components of the architecture in the most general form.
The memory components are shown in blue, and the green blocks represent external components.
The **Memory Controller** implements the fast-learning methods described below,
and manages communication with a **Memory Bank** containing a vector DB and associated structures.
The **Agent or Team** is the AI agent or team of agents to which memory is being added.
The sample code shows how to add task-centric memory to a simple AssistantAgent or a MagenticOneGroupChat team.
The **Apprentice, app, or service** represents the code that instantiates the agent and memory controller,
and routes information between them, effectively wrapping agent and memory into a combined component.
The term _Apprentice_ connotes that this combination uses memory to learn quickly on the job.
The Apprentice class is a minimal reference implementation provided as utility code for illustration and testing,
but most applications will use their own code instead of the Apprentice.
## Memory Creation and Storage
Each stored memory (called a _memo_) contains a text insight and (optionally) a task description.
The insight is intended to help the agent accomplish future tasks that are similar to a prior task.
The memory controller provides methods for different types of learning.
If the user provides advice for solving a given task, the advice is extracted by the model client and stored as an insight.
If the user demonstrates how to perform a task,
the task and demonstration are stored together as an insight used to solve similar but different tasks.
If the agent is given a task (free of side-effects) and some means of determining success or failure,
the memory controller repeats the following learning loop in the background some number of times:
1. Test the agent on the task a few times to check for a failure.
2. If a failure is found, analyze the agent's response in order to:
1. Diagnose the failure of reasoning or missing information,
2. Phrase a general piece of advice, such as what a teacher might give to a student,
3. Temporarily append this advice to the task description,
4. Return to step 1.
5. If some piece of advice succeeds in helping the agent solve the task a number of times, add the advice as an insight to memory.
3. For each insight to be stored in memory, an LLM is prompted to generate a set of free-form, multi-word topics related to the insight. Each topic is embedded to a fixed-length vector and stored in a vector DB mapping it to the topics related insight.
## Memory Retrieval and Usage
The memory controller provides methods for different types of memory retrieval.
When the agent is given a task, the following steps are performed by the controller:
1. The task is rephrased into a generalized form.
2. A set of free-form, multi-word query topics are generated from the generalized task.
3. A potentially large number of previously stored topics, those most similar to each query topic, are retrieved from the vector DB along with the insights they map to.
4. These candidate memos are filtered by the aggregate similarity of their stored topics to the query topics.
5. In the final filtering stage, an LLM is prompted to validate only those insights that seem potentially useful in solving the task at hand.
Retrieved insights that pass the filtering steps are listed under a heading like
"Important insights that may help solve tasks like this", then appended to the task description before it is passed to the agent as usual.

View File

@@ -0,0 +1,4 @@
from ._memory_bank import MemoryBankConfig
from .memory_controller import MemoryController, MemoryControllerConfig
__all__ = ["MemoryController", "MemoryControllerConfig", "MemoryBankConfig"]

View File

@@ -0,0 +1,201 @@
import os
import pickle
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, TypedDict
from ._string_similarity_map import StringSimilarityMap
from .utils.page_logger import PageLogger
@dataclass
class Memo:
"""
Represents an atomic unit of memory that can be stored in a memory bank and later retrieved.
"""
task: str | None # The task description, if any.
insight: str # A hint, solution, plan, or any other text that may help solve a similar task.
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
# the settings that change frequently, as when loading many settings from a single YAML file.
class MemoryBankConfig(TypedDict, total=False):
path: str
relevance_conversion_threshold: float
n_results: int
distance_threshold: int
class MemoryBank:
"""
Stores task-completion insights as memories in a vector DB for later retrieval.
Args:
reset: True to clear the DB before starting.
config: An optional dict that can be used to override the following values:
- path: The path to the directory where the memory bank files are stored.
- relevance_conversion_threshold: The threshold used to normalize relevance.
- n_results: The maximum number of most relevant results to return for any given topic.
- distance_threshold: The maximum string-pair distance for a memo to be retrieved.
logger: An optional logger. If None, no logging will be performed.
"""
def __init__(
self,
reset: bool,
config: MemoryBankConfig | None = None,
logger: PageLogger | None = None,
) -> None:
if logger is None:
logger = PageLogger() # Nothing will be logged by this object.
self.logger = logger
self.logger.enter_function()
# Apply default settings and any config overrides.
memory_dir_path = "./memory_bank/default"
self.relevance_conversion_threshold = 1.7
self.n_results = 25
self.distance_threshold = 100
if config is not None:
memory_dir_path = config.get("path", memory_dir_path)
self.relevance_conversion_threshold = config.get(
"relevance_conversion_threshold", self.relevance_conversion_threshold
)
self.n_results = config.get("n_results", self.n_results)
self.distance_threshold = config.get("distance_threshold", self.distance_threshold)
memory_dir_path = os.path.expanduser(memory_dir_path)
self.logger.info("\nMEMORY BANK DIRECTORY {}".format(memory_dir_path))
path_to_db_dir = os.path.join(memory_dir_path, "string_map")
self.path_to_dict = os.path.join(memory_dir_path, "uid_memo_dict.pkl")
self.string_map = StringSimilarityMap(reset=reset, path_to_db_dir=path_to_db_dir, logger=self.logger)
# Load or create the associated memo dict on disk.
self.uid_memo_dict: Dict[str, Memo] = {}
self.last_memo_id = 0
if (not reset) and os.path.exists(self.path_to_dict):
self.logger.info("\nLOADING MEMOS FROM DISK at {}".format(self.path_to_dict))
with open(self.path_to_dict, "rb") as f:
self.uid_memo_dict = pickle.load(f)
self.last_memo_id = len(self.uid_memo_dict)
self.logger.info("\n{} MEMOS LOADED".format(len(self.uid_memo_dict)))
# Clear the DB if requested.
if reset:
self._reset_memos()
self.logger.leave_function()
def reset(self) -> None:
"""
Forces immediate deletion of all contents, in memory and on disk.
"""
self.string_map.reset_db()
self._reset_memos()
def _reset_memos(self) -> None:
"""
Forces immediate deletion of the memos, in memory and on disk.
"""
self.logger.info("\nCLEARING MEMOS")
self.uid_memo_dict = {}
self.save_memos()
def save_memos(self) -> None:
"""
Saves the current memo structures (possibly empty) to disk.
"""
self.string_map.save_string_pairs()
with open(self.path_to_dict, "wb") as file:
self.logger.info("\nSAVING MEMOS TO DISK at {}".format(self.path_to_dict))
pickle.dump(self.uid_memo_dict, file)
def contains_memos(self) -> bool:
"""
Returns True if the memory bank contains any memo.
"""
return len(self.uid_memo_dict) > 0
def _map_topics_to_memo(self, topics: List[str], memo_id: str, memo: Memo) -> None:
"""
Adds a mapping in the vec DB from each topic to the memo.
"""
self.logger.enter_function()
self.logger.info("\nINSIGHT\n{}".format(memo.insight))
for topic in topics:
self.logger.info("\n TOPIC = {}".format(topic))
self.string_map.add_input_output_pair(topic, memo_id)
self.uid_memo_dict[memo_id] = memo
self.save_memos()
self.logger.leave_function()
def add_memo(self, insight_str: str, topics: List[str], task_str: Optional[str] = None) -> None:
"""
Adds an insight to the memory bank, given topics related to the insight, and optionally the task.
"""
self.logger.enter_function()
self.last_memo_id += 1
id_str = str(self.last_memo_id)
insight = Memo(insight=insight_str, task=task_str)
self._map_topics_to_memo(topics, id_str, insight)
self.logger.leave_function()
def add_task_with_solution(self, task: str, solution: str, topics: List[str]) -> None:
"""
Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight.
This is useful when the insight is a demonstration of how to solve a given type of task.
"""
self.logger.enter_function()
self.last_memo_id += 1
id_str = str(self.last_memo_id)
# Prepend the insight to the task description for context.
insight_str = "Example task:\n\n{}\n\nExample solution:\n\n{}".format(task, solution)
memo = Memo(insight=insight_str, task=task)
self._map_topics_to_memo(topics, id_str, memo)
self.logger.leave_function()
def get_relevant_memos(self, topics: List[str]) -> List[Memo]:
"""
Returns any memos from the memory bank that appear sufficiently relevant to the input topics.
"""
self.logger.enter_function()
# Retrieve all topic matches, and gather them into a single list.
matches: List[Tuple[str, str, float]] = [] # Each match is a tuple: (topic, memo_id, distance)
for topic in topics:
matches.extend(self.string_map.get_related_string_pairs(topic, self.n_results, self.distance_threshold))
# Build a dict of memo-relevance pairs from the matches.
memo_relevance_dict: Dict[str, float] = {}
for match in matches:
relevance = self.relevance_conversion_threshold - match[2]
memo_id = match[1]
if memo_id in memo_relevance_dict:
memo_relevance_dict[memo_id] += relevance
else:
memo_relevance_dict[memo_id] = relevance
# Log the details of all the retrieved memos.
self.logger.info("\n{} POTENTIALLY RELEVANT MEMOS".format(len(memo_relevance_dict)))
for memo_id, relevance in memo_relevance_dict.items():
memo = self.uid_memo_dict[memo_id]
details = ""
if memo.task is not None:
details += "\n TASK: {}\n".format(memo.task)
details += "\n INSIGHT: {}\n\n RELEVANCE: {:.3f}\n".format(memo.insight, relevance)
self.logger.info(details)
# Sort the memo-relevance pairs by relevance, in descending order.
memo_relevance_dict = dict(sorted(memo_relevance_dict.items(), key=lambda item: item[1], reverse=True))
# Compose the list of sufficiently relevant memos to return.
memo_list: List[Memo] = []
for memo_id in memo_relevance_dict:
if memo_relevance_dict[memo_id] >= 0:
memo_list.append(self.uid_memo_dict[memo_id])
self.logger.leave_function()
return memo_list

View File

@@ -0,0 +1,289 @@
import time
from typing import List, Union
from agentdhal_core import Image
from agentdhal_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
LLMMessage,
SystemMessage,
UserMessage,
)
from .utils._functions import UserContent
from .utils.page_logger import PageLogger
class Prompter:
"""
Centralizes most of the Apprentice prompts sent to the model client.
Args:
client: The client to call the model.
logger: An optional logger. If None, no logging will be performed.
"""
def __init__(self, client: ChatCompletionClient, logger: PageLogger | None = None) -> None:
if logger is None:
logger = PageLogger() # Nothing will be logged by this object.
self.logger = logger
self.client = client
self.default_system_message_content = "You are a helpful assistant."
self.time_spent_in_model_calls = 0.0
self.num_model_calls = 0
self.start_time = time.time()
# Create the chat history
self._chat_history: List[LLMMessage] = []
async def call_model(
self,
summary: str,
user_content: UserContent,
system_message_content: str | None = None,
keep_these_messages: bool = True,
) -> str:
"""
Calls the model client with the given input and returns the response.
"""
# Prepare the input message list
if system_message_content is None:
system_message_content = self.default_system_message_content
system_message: LLMMessage
if self.client.model_info["family"] == "o1":
# No system message allowed, so pass it as the first user message.
system_message = UserMessage(content=system_message_content, source="User")
else:
# System message allowed.
system_message = SystemMessage(content=system_message_content)
user_message = UserMessage(content=user_content, source="User")
input_messages = [system_message] + self._chat_history + [user_message]
# Double check the types of the input messages.
for message in input_messages:
for part in message.content:
assert isinstance(part, str) or isinstance(part, Image), "Invalid message content type: {}".format(
type(part)
)
# Call the model
start_time = time.time()
response = await self.client.create(input_messages)
assert isinstance(response, CreateResult)
response_string = response.content
assert isinstance(response_string, str)
response_message = AssistantMessage(content=response_string, source="Assistant")
assert isinstance(response_message, AssistantMessage)
self.time_spent_in_model_calls += time.time() - start_time
self.num_model_calls += 1
# Log the model call
self.logger.log_model_call(summary=summary, input_messages=input_messages, response=response)
# Manage the chat history
if keep_these_messages:
self._chat_history.append(user_message)
self._chat_history.append(response_message)
# Return the response as a string for now
return response_string
def _clear_history(self) -> None:
"""
Empties the message list containing the chat history.
"""
self._chat_history = []
async def learn_from_failure(
self, task_description: str, memory_section: str, final_response: str, expected_answer: str, work_history: str
) -> str:
"""
Tries to create an insight to help avoid the given failure in the future.
"""
sys_message = """- You are a patient and thorough teacher.
- Your job is to review work done by students and help them learn how to do better."""
user_message: List[Union[str, Image]] = []
user_message.append("# A team of students made a mistake on the following task:\n")
user_message.extend([task_description])
if len(memory_section) > 0:
user_message.append(memory_section)
user_message.append("# Here's the expected answer, which would have been correct:\n")
user_message.append(expected_answer)
user_message.append("# Here is the students' answer, which was INCORRECT:\n")
user_message.append(final_response)
user_message.append("# Please review the students' work which follows:\n")
user_message.append("**----- START OF STUDENTS' WORK -----**\n\n")
user_message.append(work_history)
user_message.append("\n**----- END OF STUDENTS' WORK -----**\n\n")
user_message.append(
"# Now carefully review the students' work above, explaining in detail what the students did right and what they did wrong.\n"
)
self._clear_history()
await self.call_model(
summary="Ask the model to learn from this failure",
system_message_content=sys_message,
user_content=user_message,
)
user_message = [
"Now put yourself in the mind of the students. What misconception led them to their incorrect answer?"
]
await self.call_model(
summary="Ask the model to state the misconception",
system_message_content=sys_message,
user_content=user_message,
)
user_message = [
"Please express your key insights in the form of short, general advice that will be given to the students. Just one or two sentences, or they won't bother to read it."
]
insight = await self.call_model(
summary="Ask the model to formulate a concise insight",
system_message_content=sys_message,
user_content=user_message,
)
return insight
async def find_index_topics(self, input_string: str) -> List[str]:
"""
Returns a list of topics related to the given string.
"""
sys_message = """You are an expert at semantic analysis."""
user_message: List[Union[str, Image]] = []
user_message.append("""- My job is to create a thorough index for a book called Task Completion, and I need your help.
- Every paragraph in the book needs to be indexed by all the topics related to various kinds of tasks and strategies for completing them.
- Your job is to read the text below and extract the task-completion topics that are covered.
- The number of topics depends on the length and content of the text. But you should list at least one topic, and potentially many more.
- Each topic you list should be a meaningful phrase composed of a few words. Don't use whole sentences as topics.
- Don't include details that are unrelated to the general nature of the task, or a potential strategy for completing tasks.
- List each topic on a separate line, without any extra text like numbering, or bullets, or any other formatting, because we don't want those things in the index of the book.\n\n""")
user_message.append("# Text to be indexed\n")
user_message.append(input_string)
self._clear_history()
topics = await self.call_model(
summary="Ask the model to extract topics", system_message_content=sys_message, user_content=user_message
)
# Parse the topics into a list.
topic_list: List[str] = []
for line in topics.split("\n"):
if len(line) > 0:
topic_list.append(line)
return topic_list
async def generalize_task(self, task_description: str, revise: bool | None = True) -> str:
"""
Attempts to rewrite a task description in a more general form.
"""
sys_message = """You are a helpful and thoughtful assistant."""
user_message: List[Union[str, Image]] = [
"We have been given a task description. Our job is not to complete the task, but merely rephrase the task in simpler, more general terms, if possible. Please reach through the following task description, then explain your understanding of the task in detail, as a single, flat list of all the important points."
]
user_message.append("\n# Task description")
user_message.append(task_description)
self._clear_history()
generalized_task = await self.call_model(
summary="Ask the model to rephrase the task in a list of important points",
system_message_content=sys_message,
user_content=user_message,
)
if revise:
user_message = [
"Do you see any parts of this list that are irrelevant to actually solving the task? If so, explain which items are irrelevant."
]
await self.call_model(
summary="Ask the model to identify irrelevant points",
system_message_content=sys_message,
user_content=user_message,
)
user_message = [
"Revise your original list to include only the most general terms, those that are critical to solving the task, removing any themes or descriptions that are not essential to the solution. Your final list may be shorter, but do not leave out any part of the task that is needed for solving the task. Do not add any additional commentary either before or after the list."
]
generalized_task = await self.call_model(
summary="Ask the model to make a final list of general terms",
system_message_content=sys_message,
user_content=user_message,
)
return generalized_task
async def validate_insight(self, insight: str, task_description: str) -> bool:
"""
Judges whether the insight could help solve the task.
"""
sys_message = """You are a helpful and thoughtful assistant."""
user_message: List[Union[str, Image]] = [
"""We have been given a potential insight that may or may not be useful for solving a given task.
- First review the following task.
- Then review the insight that follows, and consider whether it might help solve the given task.
- Do not attempt to actually solve the task.
- Reply with a single character, '1' if the insight may be useful, or '0' if it is not."""
]
user_message.append("\n# Task description")
user_message.append(task_description)
user_message.append("\n# Possibly useful insight")
user_message.append(insight)
self._clear_history()
response = await self.call_model(
summary="Ask the model to validate the insight",
system_message_content=sys_message,
user_content=user_message,
)
return response == "1"
async def extract_task(self, text: str) -> str | None:
"""
Returns a task found in the given text, or None if not found.
"""
sys_message = """You are a helpful and thoughtful assistant."""
user_message: List[Union[str, Image]] = [
"""Does the following text contain a question or a some task we are being asked to perform?
- If so, please reply with the full question or task description, along with any supporting information, but without adding extra commentary or formatting.
- If the task is just to remember something, that doesn't count as a task, so don't include it.
- If there is no question or task in the text, simply write "None" with no punctuation."""
]
user_message.append("\n# Text to analyze")
user_message.append(text)
self._clear_history()
response = await self.call_model(
summary="Ask the model to extract a task", system_message_content=sys_message, user_content=user_message
)
return response if response != "None" else None
async def extract_advice(self, text: str) -> str | None:
"""
Returns advice from the given text, or None if not found.
"""
sys_message = """You are a helpful and thoughtful assistant."""
user_message: List[Union[str, Image]] = [
"""Does the following text contain any information or advice that might be useful later?
- If so, please copy the information or advice, adding no extra commentary or formatting.
- If there is no potentially useful information or advice at all, simply write "None" with no punctuation."""
]
user_message.append("\n# Text to analyze")
user_message.append(text)
self._clear_history()
response = await self.call_model(
summary="Ask the model to extract advice", system_message_content=sys_message, user_content=user_message
)
return response if response != "None" else None

View File

@@ -0,0 +1,124 @@
import os
import pickle
from typing import Dict, List, Tuple, Union
import chromadb
from chromadb.api.types import (
QueryResult,
)
from chromadb.config import Settings
from .utils.page_logger import PageLogger
class StringSimilarityMap:
"""
Provides storage and similarity-based retrieval of string pairs using a vector database.
Each DB entry is a pair of strings: an input string and an output string.
The input string is embedded and used as the retrieval key.
The output string can be anything, but it's typically used as a dict key.
Vector embeddings are currently supplied by Chroma's default Sentence Transformers.
Args:
- reset: True to clear the DB immediately after creation.
- path_to_db_dir: Path to the directory where the DB is stored.
- logger: An optional logger. If None, no logging will be performed.
"""
def __init__(self, reset: bool, path_to_db_dir: str, logger: PageLogger | None = None) -> None:
if logger is None:
logger = PageLogger() # Nothing will be logged by this object.
self.logger = logger
self.path_to_db_dir = path_to_db_dir
# Load or create the vector DB on disk.
chromadb_settings = Settings(
anonymized_telemetry=False, allow_reset=True, is_persistent=True, persist_directory=path_to_db_dir
)
self.db_client = chromadb.Client(chromadb_settings)
self.vec_db = self.db_client.create_collection("string-pairs", get_or_create=True) # The collection is the DB.
# Load or create the associated string-pair dict on disk.
self.path_to_dict = os.path.join(path_to_db_dir, "uid_text_dict.pkl")
self.uid_text_dict: Dict[str, Tuple[str, str]] = {}
self.last_string_pair_id = 0
if (not reset) and os.path.exists(self.path_to_dict):
self.logger.debug("\nLOADING STRING SIMILARITY MAP FROM DISK at {}".format(self.path_to_dict))
with open(self.path_to_dict, "rb") as f:
self.uid_text_dict = pickle.load(f)
self.last_string_pair_id = len(self.uid_text_dict)
if len(self.uid_text_dict) > 0:
self.logger.debug("\n{} STRING PAIRS LOADED".format(len(self.uid_text_dict)))
self._log_string_pairs()
# Clear the DB if requested.
if reset:
self.reset_db()
def _log_string_pairs(self) -> None:
"""
Logs all string pairs currently in the map.
"""
self.logger.debug("LIST OF STRING PAIRS")
for uid, text in self.uid_text_dict.items():
input_text, output_text = text
self.logger.debug(" ID: {}\n INPUT TEXT: {}\n OUTPUT TEXT: {}".format(uid, input_text, output_text))
def save_string_pairs(self) -> None:
"""
Saves the string-pair dict (self.uid_text_dict) to disk.
"""
self.logger.debug("\nSAVING STRING SIMILARITY MAP TO DISK at {}".format(self.path_to_dict))
with open(self.path_to_dict, "wb") as file:
pickle.dump(self.uid_text_dict, file)
def reset_db(self) -> None:
"""
Forces immediate deletion of the DB's contents, in memory and on disk.
"""
self.logger.debug("\nCLEARING STRING-PAIR MAP")
self.db_client.delete_collection("string-pairs")
self.vec_db = self.db_client.create_collection("string-pairs")
self.uid_text_dict = {}
self.save_string_pairs()
def add_input_output_pair(self, input_text: str, output_text: str) -> None:
"""
Adds one input-output string pair to the DB.
"""
self.last_string_pair_id += 1
self.vec_db.add(documents=[input_text], ids=[str(self.last_string_pair_id)])
self.uid_text_dict[str(self.last_string_pair_id)] = input_text, output_text
self.logger.debug(
"\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {}\n INPUT\n {}\n OUTPUT\n {}\n".format(
self.last_string_pair_id, input_text, output_text
)
)
# self._log_string_pairs() # For deeper debugging, uncomment to log all string pairs after each addition.
def get_related_string_pairs(
self, query_text: str, n_results: int, threshold: Union[int, float]
) -> List[Tuple[str, str, float]]:
"""
Retrieves up to n string pairs that are related to the given query text within the specified distance threshold.
"""
string_pairs_with_distances: List[Tuple[str, str, float]] = []
if n_results > len(self.uid_text_dict):
n_results = len(self.uid_text_dict)
if n_results > 0:
results: QueryResult = self.vec_db.query(query_texts=[query_text], n_results=n_results)
num_results = len(results["ids"][0])
for i in range(num_results):
uid = results["ids"][0][i]
input_text = results["documents"][0][i] if results["documents"] else ""
distance = results["distances"][0][i] if results["distances"] else 0.0
if distance < threshold:
input_text_2, output_text = self.uid_text_dict[uid]
assert input_text == input_text_2
self.logger.debug(
"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {}\n OUTPUT\n {}\n DISTANCE\n {}".format(
input_text, output_text, distance
)
)
string_pairs_with_distances.append((input_text, output_text, distance))
return string_pairs_with_distances

View File

@@ -0,0 +1,478 @@
from typing import TYPE_CHECKING, Awaitable, Callable, List, Tuple, TypedDict
from agentdhal_core.models import (
ChatCompletionClient,
)
from ._memory_bank import Memo, MemoryBank
from ._prompter import Prompter
if TYPE_CHECKING:
from ._memory_bank import MemoryBankConfig
from .utils.grader import Grader
from .utils.page_logger import PageLogger
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
# the settings that change frequently, as when loading many settings from a single YAML file.
class MemoryControllerConfig(TypedDict, total=False):
generalize_task: bool
revise_generalized_task: bool
generate_topics: bool
validate_memos: bool
max_memos_to_retrieve: int
max_train_trials: int
max_test_trials: int
MemoryBank: "MemoryBankConfig"
class MemoryController:
"""
(EXPERIMENTAL, RESEARCH IN PROGRESS)
Implements fast, memory-based learning, and manages the flow of information to and from a memory bank.
Args:
reset: True to empty the memory bank before starting.
client: The model client to use internally.
task_assignment_callback: An optional callback used to assign a task to any agent managed by the caller.
config: An optional dict that can be used to override the following values:
- generalize_task: Whether to rewrite tasks in more general terms.
- revise_generalized_task: Whether to critique then rewrite the generalized task.
- generate_topics: Whether to base retrieval directly on tasks, or on topics extracted from tasks.
- validate_memos: Whether to apply a final validation stage to retrieved memos.
- max_memos_to_retrieve: The maximum number of memos to return from retrieve_relevant_memos().
- max_train_trials: The maximum number of learning iterations to attempt when training on a task.
- max_test_trials: The total number of attempts made when testing for failure on a task.
- MemoryBank: A config dict passed to MemoryBank.
logger: An optional logger. If None, a default logger will be created.
Example:
The `task-centric-memory` extra first needs to be installed:
.. code-block:: bash
pip install "agentdhal-ext[task-centric-memory]"
The following code snippet shows how to use this class for the most basic storage and retrieval of memories.:
.. code-block:: python
import asyncio
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.experimental.task_centric_memory import MemoryController
from agentdhal_extensions.experimental.task_centric_memory.utils import PageLogger
async def main() -> None:
client = OpenAIChatCompletionClient(model="gpt-4o")
logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart"}) # Optional, but very useful.
memory_controller = MemoryController(reset=True, client=client, logger=logger)
# Add a few task-insight pairs as memories, where an insight can be any string that may help solve the task.
await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color")
await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan")
await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite")
# Retrieve memories for a new task that's related to only two of the stored memories.
memos = await memory_controller.retrieve_relevant_memos(task="What colors do I like most?")
print("{} memories retrieved".format(len(memos)))
for memo in memos:
print("- " + memo.insight)
asyncio.run(main())
"""
def __init__(
self,
reset: bool,
client: ChatCompletionClient,
task_assignment_callback: Callable[[str], Awaitable[Tuple[str, str]]] | None = None,
config: MemoryControllerConfig | None = None,
logger: PageLogger | None = None,
) -> None:
if logger is None:
logger = PageLogger({"level": "DEBUG"})
self.logger = logger
self.logger.enter_function()
# Apply default settings and any config overrides.
self.generalize_task = True
self.revise_generalized_task = True
self.generate_topics = True
self.validate_memos = True
self.max_memos_to_retrieve = 10
self.max_train_trials = 10
self.max_test_trials = 3
memory_bank_config = None
if config is not None:
self.generalize_task = config.get("generalize_task", self.generalize_task)
self.revise_generalized_task = config.get("revise_generalized_task", self.revise_generalized_task)
self.generate_topics = config.get("generate_topics", self.generate_topics)
self.validate_memos = config.get("validate_memos", self.validate_memos)
self.max_memos_to_retrieve = config.get("max_memos_to_retrieve", self.max_memos_to_retrieve)
self.max_train_trials = config.get("max_train_trials", self.max_train_trials)
self.max_test_trials = config.get("max_test_trials", self.max_test_trials)
memory_bank_config = config.get("MemoryBank", memory_bank_config)
self.client = client
self.task_assignment_callback = task_assignment_callback
self.prompter = Prompter(client, logger)
self.memory_bank = MemoryBank(reset=reset, config=memory_bank_config, logger=logger)
self.grader = Grader(client, logger)
self.logger.leave_function()
def reset_memory(self) -> None:
"""
Empties the memory bank in RAM and on disk.
"""
self.memory_bank.reset()
async def train_on_task(self, task: str, expected_answer: str) -> None:
"""
Repeatedly assigns a task to the agent, and tries to learn from failures by creating useful insights as memories.
"""
self.logger.enter_function()
self.logger.info("Iterate on the task, possibly discovering a useful new insight.\n")
_, insight = await self._iterate_on_task(task, expected_answer)
if insight is None:
self.logger.info("No useful insight was discovered.\n")
else:
self.logger.info("A new insight was created:\n{}".format(insight))
await self.add_memo(insight, task)
self.logger.leave_function()
async def test_on_task(self, task: str, expected_answer: str, num_trials: int = 1) -> Tuple[str, int, int]:
"""
Assigns a task to the agent, along with any relevant memos retrieved from memory.
"""
self.logger.enter_function()
assert self.task_assignment_callback is not None
response = ""
num_successes = 0
for trial in range(num_trials):
self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1))
task_plus_insights = task
# Try to retrieve any relevant memories from the DB.
filtered_memos = await self.retrieve_relevant_memos(task)
filtered_insights = [memo.insight for memo in filtered_memos]
if len(filtered_insights) > 0:
self.logger.info("Relevant insights were retrieved from memory.\n")
memory_section = self._format_memory_section(filtered_insights)
if len(memory_section) > 0:
task_plus_insights = task + "\n\n" + memory_section
# Attempt to solve the task.
self.logger.info("Try to solve the task.\n")
response, _ = await self.task_assignment_callback(task_plus_insights)
# Check if the response is correct.
response_is_correct, extracted_answer = await self.grader.is_response_correct(
task, response, expected_answer
)
self.logger.info("Extracted answer: {}".format(extracted_answer))
if response_is_correct:
self.logger.info("Answer is CORRECT.\n")
num_successes += 1
else:
self.logger.info("Answer is INCORRECT.\n")
# Calculate the success rate as a percentage, rounded to the nearest whole number.
self.logger.info("\nSuccess rate: {}%\n".format(round((num_successes / num_trials) * 100)))
self.logger.leave_function()
return response, num_successes, num_trials
async def add_memo(self, insight: str, task: None | str = None, index_on_both: bool = True) -> None:
"""
Adds one insight to the memory bank, using the task (if provided) as context.
"""
self.logger.enter_function()
generalized_task = ""
if task is not None:
self.logger.info("\nGIVEN TASK:")
self.logger.info(task)
if self.generalize_task:
generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task)
else:
generalized_task = task
self.logger.info("\nGIVEN INSIGHT:")
self.logger.info(insight)
# Get a list of topics from the insight and the task (if provided).
if task is None:
text_to_index = insight
self.logger.info("\nTOPICS EXTRACTED FROM INSIGHT:")
else:
if index_on_both:
text_to_index = generalized_task.strip() + "\n(Hint: " + insight + ")"
self.logger.info("\nTOPICS EXTRACTED FROM TASK AND INSIGHT COMBINED:")
else:
text_to_index = task
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
if self.generate_topics:
topics = await self.prompter.find_index_topics(text_to_index)
else:
topics = [text_to_index]
self.logger.info("\n".join(topics))
self.logger.info("")
# Add the insight to the memory bank.
self.memory_bank.add_memo(insight, topics, task)
self.logger.leave_function()
async def add_task_solution_pair_to_memory(self, task: str, solution: str) -> None:
"""
Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight.
This is useful when the task-solution pair is an exemplar of solving a task related to some other task.
"""
self.logger.enter_function()
self.logger.info("\nEXAMPLE TASK:")
self.logger.info(task)
self.logger.info("\nEXAMPLE SOLUTION:")
self.logger.info(solution)
# Get a list of topics from the task.
if self.generate_topics:
topics = await self.prompter.find_index_topics(task.strip())
else:
topics = [task.strip()]
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
self.logger.info("\n".join(topics))
self.logger.info("")
# Add the task and solution (as a combined insight) to the memory bank.
self.memory_bank.add_task_with_solution(task=task, solution=solution, topics=topics)
self.logger.leave_function()
async def retrieve_relevant_memos(self, task: str) -> List[Memo]:
"""
Retrieves any memos from memory that seem relevant to the task.
"""
self.logger.enter_function()
if self.memory_bank.contains_memos():
self.logger.info("\nCURRENT TASK:")
self.logger.info(task)
# Get a list of topics from the generalized task.
if self.generalize_task:
generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task)
else:
generalized_task = task
if self.generate_topics:
task_topics = await self.prompter.find_index_topics(generalized_task)
else:
task_topics = [generalized_task]
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
self.logger.info("\n".join(task_topics))
self.logger.info("")
# Retrieve relevant memos from the memory bank.
memo_list = self.memory_bank.get_relevant_memos(topics=task_topics)
# Apply a final validation stage to keep only the memos that the LLM concludes are sufficiently relevant.
validated_memos: List[Memo] = []
for memo in memo_list:
if len(validated_memos) >= self.max_memos_to_retrieve:
break
if (not self.validate_memos) or await self.prompter.validate_insight(memo.insight, task):
validated_memos.append(memo)
self.logger.info("\n{} VALIDATED MEMOS".format(len(validated_memos)))
for memo in validated_memos:
if memo.task is not None:
self.logger.info("\n TASK: {}".format(memo.task))
self.logger.info("\n INSIGHT: {}".format(memo.insight))
else:
self.logger.info("\nNO SUFFICIENTLY RELEVANT MEMOS WERE FOUND IN MEMORY")
validated_memos = []
self.logger.leave_function()
return validated_memos
def _format_memory_section(self, memories: List[str]) -> str:
"""
Formats a list of memories as a section for appending to a task description.
"""
memory_section = ""
if len(memories) > 0:
memory_section = "## Important insights that may help solve tasks like this\n"
for mem in memories:
memory_section += "- " + mem + "\n"
return memory_section
async def _test_for_failure(
self, task: str, task_plus_insights: str, expected_answer: str
) -> Tuple[bool, str, str]:
"""
Attempts to solve the given task multiple times to find a failure case to learn from.
"""
self.logger.enter_function()
self.logger.info("\nTask description, including any insights: {}".format(task_plus_insights))
self.logger.info("\nExpected answer: {}\n".format(expected_answer))
assert self.task_assignment_callback is not None
failure_found = False
response, work_history = "", ""
for trial in range(self.max_test_trials):
self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1))
# Attempt to solve the task.
self.logger.info("Try to solve the task.")
response, work_history = await self.task_assignment_callback(task_plus_insights)
response_is_correct, extracted_answer = await self.grader.is_response_correct(
task, response, expected_answer
)
self.logger.info("Extracted answer: {}".format(extracted_answer))
if response_is_correct:
self.logger.info("Answer is CORRECT.\n")
else:
self.logger.info("Answer is INCORRECT.\n Stop testing, and return the details of the failure.\n")
failure_found = True
break
self.logger.leave_function()
return failure_found, response, work_history
async def _iterate_on_task(self, task: str, expected_answer: str) -> Tuple[str, None | str]:
"""
Repeatedly assigns a task to the agent, and tries to learn from failures by creating useful insights as memories.
"""
self.logger.enter_function()
self.logger.info("\nTask description: {}".format(task))
self.logger.info("\nExpected answer: {}\n".format(expected_answer))
final_response = ""
old_memos = await self.retrieve_relevant_memos(task)
old_insights = [memo.insight for memo in old_memos]
new_insights: List[str] = []
last_insight = None
insight = None
successful_insight = None
# Loop until success (or timeout) while learning from failures.
for trial in range(1, self.max_train_trials + 1):
self.logger.info("\n----- TRAIN TRIAL {} -----\n".format(trial))
task_plus_insights = task
# Add any new insights we've accumulated so far.
if last_insight is not None:
memory_section = self._format_memory_section(old_insights + [last_insight])
else:
memory_section = self._format_memory_section(old_insights)
if len(memory_section) > 0:
task_plus_insights += "\n\n" + memory_section
# Can we find a failure case to learn from?
failure_found, response, work_history = await self._test_for_failure(
task, task_plus_insights, expected_answer
)
if not failure_found:
# No. Time to exit the loop.
self.logger.info("\nResponse is CORRECT.\n Stop looking for insights.\n")
# Was this the first trial?
if trial == 1:
# Yes. We should return the successful response, and no insight.
final_response = response
else:
# No. We learned a successful insight, which should be returned.
successful_insight = insight
break
# Will we try again?
if trial == self.max_train_trials:
# No. We're out of training trials.
self.logger.info("\nNo more trials will be attempted.\n")
break
# Try to learn from this failure.
self.logger.info("\nResponse is INCORRECT. Try to learn from this failure.\n")
insight = await self.prompter.learn_from_failure(
task, memory_section, response, expected_answer, work_history
)
self.logger.info("\nInsight: {}\n".format(insight))
new_insights.append(insight)
last_insight = insight
# Return the answer from the last loop.
self.logger.info("\n{}\n".format(final_response))
self.logger.leave_function()
return final_response, successful_insight
async def _append_any_relevant_memories(self, task: str) -> str:
"""
Appends any relevant memories to the task description.
"""
self.logger.enter_function()
filtered_memos = await self.retrieve_relevant_memos(task)
filtered_insights = [memo.insight for memo in filtered_memos]
if len(filtered_insights) > 0:
self.logger.info("Relevant insights were retrieved from memory.\n")
memory_section = self._format_memory_section(filtered_insights)
if len(memory_section) > 0:
task = task + "\n\n" + memory_section
self.logger.leave_function()
return task
async def assign_task(self, task: str, use_memory: bool = True, should_await: bool = True) -> str:
"""
Assigns a task to some agent through the task_assignment_callback, along with any relevant memories.
"""
self.logger.enter_function()
assert self.task_assignment_callback is not None
if use_memory:
task = await self._append_any_relevant_memories(task)
# Attempt to solve the task.
self.logger.info("Try to solve the task.\n")
assert should_await
response, _ = await self.task_assignment_callback(task)
self.logger.leave_function()
return response
async def consider_memo_storage(self, text: str) -> str | None:
"""
Tries to extract any advice from the given text and add it to memory.
"""
self.logger.enter_function()
advice = await self.prompter.extract_advice(text)
self.logger.info("Advice: {}".format(advice))
if advice is not None:
await self.add_memo(insight=advice)
self.logger.leave_function()
return advice
async def handle_user_message(self, text: str, should_await: bool = True) -> str:
"""
Handles a user message by extracting any advice as an insight to be stored in memory, and then calling assign_task().
"""
self.logger.enter_function()
# Check for advice.
advice = await self.consider_memo_storage(text)
# Assign the task through the task_assignment_callback, using memory only if no advice was just provided.
response = await self.assign_task(text, use_memory=(advice is None), should_await=should_await)
self.logger.leave_function()
return response

View File

@@ -0,0 +1,15 @@
from .apprentice import Apprentice, ApprenticeConfig
from .chat_completion_client_recorder import ChatCompletionClientRecorder
from .grader import Grader
from .page_logger import PageLogger, PageLoggerConfig
from .teachability import Teachability
__all__ = [
"Apprentice",
"ChatCompletionClientRecorder",
"Grader",
"PageLogger",
"Teachability",
"ApprenticeConfig",
"PageLoggerConfig",
]

View File

@@ -0,0 +1,96 @@
import hashlib
import os
from typing import List, Tuple, Union
from agentdhal_core import FunctionCall, Image
from agentdhal_core.models import FunctionExecutionResult
# Convenience types
UserContent = Union[str, List[Union[str, Image]]]
AssistantContent = Union[str, List[FunctionCall]]
FunctionExecutionContent = List[FunctionExecutionResult]
SystemContent = str
MessageContent = UserContent | AssistantContent | SystemContent | FunctionExecutionContent
def message_content_to_str(message_content: MessageContent | None) -> str:
"""
Converts the message content to a string.
"""
if message_content is None:
return ""
elif isinstance(message_content, str):
return message_content
elif isinstance(message_content, List):
converted: List[str] = list()
for item in message_content:
if isinstance(item, str):
converted.append(item)
elif isinstance(item, Image):
converted.append("<Image>")
else:
converted.append(str(item).rstrip())
return "\n".join(converted)
else:
raise AssertionError("Unexpected response type.")
def text_from_user_content(user_content: UserContent) -> str:
"""
Extracts just the text from the user content.
"""
if isinstance(user_content, str):
return user_content
elif isinstance(user_content, List):
text_list: List[str] = list()
for item in user_content:
if isinstance(item, str):
text_list.append(item.rstrip())
return "\n\n".join(text_list)
else:
raise AssertionError("Unexpected response type.")
def single_image_from_user_content(user_content: UserContent) -> Union[Image, None]:
"""
Extracts a single image from the user content.
"""
image_to_return = None
if isinstance(user_content, str):
return None
elif isinstance(user_content, List):
for item in user_content:
if isinstance(item, Image):
assert image_to_return is None, "Only one image is currently allowed in the user content."
image_to_return = item
else:
raise AssertionError("Unexpected response type.")
return image_to_return
def hash_directory(directory: str, hash_algo: str = "sha256") -> Tuple[str, int, int]:
"""Computes a hash representing the state of a directory, including its structure and file contents."""
hash_func = hashlib.new(hash_algo)
# Also count the number of files and sub-directories
num_files = 0
num_subdirs = 0
for root, dirs, files in sorted(os.walk(directory)): # Ensure order for consistent hashing
num_files += len(files)
num_subdirs += len(dirs)
for dir_name in sorted(dirs):
hash_func.update(dir_name.encode()) # Hash directory names
for file_name in sorted(files):
file_path = os.path.join(root, file_name)
hash_func.update(file_name.encode()) # Hash file names
try:
with open(file_path, "rb") as f:
while chunk := f.read(4096): # Read in chunks
hash_func.update(chunk)
except Exception:
pass
return hash_func.hexdigest(), num_files, num_subdirs

View File

@@ -0,0 +1,257 @@
import random
import time
from typing import TYPE_CHECKING, Any, List, Sequence, Tuple, TypedDict
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.base import TaskResult
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, TextMessage
from agentdhal_core.models import (
ChatCompletionClient,
LLMMessage,
SystemMessage,
UserMessage,
)
from .page_logger import PageLogger
if TYPE_CHECKING:
from ..memory_controller import MemoryControllerConfig
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
# the settings that change frequently, as when loading many settings from a single YAML file.
class ApprenticeConfig(TypedDict, total=False):
name_of_agent_or_team: str
disable_prefix_caching: bool
MemoryController: "MemoryControllerConfig"
class Apprentice:
"""
A minimal wrapper combining task-centric memory with an agent or team.
Applications may use the Apprentice class, or they may directly instantiate
and call the Memory Controller using this class as an example.
Args:
client: The client to call the model.
config: An optional dict that can be used to override the following values:
- name_of_agent_or_team: The name of the target agent or team for assigning tasks to.
- disable_prefix_caching: True to disable prefix caching by prepending random ints to the first message.
- MemoryController: A config dict passed to MemoryController.
logger: An optional logger. If None, a default logger will be created.
"""
def __init__(
self,
client: ChatCompletionClient,
config: ApprenticeConfig | None = None,
logger: PageLogger | None = None,
) -> None:
if logger is None:
logger = PageLogger({"level": "DEBUG"})
self.logger = logger
# Apply default settings and any config overrides.
self.name_of_agent_or_team = "AssistantAgent"
self.disable_prefix_caching = False
memory_controller_config = None
if config is not None:
self.name_of_agent_or_team = config.get("name_of_agent_or_team", self.name_of_agent_or_team)
self.disable_prefix_caching = config.get("disable_prefix_caching", self.disable_prefix_caching)
memory_controller_config = config.get("MemoryController", memory_controller_config)
self.client = client
if self.disable_prefix_caching:
self.rand = random.Random()
self.rand.seed(int(time.time() * 1000))
# Create the MemoryController, which creates the MemoryBank.
from ..memory_controller import MemoryController
self.memory_controller = MemoryController(
reset=True,
client=self.client,
task_assignment_callback=self.assign_task_to_agent_or_team,
config=memory_controller_config,
logger=self.logger,
)
def reset_memory(self) -> None:
"""
Resets the memory bank.
"""
self.memory_controller.reset_memory()
async def handle_user_message(self, text: str, should_await: bool = True) -> str:
"""
Handles a user message, extracting any advice and assigning a task to the agent.
"""
self.logger.enter_function()
# Pass the user message through to the memory controller.
response = await self.memory_controller.handle_user_message(text, should_await)
self.logger.leave_function()
return response
async def add_task_solution_pair_to_memory(self, task: str, solution: str) -> None:
"""
Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight.
This is useful when the insight is a demonstration of how to solve a given type of task.
"""
self.logger.enter_function()
# Pass the task and solution through to the memory controller.
await self.memory_controller.add_task_solution_pair_to_memory(task, solution)
self.logger.leave_function()
async def assign_task(self, task: str, use_memory: bool = True, should_await: bool = True) -> str:
"""
Assigns a task to the agent, along with any relevant insights/memories.
"""
self.logger.enter_function()
# Pass the task through to the memory controller.
response = await self.memory_controller.assign_task(task, use_memory, should_await)
self.logger.leave_function()
return response
async def train_on_task(self, task: str, expected_answer: str) -> None:
"""
Repeatedly assigns a task to the completion agent, and tries to learn from failures by creating useful insights as memories.
"""
self.logger.enter_function()
# Pass the task through to the memory controller.
await self.memory_controller.train_on_task(task, expected_answer)
self.logger.leave_function()
async def assign_task_to_agent_or_team(self, task: str) -> Tuple[str, str]:
"""
Passes the given task to the target agent or team.
"""
self.logger.enter_function()
# Pass the task through.
if self.name_of_agent_or_team == "MagenticOneGroupChat":
response, work_history = await self._assign_task_to_magentic_one(task)
elif self.name_of_agent_or_team == "AssistantAgent":
response, work_history = await self._assign_task_to_assistant_agent(task)
else:
raise AssertionError("Invalid base agent")
self.logger.leave_function()
return response, work_history
async def _assign_task_to_assistant_agent(self, task: str) -> Tuple[Any, Any]:
"""
Passes the given task to a newly created AssistantAgent with a generic 6-step system prompt.
"""
self.logger.enter_function()
self.logger.info(task)
system_message_content = """You are a helpful and thoughtful assistant.
In responding to every user message, you follow the same multi-step process given here:
1. Explain your understanding of the user message in detail, covering all the important points.
2. List as many possible responses as you can think of.
3. Carefully list and weigh the pros and cons (if any) of each possible response.
4. Critique the pros and cons above, looking for any flaws in your reasoning. But don't make up flaws that don't exist.
5. Decide on the best response, looping back to step 1 if none of the responses are satisfactory.
6. Finish by providing your final response in the particular format requested by the user."""
if self.disable_prefix_caching:
# Prepend a random int to disable prefix caching.
random_str = "({})\n\n".format(self.rand.randint(0, 1000000))
system_message_content = random_str + system_message_content
system_message: LLMMessage
if self.client.model_info["family"] == "o1":
# No system message allowed, so pass it as the first user message.
system_message = UserMessage(content=system_message_content, source="User")
else:
# System message allowed.
system_message = SystemMessage(content=system_message_content)
user_message: LLMMessage = UserMessage(content=task, source="User")
system_message_list: List[LLMMessage] = [system_message]
user_message_list: List[LLMMessage] = [user_message]
input_messages: List[LLMMessage] = system_message_list + user_message_list
assistant_agent = AssistantAgent(
"assistant_agent",
self.client,
system_message=system_message_content,
)
# Get the agent's response to the task.
task_result: TaskResult = await assistant_agent.run(task=TextMessage(content=task, source="User"))
messages: Sequence[BaseAgentEvent | BaseChatMessage] = task_result.messages
message: BaseAgentEvent | BaseChatMessage = messages[-1]
response_str = message.to_text()
# Log the model call
self.logger.log_model_task(
summary="Ask the model to complete the task", input_messages=input_messages, task_result=task_result
)
self.logger.info("\n----- RESPONSE -----\n\n{}\n".format(response_str))
# Use the response as the work history as well.
work_history = response_str
self.logger.leave_function()
return response_str, work_history
async def _assign_task_to_magentic_one(self, task: str) -> Tuple[str, str]:
"""
Instantiates a MagenticOneGroupChat team, and passes the given task to it.
"""
self.logger.enter_function()
self.logger.info(task)
general_agent = AssistantAgent(
"general_agent",
self.client,
description="A general GPT-4o AI assistant capable of performing a variety of tasks.",
)
from agentdhal_extensions.agents.web_surfer import MultimodalWebSurfer
web_surfer = MultimodalWebSurfer(
name="web_surfer",
model_client=self.client,
downloads_folder="logs",
debug_dir="logs",
to_save_screenshots=True,
)
from agentdhal_agentchat.teams import MagenticOneGroupChat
team = MagenticOneGroupChat(
[general_agent, web_surfer],
model_client=self.client,
max_turns=20,
)
# Get the team's response to the task.
task_result: TaskResult = await team.run(task=task)
assert isinstance(task_result, TaskResult)
messages = task_result.messages
response_str_list: List[str] = []
for message in messages:
response_str_list.append(message.to_text())
response_str = "\n".join(response_str_list)
self.logger.info("\n----- RESPONSE -----\n\n{}\n".format(response_str))
# MagenticOne's response is the chat history, which we use here as the work history.
work_history = response_str
self.logger.leave_function()
return response_str, work_history

View File

@@ -0,0 +1,227 @@
import json
import os
import warnings
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Sequence, TypedDict, Union
from agentdhal_core import CancellationToken
from agentdhal_core.models import (
ChatCompletionClient,
CreateResult,
LLMMessage,
ModelCapabilities, # type: ignore
ModelInfo,
RequestUsage,
)
from agentdhal_core.tools import Tool, ToolSchema
from pydantic import BaseModel
from .page_logger import PageLogger
class RecordDict(TypedDict):
mode: Literal["create", "create_stream"]
messages: List[Mapping[str, Any]]
response: Dict[str, Any]
stream: List[Mapping[str, Any]]
class ChatCompletionClientRecorder(ChatCompletionClient):
"""
A chat completion client that supports fast, large-scale tests of code calling LLM clients.
Two modes are supported:
1. "record": delegates to the underlying client while also recording the input messages and responses,
which are saved to disk when finalize() is called.
2. "replay": loads previously recorded message and responses from disk, then on each call
checks that its message matches the recorded message, and returns the recorded response.
The recorded data is stored as a JSON list of records. Each record is a dictionary with a "mode"
field (either "create" or "create_stream"), a serialized list of messages, and either a "response" (for
create calls) or a "stream" (a list of streamed outputs for create_stream calls).
ReplayChatCompletionClient and ChatCompletionCache do similar things, but with significant differences:
- ReplayChatCompletionClient replays pre-defined responses in a specified order without recording anything or checking the messages sent to the client.
- ChatCompletionCache caches responses and replays them for messages that have been seen before, regardless of order, and calls the base client for any uncached messages.
"""
def __init__(
self,
client: ChatCompletionClient,
mode: Literal["record", "replay"],
session_file_path: str,
logger: PageLogger | None = None,
) -> None:
if logger is None:
self.logger = PageLogger() # Disabled by default.
else:
self.logger = logger
self.logger.enter_function()
self.logger.info("Wrapping the base client in ChatCompletionClientRecorder.")
self.base_client = client
self.mode = mode
self.session_file_path = os.path.expanduser(session_file_path)
self.records: List[RecordDict] = []
self._record_index = 0
self._num_checked_records = 0
if self.mode == "record":
# Prepare to record the messages and responses.
self.logger.info("Recording mode enabled.\nRecording session to: " + self.session_file_path)
elif self.mode == "replay":
# Load the previously recorded messages and responses from disk.
self.logger.info("Replay mode enabled.\nRetrieving session from: " + self.session_file_path)
try:
with open(self.session_file_path, "r") as f:
self.records = json.load(f)
except Exception as e:
error_str = f"\nFailed to load recorded session: '{self.session_file_path}': {e}"
self.logger.error(error_str)
raise ValueError(error_str) from e
self.logger.leave_function()
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
) -> CreateResult:
current_messages: List[Mapping[str, Any]] = [msg.model_dump() for msg in messages]
if self.mode == "record":
response = await self.base_client.create(
messages,
tools=tools,
json_output=json_output,
tool_choice=tool_choice,
extra_create_args=extra_create_args,
cancellation_token=cancellation_token,
)
rec: RecordDict = {
"mode": "create",
"messages": current_messages,
"response": response.model_dump(),
"stream": [],
}
self.records.append(rec)
return response
elif self.mode == "replay":
if self._record_index >= len(self.records):
error_str = "\nNo more recorded turns to check."
self.logger.error(error_str)
raise ValueError(error_str)
rec = self.records[self._record_index]
if rec.get("mode") != "create":
error_str = f"\nRecorded call type mismatch at index {self._record_index}: expected 'create', got '{rec.get('mode')}'."
self.logger.error(error_str)
raise ValueError(error_str)
recorded_messages = rec.get("messages")
if recorded_messages != current_messages:
error_str = (
"\nCurrent message list doesn't match the recorded message list. See the pagelogs for details."
)
assert recorded_messages is not None
self.logger.log_dict_list(recorded_messages, "recorded message list")
assert current_messages is not None
self.logger.log_dict_list(current_messages, "current message list")
self.logger.error(error_str)
raise ValueError(error_str)
self._record_index += 1
self._num_checked_records += 1
data = rec.get("response")
# Populate a CreateResult from the data.
assert data is not None
result = CreateResult(
content=data.get("content", ""),
finish_reason=data.get("finish_reason", "stop"),
usage=data.get("usage", RequestUsage(prompt_tokens=0, completion_tokens=0)),
cached=True,
)
return result
else:
error_str = f"\nUnknown mode: {self.mode}"
self.logger.error(error_str)
raise ValueError(error_str)
def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
) -> AsyncGenerator[Union[str, CreateResult], None]:
return self.base_client.create_stream(
messages,
tools=tools,
tool_choice=tool_choice,
json_output=json_output,
extra_create_args=extra_create_args,
cancellation_token=cancellation_token,
)
async def close(self) -> None:
await self.base_client.close()
def actual_usage(self) -> RequestUsage:
# Calls base_client.actual_usage() and returns the result.
return self.base_client.actual_usage()
def total_usage(self) -> RequestUsage:
# Calls base_client.total_usage() and returns the result.
return self.base_client.total_usage()
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
# Calls base_client.count_tokens() and returns the result.
return self.base_client.count_tokens(messages, tools=tools)
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
# Calls base_client.remaining_tokens() and returns the result.
return self.base_client.remaining_tokens(messages, tools=tools)
@property
def capabilities(self) -> ModelCapabilities: # type: ignore
# Calls base_client.capabilities and returns the result.
warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2)
return self.base_client.capabilities
@property
def model_info(self) -> ModelInfo:
# Calls base_client.model_info and returns the result.
return self.base_client.model_info
def finalize(self) -> None:
"""
In record mode, saves the accumulated records to disk.
In replay mode, makes sure all the records were checked.
"""
self.logger.enter_function()
if self.mode == "record":
try:
# Create the directory if it doesn't exist.
os.makedirs(os.path.dirname(self.session_file_path), exist_ok=True)
# Write the records to disk.
with open(self.session_file_path, "w") as f:
json.dump(self.records, f, indent=2)
self.logger.info("\nRecorded session was saved to: " + self.session_file_path)
except Exception as e:
error_str = f"Failed to write records to '{self.session_file_path}': {e}"
self.logger.error(error_str)
raise ValueError(error_str) from e
elif self.mode == "replay":
if self._num_checked_records < len(self.records):
error_str = f"\nEarly termination. Only {self._num_checked_records} of the {len(self.records)} recorded turns were checked."
self.logger.error(error_str)
raise ValueError(error_str)
self.logger.info("\nRecorded session was fully replayed and checked.")
self.logger.leave_function()

View File

@@ -0,0 +1,179 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Tuple, Union
from agentdhal_core import Image
from agentdhal_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
LLMMessage,
SystemMessage,
UserMessage,
)
from ._functions import UserContent
from .page_logger import PageLogger
if TYPE_CHECKING:
from .apprentice import Apprentice
class Grader:
"""
Runs basic tests, and determines task success without limitation to string matches.
Args:
client: The client to call the model.
logger: An optional logger. If None, no logging will be performed.
"""
def __init__(self, client: ChatCompletionClient, logger: PageLogger | None = None) -> None:
if logger is None:
logger = PageLogger() # Nothing will be logged by this object.
self.logger = logger
self.client = client
# Create the chat history
self._chat_history: List[LLMMessage] = []
async def test_apprentice(
self,
apprentice: Apprentice,
task_description: str,
expected_answer: str,
num_trials: int,
use_memory: bool,
client: ChatCompletionClient,
) -> Tuple[int, int]:
self.logger.enter_function()
self.logger.info("Testing the apprentice on the given task.\n")
num_successes = 0
for trial in range(num_trials):
self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1))
self.logger.info("Try to solve the task.\n")
response = await apprentice.assign_task(task_description, use_memory=use_memory)
response_is_correct, extracted_answer = await self.is_response_correct(
task_description, response, expected_answer
)
self.logger.info("Extracted answer: {}".format(extracted_answer))
if response_is_correct:
self.logger.info("Answer is CORRECT.\n")
num_successes += 1
else:
self.logger.info("Answer is INCORRECT.\n")
self.logger.info("\nSuccess rate: {}%\n".format(round((num_successes / num_trials) * 100)))
self.logger.leave_function()
return num_successes, num_trials
async def call_model(
self,
summary: str,
user_content: UserContent,
system_message_content: str | None = None,
keep_these_messages: bool = True,
) -> str:
"""
Calls the model client with the given input and returns the response.
"""
# Prepare the input message list
if system_message_content is None:
system_message_content = "You are a helpful assistant."
system_message: LLMMessage
if self.client.model_info["family"] == "o1":
# No system message allowed, so pass it as the first user message.
system_message = UserMessage(content=system_message_content, source="User")
else:
# System message allowed.
system_message = SystemMessage(content=system_message_content)
user_message = UserMessage(content=user_content, source="User")
input_messages = [system_message] + self._chat_history + [user_message]
# Call the model.
response = await self.client.create(input_messages)
assert isinstance(response, CreateResult)
response_string = response.content
assert isinstance(response_string, str)
response_message = AssistantMessage(content=response_string, source="Assistant")
assert isinstance(response_message, AssistantMessage)
# Log the model call
self.logger.log_model_call(summary=summary, input_messages=input_messages, response=response)
# Manage the chat history
if keep_these_messages:
self._chat_history.append(user_message)
self._chat_history.append(response_message)
# Return the response as a string
return response_string
def _clear_history(self) -> None:
"""
Empties the message list containing the chat history.
"""
self._chat_history = []
async def is_response_correct(
self, task_description: str, response_to_be_graded: str, correct_answer: str
) -> Tuple[bool, str]:
"""
Determines whether the response is equivalent to the task's correct answer.
"""
self.logger.enter_function()
sys_message = """You are a helpful and thoughtful assistant."""
# Ask the model to extract the answer from the response.
user_message: List[Union[str, Image]] = []
user_message.append("""Your job is to extract a possible answer to the following question from the given text.
- First review the following task.
- Then review the text that follows, which may an answer, plus reasoning that led to the answer.
- Do not attempt to actually solve the task yourself.
- Don't try to judge whether the reasoning steps were correct.
- Simply respond by summarizing the answer described in the text, omitting any other parts of the text.
- If no answer is present can be extracted from the text, simply reply "None".""")
user_message.append("\n# Task description")
user_message.append(task_description)
user_message.append("\n# Text that may contain an answer")
user_message.append(response_to_be_graded)
user_message_arg: UserContent = user_message
self._clear_history()
extracted_answer = await self.call_model(
summary="Ask the model to extract the answer",
system_message_content=sys_message,
user_content=user_message_arg,
)
self.logger.info("Extracted answer: " + extracted_answer)
# Ask the model to check the answer for correctness.
user_message = [
"""Your job is to decide whether a given answer to a task is correct or not.
- You will be given the task description and the correct, gold-standard answer, along with the answer to be graded.
- In general, an answer is correct if it is equivalent to the correct answer.
- Specifically, the given answer must contain the important information from the correct answer, and must not in any way contradict the correct answer.
- Ignore any differences of grammar, spelling mistakes, punctuation, capitalization, formatting, or extra commentary.
- An answer should be considered correct if it omits information that is clearly inferred.
- For instance, if the correct answer is "Paris, France", the answer "Paris" should be considered correct.
- Respond with a single character: '1' if the answer to be graded is correct", '0' if not."""
]
user_message.append("\n# Task description")
user_message.append(task_description)
user_message.append("\n# Correct answer")
user_message.append(correct_answer)
user_message.append("\n# Answer to be graded")
user_message.append(extracted_answer)
self._clear_history()
decision = await self.call_model(
summary="Ask the model to check the answer for correctness",
system_message_content=sys_message,
user_content=user_message,
)
self.logger.info("Decision: " + decision)
self.logger.leave_function()
return decision == "1", extracted_answer

View File

@@ -0,0 +1,546 @@
import inspect
import json
import os
import shutil
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict
from agentdhal_agentchat.base import TaskResult
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage
from agentdhal_core import Image
from agentdhal_core.models import (
AssistantMessage,
CreateResult,
FunctionExecutionResultMessage,
LLMMessage,
RequestUsage,
SystemMessage,
UserMessage,
)
from ._functions import MessageContent, hash_directory
def _html_opening(file_title: str, finished: bool = False) -> str:
"""
Returns the opening text of a simple HTML file.
"""
refresh_tag = '<meta http-equiv="refresh" content="2">' if not finished else ""
st = f"""
<!DOCTYPE html>
<html>
<head>
{refresh_tag}
<title>{file_title}</title>
<style>
body {{font-size: 20px}}
body {{white-space: pre-wrap}}
</style>
</head>
<body>"""
return st
def _html_closing() -> str:
"""
Return the closing text of a simple HTML file.
"""
return """</body></html>"""
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
# the settings that change frequently, as when loading many settings from a single YAML file.
class PageLoggerConfig(TypedDict, total=False):
level: str
path: str
class PageLogger:
"""
Logs text and images to a set of HTML pages, one per function/method, linked to each other in a call tree.
Args:
config: An optional dict that can be used to override the following values:
- level: The logging level, one of DEBUG, INFO, WARNING, ERROR, CRITICAL, or NONE.
- path: The path to the directory where the log files will be written.
"""
def __init__(self, config: PageLoggerConfig | None = None) -> None:
self.levels = {
"DEBUG": 10,
"INFO": 20,
"WARNING": 30,
"ERROR": 40,
"CRITICAL": 50,
"NONE": 100,
}
# Apply default settings and any config overrides.
level_str = "NONE" # Default to no logging at all.
self.log_dir = "./pagelogs/default"
if config is not None:
level_str = config.get("level", level_str)
self.log_dir = config.get("path", self.log_dir)
self.level = self.levels[level_str]
self.log_dir = os.path.expanduser(self.log_dir)
# If the logging level is set to NONE or higher, don't log anything.
if self.level >= self.levels["NONE"]:
return
self.page_stack = PageStack()
self.pages: List[Page] = []
self.last_page_id = 0
self.name = "0 Call Tree"
self._create_run_dir()
self.flush()
self.finalized = False
def __del__(self) -> None:
self.finalize()
def finalize(self) -> None:
# Writes a hash of the log directory to a file for change detection.
if self.level >= self.levels["NONE"]:
return
# Don't finalize the log if it has already been finalized.
if self.finalized:
return
# Do nothing if the app is being forced to exit early.
if self.page_stack.size() > 0:
return
self.flush(finished=True)
# Write the hash and other details to a file.
hash_str, num_files, num_subdirs = hash_directory(self.log_dir)
hash_path = os.path.join(self.log_dir, "hash.txt")
with open(hash_path, "w") as f:
f.write(hash_str)
f.write("\n")
f.write("{} files\n".format(num_files))
f.write("{} subdirectories\n".format(num_subdirs))
self.finalized = True
@staticmethod
def _decorate_text(text: str, color: str, weight: str = "bold", demarcate: bool = False) -> str:
"""
Returns a string of text with HTML styling for weight and color.
"""
if demarcate:
text = f"<<<<< {text} >>>>>"
return f'<span style="color: {color}; font-weight: {weight};">{text}</span>'
@staticmethod
def _link_to_image(image_path: str, description: str) -> str:
"""
Returns an HTML string defining a thumbnail link to an image.
"""
# To avoid a bug in heml rendering aht displays underscores to the left of thumbnails,
# define the following string on a single line.
link = f"""<a href="{image_path}"><img src="{image_path}" alt="{description}" style="width: 300px; height: auto;"></a>"""
return link
def _get_next_page_id(self) -> int:
"""Returns the next page id and increments the counter."""
self.last_page_id += 1
return self.last_page_id
def _create_run_dir(self) -> None:
"""Creates a fresh log directory."""
if os.path.exists(self.log_dir):
shutil.rmtree(self.log_dir)
os.makedirs(self.log_dir)
def _add_page(self, summary: str, show_in_call_tree: bool = True, finished: bool = True) -> "Page":
"""
Adds a new page to the log.
"""
page = Page(
page_logger=self,
index=self._get_next_page_id(),
summary=summary,
indent_level=len(self.page_stack.stack),
show_in_call_tree=show_in_call_tree,
finished=finished,
)
self.pages.append(page)
self.flush()
if len(self.page_stack.stack) > 0:
# Insert a link to the new page into the calling page.
self.info("\n" + page.full_link)
return page
def _log_text(self, text: str) -> None:
"""
Adds text to the current page.
"""
page = self.page_stack.top()
if page is not None:
page.add_lines(text, flush=True)
def debug(self, line: str) -> None:
"""
Adds DEBUG text to the current page if debugging level <= DEBUG.
"""
if self.level <= self.levels["DEBUG"]:
self._log_text(line)
def info(self, line: str) -> None:
"""
Adds INFO text to the current page if debugging level <= INFO.
"""
if self.level <= self.levels["INFO"]:
self._log_text(line)
def warning(self, line: str) -> None:
"""
Adds WARNING text to the current page if debugging level <= WARNING.
"""
if self.level <= self.levels["WARNING"]:
self._log_text(line)
def error(self, line: str) -> None:
"""
Adds ERROR text to the current page if debugging level <= ERROR.
"""
if self.level <= self.levels["ERROR"]:
self._log_text(line)
def critical(self, line: str) -> None:
"""
Adds CRITICAL text to the current page if debugging level <= CRITICAL.
"""
if self.level <= self.levels["CRITICAL"]:
self._log_text(line)
def _message_source(self, message: LLMMessage) -> str:
"""
Returns a decorated string indicating the source of a message.
"""
source = "UNKNOWN"
color = "black"
if isinstance(message, SystemMessage):
source = "SYSTEM"
color = "purple"
elif isinstance(message, UserMessage):
source = "USER"
color = "blue"
elif isinstance(message, AssistantMessage):
source = "ASSISTANT"
color = "green"
elif isinstance(message, FunctionExecutionResultMessage):
source = "FUNCTION"
color = "red"
return self._decorate_text(source, color, demarcate=True)
def _format_message_content(self, message_content: MessageContent) -> str:
"""
Formats the message content for logging.
"""
# Start by converting the message content to a list of strings.
content_list: List[str] = []
content = message_content
if isinstance(content, str):
content_list.append(content)
elif isinstance(content, list):
for item in content:
if isinstance(item, str):
content_list.append(item.rstrip())
elif isinstance(item, Image):
# Save the image to disk.
image_filename = str(self._get_next_page_id()) + " image.jpg"
image_path = os.path.join(self.log_dir, image_filename)
item.image.save(image_path)
# Add a link to the image.
content_list.append(self._link_to_image(image_filename, "message_image"))
elif isinstance(item, Dict):
# Add a dictionary to the log.
json_str = json.dumps(item, indent=4)
content_list.append(json_str)
else:
content_list.append(str(item).rstrip())
else:
content_list.append("<UNKNOWN MESSAGE CONTENT>")
# Convert the list of strings to a single string containing newline separators.
output = ""
for item in content_list:
output += f"\n{item}\n"
return output
def log_message_content(self, message_content: MessageContent, summary: str) -> None:
"""
Adds a page containing the message's content, including any images.
"""
if self.level > self.levels["INFO"]:
return None
page = self._add_page(summary=summary, show_in_call_tree=False)
self.page_stack.write_stack_to_page(page)
page.add_lines(self._format_message_content(message_content=message_content))
page.flush()
def log_dict_list(self, content: List[Mapping[str, Any]], summary: str) -> None:
"""
Adds a page containing a list of dicts.
"""
if self.level > self.levels["INFO"]:
return None
page = self._add_page(summary=summary, show_in_call_tree=False)
self.page_stack.write_stack_to_page(page)
for item in content:
json_str = json.dumps(item, indent=4)
page.add_lines(json_str)
page.flush()
def _log_model_messages(
self, summary: str, input_messages: List[LLMMessage], response_str: str, usage: RequestUsage | None
) -> Optional["Page"]:
"""
Adds a page containing the messages to a model (including any input images) and its response.
"""
page = self._add_page(summary=summary, show_in_call_tree=False)
self.page_stack.write_stack_to_page(page)
if usage is not None:
page.add_lines("{} prompt tokens".format(usage.prompt_tokens))
page.add_lines("{} completion tokens".format(usage.completion_tokens))
for m in input_messages:
page.add_lines("\n" + self._message_source(m))
page.add_lines(self._format_message_content(message_content=m.content))
page.add_lines("\n" + self._decorate_text("ASSISTANT RESPONSE", "green", demarcate=True))
page.add_lines("\n" + response_str + "\n")
page.flush()
return page
def log_model_call(
self, summary: str, input_messages: List[LLMMessage], response: CreateResult
) -> Optional["Page"]:
"""
Logs messages sent to a model and the TaskResult response to a new page.
"""
if self.level > self.levels["INFO"]:
return None
response_str = response.content
if not isinstance(response_str, str):
response_str = "??"
page = self._log_model_messages(summary, input_messages, response_str, response.usage)
return page
def log_model_task(
self, summary: str, input_messages: List[LLMMessage], task_result: TaskResult
) -> Optional["Page"]:
"""
Logs messages sent to a model and the TaskResult response to a new page.
"""
if self.level > self.levels["INFO"]:
return None
messages: Sequence[BaseAgentEvent | BaseChatMessage] = task_result.messages
message = messages[-1]
response_str = message.to_text()
if not isinstance(response_str, str):
response_str = "??"
if hasattr(message, "models_usage"):
usage: RequestUsage | None = message.models_usage
else:
usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
page = self._log_model_messages(summary, input_messages, response_str, usage)
return page
def log_link_to_local_file(self, file_path: str) -> str:
"""
Returns a link to a local file in the log.
"""
file_name = os.path.basename(file_path)
link = f'<a href="{file_name}">{file_name}</a>'
return link
def add_link_to_image(self, description: str, source_image_path: str) -> None:
"""
Inserts a thumbnail link to an image to the page.
"""
# Remove every character from the string 'description' that is not alphanumeric or a space.
description = "".join(e for e in description if e.isalnum() or e.isspace())
target_image_filename = str(self._get_next_page_id()) + " - " + description
# Copy the image to the log directory.
local_image_path = os.path.join(self.log_dir, target_image_filename)
shutil.copyfile(source_image_path, local_image_path)
self._log_text("\n" + description)
self._log_text(self._link_to_image(target_image_filename, description))
def flush(self, finished: bool = False) -> None:
"""
Writes the current state of the log to disk.
"""
if self.level > self.levels["INFO"]:
return
# Create a call tree of the log.
call_tree_path = os.path.join(self.log_dir, self.name + ".html")
with open(call_tree_path, "w") as f:
f.write(_html_opening("0 Call Tree", finished=finished))
f.write(f"<h3>{self.name}</h3>")
f.write("\n")
for page in self.pages:
if page.show_in_call_tree:
f.write(page.line_text + "\n")
f.write("\n")
f.write(_html_closing())
def enter_function(self) -> Optional["Page"]:
"""
Adds a new page corresponding to the current function call.
"""
if self.level > self.levels["INFO"]:
return None
page = None
frame_type = inspect.currentframe()
if frame_type is not None:
frame = frame_type.f_back # Get the calling frame
if frame is not None:
# Check if it's a method by looking for 'self' or 'cls' in f_locals
if "self" in frame.f_locals:
class_name = type(frame.f_locals["self"]).__name__
elif "cls" in frame.f_locals:
class_name = frame.f_locals["cls"].__name__
else:
class_name = None # Not part of a class
if class_name is None: # Not part of a class
caller_name = frame.f_code.co_name
else:
caller_name = class_name + "." + frame.f_code.co_name
# Create a new page for this function.
page = self._add_page(summary=caller_name, show_in_call_tree=True, finished=False)
self.page_stack.push(page)
self.page_stack.write_stack_to_page(page)
page.add_lines("\nENTER {}".format(caller_name), flush=True)
return page
def leave_function(self) -> None:
"""
Finishes the page corresponding to the current function call.
"""
if self.level > self.levels["INFO"]:
return None
page = self.page_stack.top()
if page is not None:
page.finished = True
page.add_lines("\nLEAVE {}".format(page.summary), flush=True)
self.page_stack.pop()
class Page:
"""
Represents a single HTML page in the logger output.
Args:
page_logger: The PageLogger object that created this page.
index: The index of the page.
summary: A brief summary of the page's contents for display.
indent_level: The level of indentation in the call tree.
show_in_call_tree: Whether to display the page in the call tree.
finished: Whether the page is complete.
"""
def __init__(
self,
page_logger: PageLogger,
index: int,
summary: str,
indent_level: int,
show_in_call_tree: bool = True,
finished: bool = True,
):
"""
Initializes and writes to a new HTML page.
"""
self.page_logger = page_logger
self.index_str = str(index)
self.summary = summary
self.indent_level = indent_level
self.show_in_call_tree = show_in_call_tree
self.finished = finished
self.file_title = self.index_str + " " + self.summary
self.indentation_text = "|&emsp;" * self.indent_level
self.full_link = f'<a href="{self.index_str}.html">{self.file_title}</a>'
self.line_text = self.indentation_text + self.full_link
self.lines: List[str] = []
self.flush()
def add_lines(self, lines: str, flush: bool = False) -> None:
"""
Adds one or more lines to the page.
"""
lines_to_add: List[str] = []
if "\n" in lines:
lines_to_add = lines.split("\n")
else:
lines_to_add.append(lines)
self.lines.extend(lines_to_add)
if flush:
self.flush()
def flush(self) -> None:
"""
Writes the HTML page to disk.
"""
page_path = os.path.join(self.page_logger.log_dir, self.index_str + ".html")
with open(page_path, "w") as f:
f.write(_html_opening(self.file_title, finished=self.finished))
f.write(f"<h3>{self.file_title}</h3>\n")
for line in self.lines:
try:
f.write(f"{line}\n")
except UnicodeEncodeError:
f.write("UnicodeEncodeError in this line.\n")
f.write(_html_closing())
f.flush()
class PageStack:
"""
A call stack containing a list of currently active function pages in the order they called each other.
"""
def __init__(self) -> None:
self.stack: List[Page] = []
def push(self, page: Page) -> None:
"""Adds a page to the top of the stack."""
self.stack.append(page)
def pop(self) -> Page:
"""Removes and returns the top page from the stack"""
return self.stack.pop()
def size(self) -> int:
"""Returns the number of pages in the stack."""
return len(self.stack)
def top(self) -> Page | None:
"""Returns the top page from the stack without removing it"""
if self.size() == 0:
return None
return self.stack[-1]
def write_stack_to_page(self, page: Page) -> None:
# Logs a properly indented string displaying the current call stack.
page.add_lines("\nCALL STACK")
for stack_page in self.stack:
page.add_lines(stack_page.line_text)
page.add_lines("")
page.add_lines("")
page.flush()

View File

@@ -0,0 +1,133 @@
from typing import TYPE_CHECKING, Any
from agentdhal_core import CancellationToken, Image
from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
from agentdhal_core.model_context import ChatCompletionContext
from agentdhal_core.models import UserMessage
if TYPE_CHECKING:
from agentdhal_extensions.experimental.task_centric_memory import MemoryController
class Teachability(Memory):
"""
Gives an AssistantAgent the ability to learn quickly from user teachings, hints, and advice.
Steps for usage:
1. Instantiate MemoryController.
2. Instantiate Teachability, passing the memory controller as a parameter.
3. Instantiate an AssistantAgent, passing the teachability instance (wrapped in a list) as the memory parameter.
4. Use the AssistantAgent as usual, such as for chatting with the user.
"""
def __init__(self, memory_controller: "MemoryController", name: str | None = None) -> None:
"""Initialize Teachability."""
self._memory_controller = memory_controller
self._logger = memory_controller.logger
self._name = name or "teachability"
@property
def name(self) -> str:
"""Get the memory instance identifier."""
return self._name
def _extract_text(self, content_item: str | MemoryContent) -> str:
"""Extract searchable text from content."""
if isinstance(content_item, str):
return content_item
content = content_item.content
mime_type = content_item.mime_type
if mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]:
return str(content)
elif mime_type == MemoryMimeType.JSON:
if isinstance(content, dict):
# Store original JSON string representation
return str(content).lower()
raise ValueError("JSON content must be a dict")
elif isinstance(content, Image):
raise ValueError("Image content cannot be converted to text")
else:
raise ValueError(f"Unsupported content type: {mime_type}")
async def update_context(
self,
model_context: ChatCompletionContext,
) -> UpdateContextResult:
"""
Extracts any advice from the last user turn to be stored in memory,
and adds any relevant memories to the model context.
"""
self._logger.enter_function()
# Extract text from the user's last message
messages = await model_context.get_messages()
if not messages:
self._logger.leave_function()
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
last_message = messages[-1]
last_user_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
# Add any relevant memories to the chat history
query_results = await self.query(last_user_text)
if query_results.results:
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)]
memory_context = "\nPotentially relevant memories:\n" + "\n".join(memory_strings)
await model_context.add_message(UserMessage(content=memory_context, source="user"))
# Add any user advice to memory
await self._memory_controller.consider_memo_storage(last_user_text)
self._logger.leave_function()
return UpdateContextResult(memories=query_results)
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
"""
Tries to extract any advice from the passed content and add it to memory.
"""
self._logger.enter_function()
# Extract text from the incoming content
text = self._extract_text(content)
# Check for advice to add to memory for later turns.
await self._memory_controller.consider_memo_storage(text)
self._logger.leave_function()
async def query(
self,
query: str | MemoryContent,
cancellation_token: CancellationToken | None = None,
**kwargs: Any,
) -> MemoryQueryResult:
"""
Returns any memories that seem relevant to the query.
"""
self._logger.enter_function()
task = self._extract_text(query)
memory_results: list[MemoryContent] = []
filtered_memos = await self._memory_controller.retrieve_relevant_memos(task=task)
filtered_insights = [memo.insight for memo in filtered_memos]
for insight in filtered_insights:
self._logger.info(f"Insight: {insight}")
memory_content = MemoryContent(
content=insight,
mime_type="MemoryMimeType.TEXT",
metadata={},
)
memory_results.append(memory_content)
self._logger.leave_function()
return MemoryQueryResult(results=memory_results)
async def clear(self) -> None:
"""Clear all entries from memory."""
self._memory_controller.reset_memory()
async def close(self) -> None:
"""Clean up memory resources."""
pass # No cleanup needed for this memory implementation

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,4 @@
from ._text_canvas import TextCanvas
from ._text_canvas_memory import TextCanvasMemory
__all__ = ["TextCanvas", "TextCanvasMemory"]

View File

@@ -0,0 +1,50 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Union
class BaseCanvas(ABC):
"""
An abstract protocol for "canvas" objects that maintain
revision history for file-like data. Concrete subclasses
can handle text, images, structured data, etc.
.. warning::
This is an experimental API and may change in the future.
"""
@abstractmethod
def list_files(self) -> Dict[str, int]:
"""
Returns a dict of filename -> latest revision number.
"""
raise NotImplementedError
@abstractmethod
def get_latest_content(self, filename: str) -> Union[str, bytes, Any]:
"""
Returns the latest version of a file's content.
"""
raise NotImplementedError
@abstractmethod
def add_or_update_file(self, filename: str, new_content: Union[str, bytes, Any]) -> None:
"""
Creates or updates the file content with a new revision.
"""
raise NotImplementedError
@abstractmethod
def get_diff(self, filename: str, from_revision: int, to_revision: int) -> str:
"""
Returns a diff (in some format) between two revisions.
"""
raise NotImplementedError
@abstractmethod
def apply_patch(self, filename: str, patch_data: Union[str, bytes, Any]) -> None:
"""
Applies a patch/diff to the latest revision and increments the revision.
"""
raise NotImplementedError

View File

@@ -0,0 +1,64 @@
from agentdhal_core import CancellationToken
from agentdhal_core.tools import BaseTool
from pydantic import BaseModel
from ._text_canvas import TextCanvas
class UpdateFileArgs(BaseModel):
filename: str
new_content: str
class UpdateFileResult(BaseModel):
status: str
class UpdateFileTool(BaseTool[UpdateFileArgs, UpdateFileResult]):
"""
Overwrites or creates a file in the canvas.
"""
def __init__(self, canvas: TextCanvas):
super().__init__(
args_type=UpdateFileArgs,
return_type=UpdateFileResult,
name="update_file",
description="Create/update a file on the canvas with the provided content.",
)
self._canvas = canvas
async def run(self, args: UpdateFileArgs, cancellation_token: CancellationToken) -> UpdateFileResult:
self._canvas.add_or_update_file(args.filename, args.new_content)
return UpdateFileResult(status="OK")
class ApplyPatchArgs(BaseModel):
filename: str
patch_text: str
class ApplyPatchResult(BaseModel):
status: str
class ApplyPatchTool(BaseTool[ApplyPatchArgs, ApplyPatchResult]):
"""
Applies a unified diff patch to the given file on the canvas.
"""
def __init__(self, canvas: TextCanvas):
super().__init__(
args_type=ApplyPatchArgs,
return_type=ApplyPatchResult,
name="apply_patch",
description=(
"Apply a unified diff patch to an existing file on the canvas. "
"The patch must be in diff/patch format. The file must exist or be created first."
),
)
self._canvas = canvas
async def run(self, args: ApplyPatchArgs, cancellation_token: CancellationToken) -> ApplyPatchResult:
self._canvas.apply_patch(args.filename, args.patch_text)
return ApplyPatchResult(status="PATCH APPLIED")

View File

@@ -0,0 +1,192 @@
import difflib
from typing import Any, Dict, List, Union
try: # pragma: no cover
from unidiff import PatchSet
except ModuleNotFoundError: # pragma: no cover
PatchSet = None # type: ignore
from ._canvas import BaseCanvas
class FileRevision:
"""Tracks the history of one file's content."""
__slots__ = ("content", "revision")
def __init__(self, content: str, revision: int) -> None:
self.content: str = content
self.revision: int = revision # e.g. an integer, a timestamp, or git hash
class TextCanvas(BaseCanvas):
"""An inmemory canvas that stores *text* files with full revision history.
.. warning::
This is an experimental API and may change in the future.
Besides the original CRUDlike operations, this enhanced implementation adds:
* **apply_patch** applies patches using the ``unidiff`` library for accurate
hunk application and context line validation.
* **get_revision_content** random access to any historical revision.
* **get_revision_diffs** obtain the list of diffs applied between every
consecutive pair of revisions so that a caller can replay or audit the
full change history.
"""
# ----------------------------------------------------------------------------------
# Construction helpers
# ----------------------------------------------------------------------------------
def __init__(self) -> None:
# For each file we keep an *ordered* list of FileRevision where the last
# element is the most recent. Using a list keeps the memory footprint
# small and preserves order without any extra bookkeeping.
self._files: Dict[str, List[FileRevision]] = {}
# ----------------------------------------------------------------------------------
# Internal utilities
# ----------------------------------------------------------------------------------
def _latest_idx(self, filename: str) -> int:
"""Return the index (not revision number) of the newest revision."""
return len(self._files.get(filename, [])) - 1
def _ensure_file(self, filename: str) -> None:
if filename not in self._files:
raise ValueError(f"File '{filename}' does not exist on the canvas; create it first.")
# ----------------------------------------------------------------------------------
# Revision inspection helpers
# ----------------------------------------------------------------------------------
def get_revision_content(self, filename: str, revision: int) -> str: # NEW 🚀
"""Return the exact content stored in *revision*.
If the revision does not exist an empty string is returned so that
downstream code can handle the "not found" case without exceptions.
"""
for rev in self._files.get(filename, []):
if rev.revision == revision:
return rev.content
return ""
def get_revision_diffs(self, filename: str) -> List[str]: # NEW 🚀
"""Return a *chronological* list of unifieddiffs for *filename*.
Each element in the returned list represents the diff that transformed
revision *n* into revision *n+1* (starting at revision 1 → 2).
"""
revisions = self._files.get(filename, [])
diffs: List[str] = []
for i in range(1, len(revisions)):
older, newer = revisions[i - 1], revisions[i]
diff = difflib.unified_diff(
older.content.splitlines(keepends=True),
newer.content.splitlines(keepends=True),
fromfile=f"{filename}@r{older.revision}",
tofile=f"{filename}@r{newer.revision}",
)
diffs.append("".join(diff))
return diffs
# ----------------------------------------------------------------------------------
# BaseCanvas interface implementation
# ----------------------------------------------------------------------------------
def list_files(self) -> Dict[str, int]:
"""Return a mapping of *filename → latest revision number*."""
return {fname: revs[-1].revision for fname, revs in self._files.items() if revs}
def get_latest_content(self, filename: str) -> str: # noqa: D401 keep API identical
"""Return the most recent content or an empty string if the file is new."""
revs = self._files.get(filename, [])
return revs[-1].content if revs else ""
def add_or_update_file(self, filename: str, new_content: Union[str, bytes, Any]) -> None:
"""Create *filename* or append a new revision containing *new_content*."""
if isinstance(new_content, bytes):
new_content = new_content.decode("utf-8")
if not isinstance(new_content, str):
raise ValueError(f"Expected str or bytes, got {type(new_content)}")
if filename not in self._files:
self._files[filename] = [FileRevision(new_content, 1)]
else:
last_rev_num = self._files[filename][-1].revision
self._files[filename].append(FileRevision(new_content, last_rev_num + 1))
def get_diff(self, filename: str, from_revision: int, to_revision: int) -> str:
"""Return a unified diff between *from_revision* and *to_revision*."""
revisions = self._files.get(filename, [])
if not revisions:
return ""
# Fetch the contents for the requested revisions.
from_content = self.get_revision_content(filename, from_revision)
to_content = self.get_revision_content(filename, to_revision)
if from_content == "" and to_content == "": # one (or both) revision ids not found
return ""
diff = difflib.unified_diff(
from_content.splitlines(keepends=True),
to_content.splitlines(keepends=True),
fromfile=f"{filename}@r{from_revision}",
tofile=f"{filename}@r{to_revision}",
)
return "".join(diff)
def apply_patch(self, filename: str, patch_data: Union[str, bytes, Any]) -> None:
"""Apply *patch_text* (unified diff) to the latest revision and save a new revision.
Uses the *unidiff* library to accurately apply hunks and validate context lines.
"""
if isinstance(patch_data, bytes):
patch_data = patch_data.decode("utf-8")
if not isinstance(patch_data, str):
raise ValueError(f"Expected str or bytes, got {type(patch_data)}")
self._ensure_file(filename)
original_content = self.get_latest_content(filename)
if PatchSet is None:
raise ImportError(
"The 'unidiff' package is required for patch application. Install with 'pip install unidiff'."
)
patch = PatchSet(patch_data)
# Our canvas stores exactly one file per patch operation so we
# use the first (and only) patched_file object.
if not patch:
raise ValueError("Empty patch text provided.")
patched_file = patch[0]
working_lines = original_content.splitlines(keepends=True)
line_offset = 0
for hunk in patched_file:
# Calculate the slice boundaries in the *current* working copy.
start = hunk.source_start - 1 + line_offset
end = start + hunk.source_length
# Build the replacement block for this hunk.
replacement: List[str] = []
for line in hunk:
if line.is_added or line.is_context:
replacement.append(line.value)
# removed lines (line.is_removed) are *not* added.
# Replace the slice with the hunkresult.
working_lines[start:end] = replacement
line_offset += len(replacement) - (end - start)
new_content = "".join(working_lines)
# Finally commit the new revision.
self.add_or_update_file(filename, new_content)
# ----------------------------------------------------------------------------------
# Convenience helpers
# ----------------------------------------------------------------------------------
def get_all_contents_for_context(self) -> str: # noqa: D401 keep public API stable
"""Return a summarised view of every file and its *latest* revision."""
out: List[str] = ["=== CANVAS FILES ==="]
for fname, revs in self._files.items():
latest = revs[-1]
out.append(f"File: {fname} (rev {latest.revision}):\n{latest.content}\n")
out.append("=== END OF CANVAS ===")
return "\n".join(out)

View File

@@ -0,0 +1,229 @@
from typing import Any, Optional
from agentdhal_core import CancellationToken
from agentdhal_core.memory import (
Memory,
MemoryContent,
MemoryMimeType,
MemoryQueryResult,
UpdateContextResult,
)
from agentdhal_core.model_context import ChatCompletionContext
from agentdhal_core.models import SystemMessage
from ._canvas_writer import ApplyPatchTool, UpdateFileTool
from ._text_canvas import TextCanvas
class TextCanvasMemory(Memory):
"""
A memory implementation that uses a Canvas for storing file-like content.
Inserts the current state of the canvas into the ChatCompletionContext on each turn.
.. warning::
This is an experimental API and may change in the future.
The TextCanvasMemory provides a persistent, file-like storage mechanism that can be used
by agents to read and write content. It automatically injects the current state of all files
in the canvas into the model context before each inference.
This is particularly useful for:
- Allowing agents to create and modify documents over multiple turns
- Enabling collaborative document editing between multiple agents
- Maintaining persistent state across conversation turns
- Working with content too large to fit in a single message
The canvas provides tools for:
- Creating or updating files with new content
- Applying patches (unified diff format) to existing files
Examples:
**Example: Using TextCanvasMemory with an AssistantAgent**
The following example demonstrates how to create a TextCanvasMemory and use it with
an AssistantAgent to write and update a story file.
.. code-block:: python
import asyncio
from agentdhal_core import CancellationToken
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.messages import TextMessage
from agentdhal_extensions.memory.canvas import TextCanvasMemory
async def main():
# Create a model client
model_client = OpenAIChatCompletionClient(
model="gpt-4o",
# api_key = "your_openai_api_key"
)
# Create the canvas memory
text_canvas_memory = TextCanvasMemory()
# Get tools for working with the canvas
update_file_tool = text_canvas_memory.get_update_file_tool()
apply_patch_tool = text_canvas_memory.get_apply_patch_tool()
# Create an agent with the canvas memory and tools
writer_agent = AssistantAgent(
name="Writer",
model_client=model_client,
description="A writer agent that creates and updates stories.",
system_message='''
You are a Writer Agent. Your focus is to generate a story based on the user's request.
Instructions for using the canvas:
- The story should be stored on the canvas in a file named "story.md".
- If "story.md" does not exist, create it by calling the 'update_file' tool.
- If "story.md" already exists, generate a unified diff (patch) from the current
content to the new version, and call the 'apply_patch' tool to apply the changes.
IMPORTANT: Do not include the full story text in your chat messages.
Only write the story content to the canvas using the tools.
''',
tools=[update_file_tool, apply_patch_tool],
memory=[text_canvas_memory],
)
# Send a message to the agent
await writer_agent.on_messages(
[TextMessage(content="Write a short story about a bunny and a sunflower.", source="user")],
CancellationToken(),
)
# Retrieve the content from the canvas
story_content = text_canvas_memory.canvas.get_latest_content("story.md")
print("Story content from canvas:")
print(story_content)
if __name__ == "__main__":
asyncio.run(main())
**Example: Using TextCanvasMemory with multiple agents**
The following example shows how to use TextCanvasMemory with multiple agents
collaborating on the same document.
.. code-block:: python
import asyncio
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.teams import RoundRobinGroupChat
from agentdhal_agentchat.conditions import TextMentionTermination
from agentdhal_extensions.memory.canvas import TextCanvasMemory
async def main():
# Create a model client
model_client = OpenAIChatCompletionClient(
model="gpt-4o",
# api_key = "your_openai_api_key"
)
# Create the shared canvas memory
text_canvas_memory = TextCanvasMemory()
update_file_tool = text_canvas_memory.get_update_file_tool()
apply_patch_tool = text_canvas_memory.get_apply_patch_tool()
# Create a writer agent
writer_agent = AssistantAgent(
name="Writer",
model_client=model_client,
description="A writer agent that creates stories.",
system_message="You write children's stories on the canvas in story.md.",
tools=[update_file_tool, apply_patch_tool],
memory=[text_canvas_memory],
)
# Create a critique agent
critique_agent = AssistantAgent(
name="Critique",
model_client=model_client,
description="A critique agent that provides feedback on stories.",
system_message="You review the story.md file and provide constructive feedback.",
memory=[text_canvas_memory],
)
# Create a team with both agents
team = RoundRobinGroupChat(
participants=[writer_agent, critique_agent],
termination_condition=TextMentionTermination("TERMINATE"),
max_turns=10,
)
# Run the team on a task
await team.run(task="Create a children's book about a bunny and a sunflower")
# Get the final story
story = text_canvas_memory.canvas.get_latest_content("story.md")
print(story)
if __name__ == "__main__":
asyncio.run(main())
"""
def __init__(self, canvas: Optional[TextCanvas] = None):
super().__init__()
self.canvas = canvas if canvas is not None else TextCanvas()
async def update_context(self, model_context: ChatCompletionContext) -> UpdateContextResult:
"""
Inject the entire canvas summary (or a selected subset) as reference data.
Here, we just put it into a system message, but you could customize.
"""
snapshot = self.canvas.get_all_contents_for_context()
if snapshot.strip():
msg = SystemMessage(content=snapshot)
await model_context.add_message(msg)
# Return it for debugging/logging
memory_content = MemoryContent(content=snapshot, mime_type=MemoryMimeType.TEXT)
return UpdateContextResult(memories=MemoryQueryResult(results=[memory_content]))
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
async def query(
self, query: str | MemoryContent, cancellation_token: Optional[CancellationToken] = None, **kwargs: Any
) -> MemoryQueryResult:
"""
Potentially search for matching filenames or file content.
This example returns empty.
"""
return MemoryQueryResult(results=[])
async def add(self, content: MemoryContent, cancellation_token: Optional[CancellationToken] = None) -> None:
"""
Example usage: Possibly interpret content as a patch or direct file update.
Could also be done by a specialized "CanvasTool" instead.
"""
# NO-OP here, leaving actual changes to the CanvasTool
pass
async def clear(self) -> None:
"""Clear the entire canvas by replacing it with a new empty instance."""
# Create a new TextCanvas instance instead of calling __init__ directly
self.canvas = TextCanvas()
async def close(self) -> None:
pass
def get_update_file_tool(self) -> UpdateFileTool:
"""
Returns an UpdateFileTool instance that works with this memory's canvas.
"""
return UpdateFileTool(self.canvas)
def get_apply_patch_tool(self) -> ApplyPatchTool:
"""
Returns an ApplyPatchTool instance that works with this memory's canvas.
"""
return ApplyPatchTool(self.canvas)

View File

@@ -0,0 +1,21 @@
from ._chroma_configs import (
ChromaDBVectorMemoryConfig,
CustomEmbeddingFunctionConfig,
DefaultEmbeddingFunctionConfig,
HttpChromaDBVectorMemoryConfig,
OpenAIEmbeddingFunctionConfig,
PersistentChromaDBVectorMemoryConfig,
SentenceTransformerEmbeddingFunctionConfig,
)
from ._chromadb import ChromaDBVectorMemory
__all__ = [
"ChromaDBVectorMemory",
"ChromaDBVectorMemoryConfig",
"PersistentChromaDBVectorMemoryConfig",
"HttpChromaDBVectorMemoryConfig",
"DefaultEmbeddingFunctionConfig",
"SentenceTransformerEmbeddingFunctionConfig",
"OpenAIEmbeddingFunctionConfig",
"CustomEmbeddingFunctionConfig",
]

View File

@@ -0,0 +1,137 @@
"""Configuration classes for ChromaDB vector memory."""
from typing import Any, Callable, Dict, Literal, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
class DefaultEmbeddingFunctionConfig(BaseModel):
"""Configuration for the default ChromaDB embedding function.
Uses ChromaDB's default embedding function (Sentence Transformers all-MiniLM-L6-v2).
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
"""
function_type: Literal["default"] = "default"
class SentenceTransformerEmbeddingFunctionConfig(BaseModel):
"""Configuration for SentenceTransformer embedding functions.
Allows specifying a custom SentenceTransformer model for embeddings.
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
Args:
model_name (str): Name of the SentenceTransformer model to use.
Defaults to "all-MiniLM-L6-v2".
Example:
.. code-block:: python
from agentdhal_extensions.memory.chromadb import SentenceTransformerEmbeddingFunctionConfig
_ = SentenceTransformerEmbeddingFunctionConfig(model_name="paraphrase-multilingual-mpnet-base-v2")
"""
function_type: Literal["sentence_transformer"] = "sentence_transformer"
model_name: str = Field(default="all-MiniLM-L6-v2", description="SentenceTransformer model name to use")
class OpenAIEmbeddingFunctionConfig(BaseModel):
"""Configuration for OpenAI embedding functions.
Uses OpenAI's embedding API for generating embeddings.
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
Args:
api_key (str): OpenAI API key. If empty, will attempt to use environment variable.
model_name (str): OpenAI embedding model name. Defaults to "text-embedding-ada-002".
Example:
.. code-block:: python
from agentdhal_extensions.memory.chromadb import OpenAIEmbeddingFunctionConfig
_ = OpenAIEmbeddingFunctionConfig(api_key="sk-...", model_name="text-embedding-3-small")
"""
function_type: Literal["openai"] = "openai"
api_key: str = Field(default="", description="OpenAI API key")
model_name: str = Field(default="text-embedding-ada-002", description="OpenAI embedding model name")
class CustomEmbeddingFunctionConfig(BaseModel):
"""Configuration for custom embedding functions.
Allows using a custom function that returns a ChromaDB-compatible embedding function.
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
.. warning::
Configurations containing custom functions are not serializable.
Args:
function (Callable): Function that returns a ChromaDB-compatible embedding function.
params (Dict[str, Any]): Parameters to pass to the function.
"""
function_type: Literal["custom"] = "custom"
function: Callable[..., Any] = Field(description="Function that returns an embedding function")
params: Dict[str, Any] = Field(default_factory=dict, description="Parameters to pass to the function")
# Tagged union type for embedding function configurations
EmbeddingFunctionConfig = Annotated[
Union[
DefaultEmbeddingFunctionConfig,
SentenceTransformerEmbeddingFunctionConfig,
OpenAIEmbeddingFunctionConfig,
CustomEmbeddingFunctionConfig,
],
Field(discriminator="function_type"),
]
class ChromaDBVectorMemoryConfig(BaseModel):
"""Base configuration for ChromaDB-based memory implementation.
.. versionchanged:: v0.4.1
Added support for custom embedding functions via embedding_function_config.
"""
client_type: Literal["persistent", "http"]
collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection")
distance_metric: str = Field(default="cosine", description="Distance metric for similarity search")
k: int = Field(default=3, description="Number of results to return in queries")
score_threshold: float | None = Field(default=None, description="Minimum similarity score threshold")
allow_reset: bool = Field(default=False, description="Whether to allow resetting the ChromaDB client")
tenant: str = Field(default="default_tenant", description="Tenant to use")
database: str = Field(default="default_database", description="Database to use")
embedding_function_config: EmbeddingFunctionConfig = Field(
default_factory=DefaultEmbeddingFunctionConfig, description="Configuration for the embedding function"
)
class PersistentChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
"""Configuration for persistent ChromaDB memory."""
client_type: Literal["persistent", "http"] = "persistent"
persistence_path: str = Field(default="./chroma_db", description="Path for persistent storage")
class HttpChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
"""Configuration for HTTP ChromaDB memory."""
client_type: Literal["persistent", "http"] = "http"
host: str = Field(default="localhost", description="Host of the remote server")
port: int = Field(default=8000, description="Port of the remote server")
ssl: bool = Field(default=False, description="Whether to use HTTPS")
headers: Dict[str, str] | None = Field(default=None, description="Headers to send to the server")

View File

@@ -0,0 +1,459 @@
import logging
import uuid
from typing import Any, List
from agentdhal_core import CancellationToken, Component, Image
from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
from agentdhal_core.model_context import ChatCompletionContext
from agentdhal_core.models import SystemMessage
from chromadb import HttpClient, PersistentClient
from chromadb.api.models.Collection import Collection
from chromadb.api.types import Document, Metadata
from typing_extensions import Self
from ._chroma_configs import (
ChromaDBVectorMemoryConfig,
CustomEmbeddingFunctionConfig,
DefaultEmbeddingFunctionConfig,
HttpChromaDBVectorMemoryConfig,
OpenAIEmbeddingFunctionConfig,
PersistentChromaDBVectorMemoryConfig,
SentenceTransformerEmbeddingFunctionConfig,
)
logger = logging.getLogger(__name__)
try:
from chromadb.api import ClientAPI
except ImportError as e:
raise ImportError(
"To use the ChromaDBVectorMemory the chromadb extra must be installed. Run `pip install autogen-ext[chromadb]`"
) from e
class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
"""
Store and retrieve memory using vector similarity search powered by ChromaDB.
`ChromaDBVectorMemory` provides a vector-based memory implementation that uses ChromaDB for
storing and retrieving content based on semantic similarity. It enhances agents with the ability
to recall contextually relevant information during conversations by leveraging vector embeddings
to find similar content.
This implementation serves as a reference for more complex memory systems using vector embeddings.
For advanced use cases requiring specialized formatting of retrieved content, users should extend
this class and override the `update_context()` method.
This implementation requires the ChromaDB extra to be installed. Install with:
.. code-block:: bash
pip install "agentdhal-ext[chromadb]"
Args:
config (ChromaDBVectorMemoryConfig | None): Configuration for the ChromaDB memory.
If None, defaults to a PersistentChromaDBVectorMemoryConfig with default values.
Two config types are supported:
* PersistentChromaDBVectorMemoryConfig: For local storage
* HttpChromaDBVectorMemoryConfig: For connecting to a remote ChromaDB server
Example:
.. code-block:: python
import os
import asyncio
from pathlib import Path
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.ui import Console
from agentdhal_core.memory import MemoryContent, MemoryMimeType
from agentdhal_extensions.memory.chromadb import (
ChromaDBVectorMemory,
PersistentChromaDBVectorMemoryConfig,
SentenceTransformerEmbeddingFunctionConfig,
OpenAIEmbeddingFunctionConfig,
)
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
def get_weather(city: str) -> str:
return f"The weather in {city} is sunny with a high of 90°F and a low of 70°F."
def fahrenheit_to_celsius(fahrenheit: float) -> float:
return (fahrenheit - 32) * 5.0 / 9.0
async def main() -> None:
# Use default embedding function
default_memory = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
collection_name="user_preferences",
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
k=3, # Return top 3 results
score_threshold=0.5, # Minimum similarity score
)
)
# Using a custom SentenceTransformer model
custom_memory = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
collection_name="multilingual_memory",
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(
model_name="paraphrase-multilingual-mpnet-base-v2"
),
)
)
# Using OpenAI embeddings
openai_memory = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
collection_name="openai_memory",
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
embedding_function_config=OpenAIEmbeddingFunctionConfig(
api_key=os.environ["OPENAI_API_KEY"], model_name="text-embedding-3-small"
),
)
)
# Add user preferences to memory
await openai_memory.add(
MemoryContent(
content="The user prefers weather temperatures in Celsius",
mime_type=MemoryMimeType.TEXT,
metadata={"category": "preferences", "type": "units"},
)
)
# Create assistant agent with ChromaDB memory
assistant = AssistantAgent(
name="assistant",
model_client=OpenAIChatCompletionClient(
model="gpt-4.1",
),
tools=[
get_weather,
fahrenheit_to_celsius,
],
max_tool_iterations=10,
memory=[openai_memory],
)
# The memory will automatically retrieve relevant content during conversations
await Console(assistant.run_stream(task="What's the temperature in New York?"))
# Remember to close the memory when finished
await default_memory.close()
await custom_memory.close()
await openai_memory.close()
asyncio.run(main())
Output:
.. code-block:: text
---------- TextMessage (user) ----------
What's the temperature in New York?
---------- MemoryQueryEvent (assistant) ----------
[MemoryContent(content='The user prefers weather temperatures in Celsius', mime_type='MemoryMimeType.TEXT', metadata={'type': 'units', 'category': 'preferences', 'mime_type': 'MemoryMimeType.TEXT', 'score': 0.3133561611175537, 'id': 'fb00506c-acf4-4174-93d7-2a942593f3f7'}), MemoryContent(content='The user prefers weather temperatures in Celsius', mime_type='MemoryMimeType.TEXT', metadata={'mime_type': 'MemoryMimeType.TEXT', 'category': 'preferences', 'type': 'units', 'score': 0.3133561611175537, 'id': '34311689-b419-4e1a-8bc4-09143f356c66'})]
---------- ToolCallRequestEvent (assistant) ----------
[FunctionCall(id='call_7TjsFd430J1aKwU5T2w8bvdh', arguments='{"city":"New York"}', name='get_weather')]
---------- ToolCallExecutionEvent (assistant) ----------
[FunctionExecutionResult(content='The weather in New York is sunny with a high of 90°F and a low of 70°F.', name='get_weather', call_id='call_7TjsFd430J1aKwU5T2w8bvdh', is_error=False)]
---------- ToolCallRequestEvent (assistant) ----------
[FunctionCall(id='call_RTjMHEZwDXtjurEYTjDlvq9c', arguments='{"fahrenheit": 90}', name='fahrenheit_to_celsius'), FunctionCall(id='call_3mMuCK1aqtzZPTqIHPoHKxtP', arguments='{"fahrenheit": 70}', name='fahrenheit_to_celsius')]
---------- ToolCallExecutionEvent (assistant) ----------
[FunctionExecutionResult(content='32.22222222222222', name='fahrenheit_to_celsius', call_id='call_RTjMHEZwDXtjurEYTjDlvq9c', is_error=False), FunctionExecutionResult(content='21.11111111111111', name='fahrenheit_to_celsius', call_id='call_3mMuCK1aqtzZPTqIHPoHKxtP', is_error=False)]
---------- TextMessage (assistant) ----------
The temperature in New York today is sunny with a high of about 32°C and a low of about 21°C.
"""
component_config_schema = ChromaDBVectorMemoryConfig
component_provider_override = "agentdhal_extensions.memory.chromadb.ChromaDBVectorMemory"
def __init__(self, config: ChromaDBVectorMemoryConfig | None = None) -> None:
self._config = config or PersistentChromaDBVectorMemoryConfig()
self._client: ClientAPI | None = None
self._collection: Collection | None = None
@property
def collection_name(self) -> str:
"""Get the name of the ChromaDB collection."""
return self._config.collection_name
def _create_embedding_function(self) -> Any:
"""Create an embedding function based on the configuration.
Returns:
A ChromaDB-compatible embedding function.
Raises:
ValueError: If the embedding function type is unsupported.
ImportError: If required dependencies are not installed.
"""
try:
from chromadb.utils import embedding_functions
except ImportError as e:
raise ImportError(
"ChromaDB embedding functions not available. Ensure chromadb is properly installed."
) from e
config = self._config.embedding_function_config
if isinstance(config, DefaultEmbeddingFunctionConfig):
return embedding_functions.DefaultEmbeddingFunction()
elif isinstance(config, SentenceTransformerEmbeddingFunctionConfig):
try:
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=config.model_name)
except Exception as e:
raise ImportError(
f"Failed to create SentenceTransformer embedding function with model '{config.model_name}'. "
f"Ensure sentence-transformers is installed and the model is available. Error: {e}"
) from e
elif isinstance(config, OpenAIEmbeddingFunctionConfig):
try:
return embedding_functions.OpenAIEmbeddingFunction(api_key=config.api_key, model_name=config.model_name)
except Exception as e:
raise ImportError(
f"Failed to create OpenAI embedding function with model '{config.model_name}'. "
f"Ensure openai is installed and API key is valid. Error: {e}"
) from e
elif isinstance(config, CustomEmbeddingFunctionConfig):
try:
return config.function(**config.params)
except Exception as e:
raise ValueError(f"Failed to create custom embedding function. Error: {e}") from e
else:
raise ValueError(f"Unsupported embedding function config type: {type(config)}")
def _ensure_initialized(self) -> None:
"""Ensure ChromaDB client and collection are initialized."""
if self._client is None:
try:
from chromadb.config import Settings
settings = Settings(allow_reset=self._config.allow_reset)
if isinstance(self._config, PersistentChromaDBVectorMemoryConfig):
self._client = PersistentClient(
path=self._config.persistence_path,
settings=settings,
tenant=self._config.tenant,
database=self._config.database,
)
elif isinstance(self._config, HttpChromaDBVectorMemoryConfig):
self._client = HttpClient(
host=self._config.host,
port=self._config.port,
ssl=self._config.ssl,
headers=self._config.headers,
settings=settings,
tenant=self._config.tenant,
database=self._config.database,
)
else:
raise ValueError(f"Unsupported config type: {type(self._config)}")
except Exception as e:
logger.error(f"Failed to initialize ChromaDB client: {e}")
raise
if self._collection is None:
try:
# Create embedding function
embedding_function = self._create_embedding_function()
# Create or get collection with embedding function
self._collection = self._client.get_or_create_collection(
name=self._config.collection_name,
metadata={"distance_metric": self._config.distance_metric},
embedding_function=embedding_function,
)
except Exception as e:
logger.error(f"Failed to get/create collection: {e}")
raise
def _extract_text(self, content_item: str | MemoryContent) -> str:
"""Extract searchable text from content."""
if isinstance(content_item, str):
return content_item
content = content_item.content
mime_type = content_item.mime_type
if mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]:
return str(content)
elif mime_type == MemoryMimeType.JSON:
if isinstance(content, dict):
# Store original JSON string representation
return str(content).lower()
raise ValueError("JSON content must be a dict")
elif isinstance(content, Image):
raise ValueError("Image content cannot be converted to text")
else:
raise ValueError(f"Unsupported content type: {mime_type}")
def _calculate_score(self, distance: float) -> float:
"""Convert ChromaDB distance to a similarity score."""
if self._config.distance_metric == "cosine":
return 1.0 - (distance / 2.0)
return 1.0 / (1.0 + distance)
async def update_context(
self,
model_context: ChatCompletionContext,
) -> UpdateContextResult:
messages = await model_context.get_messages()
if not messages:
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
# Extract query from last message
last_message = messages[-1]
query_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
# Query memory and get results
query_results = await self.query(query_text)
if query_results.results:
# Format results for context
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)]
memory_context = "\nRelevant memory content:\n" + "\n".join(memory_strings)
# Add to context
await model_context.add_message(SystemMessage(content=memory_context))
return UpdateContextResult(memories=query_results)
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
self._ensure_initialized()
if self._collection is None:
raise RuntimeError("Failed to initialize ChromaDB")
try:
# Extract text from content
text = self._extract_text(content)
# Use metadata directly from content
metadata_dict = content.metadata or {}
metadata_dict["mime_type"] = str(content.mime_type)
# Add to ChromaDB
self._collection.add(documents=[text], metadatas=[metadata_dict], ids=[str(uuid.uuid4())])
except Exception as e:
logger.error(f"Failed to add content to ChromaDB: {e}")
raise
async def query(
self,
query: str | MemoryContent,
cancellation_token: CancellationToken | None = None,
**kwargs: Any,
) -> MemoryQueryResult:
self._ensure_initialized()
if self._collection is None:
raise RuntimeError("Failed to initialize ChromaDB")
try:
# Extract text for query
query_text = self._extract_text(query)
# Query ChromaDB
results = self._collection.query(
query_texts=[query_text],
n_results=self._config.k,
include=["documents", "metadatas", "distances"],
**kwargs,
)
# Convert results to MemoryContent list
memory_results: List[MemoryContent] = []
if (
not results
or not results.get("documents")
or not results.get("metadatas")
or not results.get("distances")
):
return MemoryQueryResult(results=memory_results)
documents: List[Document] = results["documents"][0] if results["documents"] else []
metadatas: List[Metadata] = results["metadatas"][0] if results["metadatas"] else []
distances: List[float] = results["distances"][0] if results["distances"] else []
ids: List[str] = results["ids"][0] if results["ids"] else []
for doc, metadata_dict, distance, doc_id in zip(documents, metadatas, distances, ids, strict=False):
# Calculate score
score = self._calculate_score(distance)
metadata = dict(metadata_dict)
metadata["score"] = score
metadata["id"] = doc_id
if self._config.score_threshold is not None and score < self._config.score_threshold:
continue
# Extract mime_type from metadata
mime_type = str(metadata_dict.get("mime_type", MemoryMimeType.TEXT.value))
# Create MemoryContent
content = MemoryContent(
content=doc,
mime_type=mime_type,
metadata=metadata,
)
memory_results.append(content)
return MemoryQueryResult(results=memory_results)
except Exception as e:
logger.error(f"Failed to query ChromaDB: {e}")
raise
async def clear(self) -> None:
self._ensure_initialized()
if self._collection is None:
raise RuntimeError("Failed to initialize ChromaDB")
try:
results = self._collection.get()
if results and results["ids"]:
self._collection.delete(ids=results["ids"])
except Exception as e:
logger.error(f"Failed to clear ChromaDB collection: {e}")
raise
async def close(self) -> None:
"""Clean up ChromaDB client and resources."""
self._collection = None
self._client = None
async def reset(self) -> None:
self._ensure_initialized()
if not self._config.allow_reset:
raise RuntimeError("Reset not allowed. Set allow_reset=True in config to enable.")
if self._client is not None:
try:
self._client.reset()
except Exception as e:
logger.error(f"Error during ChromaDB reset: {e}")
finally:
self._collection = None
def _to_config(self) -> ChromaDBVectorMemoryConfig:
"""Serialize the memory configuration."""
return self._config
@classmethod
def _from_config(cls, config: ChromaDBVectorMemoryConfig) -> Self:
"""Deserialize the memory configuration."""
return cls(config=config)

View File

@@ -0,0 +1,6 @@
from ._mem0 import Mem0Memory, Mem0MemoryConfig
__all__ = [
"Mem0Memory",
"Mem0MemoryConfig",
]

View File

@@ -0,0 +1,449 @@
import io
import logging
import uuid
from contextlib import redirect_stderr, redirect_stdout
from datetime import datetime
from typing import Any, Dict, List, Optional, TypedDict, cast
from agentdhal_core import CancellationToken, Component, ComponentBase
from agentdhal_core.memory import Memory, MemoryContent, MemoryQueryResult, UpdateContextResult
from agentdhal_core.model_context import ChatCompletionContext
from agentdhal_core.models import SystemMessage
from mem0 import Memory as Memory0
from mem0 import MemoryClient
from pydantic import BaseModel, Field
from typing_extensions import Self
logger = logging.getLogger(__name__)
logging.getLogger("chromadb").setLevel(logging.ERROR)
class Mem0MemoryConfig(BaseModel):
"""Configuration for Mem0Memory component."""
user_id: Optional[str] = Field(
default=None, description="User ID for memory operations. If not provided, a UUID will be generated."
)
limit: int = Field(default=10, description="Maximum number of results to return in memory queries.")
is_cloud: bool = Field(default=True, description="Whether to use cloud Mem0 client (True) or local client (False).")
api_key: Optional[str] = Field(
default=None, description="API key for cloud Mem0 client. Required if is_cloud=True."
)
config: Optional[Dict[str, Any]] = Field(
default=None, description="Configuration dictionary for local Mem0 client. Required if is_cloud=False."
)
class MemoryResult(TypedDict, total=False):
memory: str
score: float
metadata: Dict[str, Any]
created_at: str
updated_at: str
categories: List[str]
# pyright: reportGeneralTypeIssues=false
class Mem0Memory(Memory, Component[Mem0MemoryConfig], ComponentBase[Mem0MemoryConfig]):
"""Mem0 memory implementation for AutoGen.
This component integrates with Mem0.ai's memory system, providing an implementation
of AutoGen's Memory interface. It supports both cloud and local backends through the
mem0ai Python package.
To use this component, you need to have the `mem0` (for cloud-only) or `mem0-local` (for local)
extra installed for the `autogen-ext` package:
.. code-block:: bash
pip install -U "agentdhal-ext[mem0]" # For cloud-based Mem0
pip install -U "agentdhal-ext[mem0-local]" # For local Mem0
The memory component can store and retrieve information that agents need to remember
across conversations. It also provides context updating for language models with
relevant memories.
Examples:
.. code-block:: python
import asyncio
from agentdhal_extensions.memory.mem0 import Mem0Memory
from agentdhal_core.memory import MemoryContent
async def main() -> None:
# Create a local Mem0Memory (no API key required)
memory = Mem0Memory(
is_cloud=False,
config={"path": ":memory:"}, # Use in-memory storage for testing
)
print("Memory initialized successfully!")
# Add something to memory
test_content = "User likes the color blue."
await memory.add(MemoryContent(content=test_content, mime_type="text/plain"))
print(f"Added content: {test_content}")
# Retrieve memories with a search query
results = await memory.query("What color does the user like?")
print(f"Query results: {len(results.results)} found")
for i, result in enumerate(results.results):
print(f"Result {i+1}: {result}")
asyncio.run(main())
Output:
.. code-block:: text
Memory initialized successfully!
Added content: User likes the color blue.
Query results: 1 found
Result 1: content='User likes the color blue' mime_type='text/plain' metadata={'score': 0.6977155806281953, 'created_at': datetime.datetime(2025, 7, 6, 17, 25, 18, 754725, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200)))}
Using it with an :class:`~agentdhal_agentchat.agents.AssistantAgent`:
.. code-block:: python
import asyncio
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_core.memory import MemoryContent
from agentdhal_extensions.memory.mem0 import Mem0Memory
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
async def main() -> None:
# Create a model client
model_client = OpenAIChatCompletionClient(model="gpt-4.1")
# Create a Mem0 memory instance
memory = Mem0Memory(
user_id="user123",
is_cloud=False,
config={"path": ":memory:"}, # Use in-memory storage for testing
)
# Add something to memory
test_content = "User likes the color blue."
await memory.add(MemoryContent(content=test_content, mime_type="text/plain"))
# Create an assistant agent with Mem0 memory
agent = AssistantAgent(
name="assistant",
model_client=model_client,
memory=[memory],
system_message="You are a helpful assistant that remembers user preferences.",
)
# Run a sample task
result = await agent.run(task="What color does the user like?")
print(result.messages[-1].content) # type: ignore
asyncio.run(main())
Output:
.. code-block:: text
User likes the color blue.
Args:
user_id: Optional user ID for memory operations. If not provided, a UUID will be generated.
limit: Maximum number of results to return in memory queries.
is_cloud: Whether to use cloud Mem0 client (True) or local client (False).
api_key: API key for cloud Mem0 client. It will read from the environment MEM0_API_KEY if not provided.
config: Configuration dictionary for local Mem0 client. Required if is_cloud=False.
"""
component_type = "memory"
component_provider_override = "agentdhal_extensions.memory.mem0.Mem0Memory"
component_config_schema = Mem0MemoryConfig
def __init__(
self,
user_id: Optional[str] = None,
limit: int = 10,
is_cloud: bool = True,
api_key: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
) -> None:
# Validate parameters
if not is_cloud and config is None:
raise ValueError("config is required when using local Mem0 client (is_cloud=False)")
# Initialize instance variables
self._user_id = user_id or str(uuid.uuid4())
self._limit = limit
self._is_cloud = is_cloud
self._api_key = api_key
self._config = config
# Initialize client
if self._is_cloud:
self._client = MemoryClient(api_key=self._api_key)
else:
assert self._config is not None
config_dict = self._config
self._client = Memory0.from_config(config_dict=config_dict) # type: ignore
@property
def user_id(self) -> str:
"""Get the user ID for memory operations."""
return self._user_id
@property
def limit(self) -> int:
"""Get the maximum number of results to return in memory queries."""
return self._limit
@property
def is_cloud(self) -> bool:
"""Check if the Mem0 client is cloud-based."""
return self._is_cloud
@property
def config(self) -> Optional[Dict[str, Any]]:
"""Get the configuration for the Mem0 client."""
return self._config
async def add(
self,
content: MemoryContent,
cancellation_token: Optional[CancellationToken] = None,
) -> None:
"""Add content to memory.
Args:
content: The memory content to add.
cancellation_token: Optional token to cancel operation.
Raises:
Exception: If there's an error adding content to mem0 memory.
"""
# Extract content based on mime type
if hasattr(content, "content") and hasattr(content, "mime_type"):
if content.mime_type in ["text/plain", "text/markdown"]:
message = str(content.content)
elif content.mime_type == "application/json":
# Convert JSON content to string representation
if isinstance(content.content, str):
message = content.content
else:
# Convert dict or other JSON serializable objects to string
import json
message = json.dumps(content.content)
else:
message = str(content.content)
# Extract metadata
metadata = content.metadata or {}
else:
# Handle case where content is directly provided as string
message = str(content)
metadata = {}
# Check if operation is cancelled
if cancellation_token is not None and cancellation_token.cancelled: # type: ignore
return
# Add to mem0 client
try:
user_id = metadata.pop("user_id", self._user_id)
# Suppress warning messages from mem0 MemoryClient
kwargs = {} if self._client.__class__.__name__ == "Memory" else {"output_format": "v1.1"}
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
self._client.add([{"role": "user", "content": message}], user_id=user_id, metadata=metadata, **kwargs) # type: ignore
except Exception as e:
# Log the error but don't crash
logger.error(f"Error adding to mem0 memory: {str(e)}")
raise
async def query(
self,
query: str | MemoryContent = "",
cancellation_token: Optional[CancellationToken] = None,
**kwargs: Any,
) -> MemoryQueryResult:
"""Query memory for relevant content.
Args:
query: The query to search for, either as string or MemoryContent.
cancellation_token: Optional token to cancel operation.
**kwargs: Additional query parameters to pass to mem0.
Returns:
MemoryQueryResult containing search results.
"""
# Extract query text
if isinstance(query, str):
query_text = query
elif hasattr(query, "content"):
query_text = str(query.content)
else:
query_text = str(query)
# Check if operation is cancelled
if (
cancellation_token
and hasattr(cancellation_token, "cancelled")
and getattr(cancellation_token, "cancelled", False)
):
return MemoryQueryResult(results=[])
try:
limit = kwargs.pop("limit", self._limit)
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
# Query mem0 client
results = self._client.search( # type: ignore
query_text,
user_id=self._user_id,
limit=limit,
**kwargs,
)
# Type-safe handling of results
if isinstance(results, dict) and "results" in results:
result_list = cast(List[MemoryResult], results["results"])
else:
result_list = cast(List[MemoryResult], results)
# Convert results to MemoryContent objects
memory_contents: List[MemoryContent] = []
for result in result_list:
content_text = result.get("memory", "")
metadata: Dict[str, Any] = {}
if "metadata" in result and result["metadata"]:
metadata = result["metadata"]
# Add relevant fields to metadata
if "score" in result:
metadata["score"] = result["score"]
# For created_at
if "created_at" in result and result.get("created_at"):
try:
metadata["created_at"] = datetime.fromisoformat(result["created_at"])
except (ValueError, TypeError):
pass
# For updated_at
if "updated_at" in result and result.get("updated_at"):
try:
metadata["updated_at"] = datetime.fromisoformat(result["updated_at"])
except (ValueError, TypeError):
pass
# For categories
if "categories" in result and result.get("categories"):
metadata["categories"] = result["categories"]
# Create MemoryContent object
memory_content = MemoryContent(
content=content_text,
mime_type="text/plain", # Default to text/plain
metadata=metadata,
)
memory_contents.append(memory_content)
return MemoryQueryResult(results=memory_contents)
except Exception as e:
# Log the error but return empty results
logger.error(f"Error querying mem0 memory: {str(e)}")
return MemoryQueryResult(results=[])
async def update_context(
self,
model_context: ChatCompletionContext,
) -> UpdateContextResult:
"""Update the model context with relevant memories.
This method retrieves the conversation history from the model context,
uses the last message as a query to find relevant memories, and then
adds those memories to the context as a system message.
Args:
model_context: The model context to update.
Returns:
UpdateContextResult containing memories added to the context.
"""
# Get messages from context
messages = await model_context.get_messages()
if not messages:
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
# Use the last message as query
last_message = messages[-1]
query_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
# Query memory
query_results = await self.query(query_text, limit=self._limit)
# If we have results, add them to the context
if query_results.results:
# Format memories as numbered list
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)]
memory_context = "\nRelevant memories:\n" + "\n".join(memory_strings)
# Add as system message
await model_context.add_message(SystemMessage(content=memory_context))
return UpdateContextResult(memories=query_results)
async def clear(self) -> None:
"""Clear all content from memory for the current user.
Raises:
Exception: If there's an error clearing mem0 memory.
"""
try:
self._client.delete_all(user_id=self._user_id) # type: ignore
except Exception as e:
logger.error(f"Error clearing mem0 memory: {str(e)}")
raise
async def close(self) -> None:
"""Clean up resources if needed.
This is a no-op for Mem0 clients as they don't require explicit cleanup.
"""
pass
@classmethod
def _from_config(cls, config: Mem0MemoryConfig) -> Self:
"""Create instance from configuration.
Args:
config: Configuration for Mem0Memory component.
Returns:
A new Mem0Memory instance.
"""
return cls(
user_id=config.user_id,
limit=config.limit,
is_cloud=config.is_cloud,
api_key=config.api_key,
config=config.config,
)
def _to_config(self) -> Mem0MemoryConfig:
"""Convert instance to configuration.
Returns:
Configuration representing this Mem0Memory instance.
"""
return Mem0MemoryConfig(
user_id=self._user_id,
limit=self._limit,
is_cloud=self._is_cloud,
api_key=self._api_key,
config=self._config,
)

View File

@@ -0,0 +1,9 @@
from ._redis_memory import (
RedisMemory,
RedisMemoryConfig,
)
__all__ = [
"RedisMemoryConfig",
"RedisMemory",
]

View File

@@ -0,0 +1,325 @@
import logging
from typing import Any, List, Literal
from agentdhal_core import CancellationToken, Component
from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
from agentdhal_core.model_context import ChatCompletionContext
from agentdhal_core.models import SystemMessage
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
try:
from redis import Redis
from redisvl.extensions.message_history import SemanticMessageHistory
from redisvl.utils.utils import deserialize, serialize
except ImportError as e:
raise ImportError("To use Redis Memory RedisVL must be installed. Run `pip install autogen-ext[redisvl]`") from e
class RedisMemoryConfig(BaseModel):
"""
Configuration for Redis-based vector memory.
This class defines the configuration options for using Redis as a vector memory store,
supporting semantic memory. It allows customization of the Redis connection, index settings,
similarity search parameters, and embedding model.
"""
redis_url: str = Field(default="redis://localhost:6379", description="url of the Redis instance")
index_name: str = Field(default="chat_history", description="Name of the Redis collection")
prefix: str = Field(default="memory", description="prefix of the Redis collection")
distance_metric: Literal["cosine", "ip", "l2"] = "cosine"
algorithm: Literal["flat", "hnsw"] = "flat"
top_k: int = Field(default=10, description="Number of results to return in queries")
datatype: Literal["uint8", "int8", "float16", "float32", "float64", "bfloat16"] = "float32"
distance_threshold: float = Field(default=0.7, description="Minimum similarity score threshold")
model_name: str | None = Field(
default="sentence-transformers/all-mpnet-base-v2", description="Embedding model name"
)
class RedisMemory(Memory, Component[RedisMemoryConfig]):
"""
Store and retrieve memory using vector similarity search powered by RedisVL.
`RedisMemory` provides a vector-based memory implementation that uses RedisVL for storing and
retrieving content based on semantic similarity. It enhances agents with the ability to recall
contextually relevant information during conversations by leveraging vector embeddings to find
similar content.
This implementation requires the RedisVL extra to be installed. Install with:
.. code-block:: bash
pip install "agentdhal-ext[redisvl]"
Additionally, you will need access to a Redis instance.
To run a local instance of redis in docker:
.. code-block:: bash
docker run -d --name redis -p 6379:6379 redis:8
To download and run Redis locally:
.. code-block:: bash
curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list
sudo apt-get update > /dev/null 2>&1
sudo apt-get install redis-server > /dev/null 2>&1
redis-server --daemonize yes
Args:
config (RedisMemoryConfig | None): Configuration for the Redis memory.
If None, defaults to a RedisMemoryConfig with recommended settings.
Example:
.. code-block:: python
from logging import WARNING, getLogger
import asyncio
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.ui import Console
from agentdhal_core.memory import MemoryContent, MemoryMimeType
from agentdhal_extensions.memory.redis import RedisMemory, RedisMemoryConfig
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
logger = getLogger()
logger.setLevel(WARNING)
# Define tool to use
async def get_weather(city: str, units: str = "imperial") -> str:
if units == "imperial":
return f"The weather in {city} is 73 °F and Sunny."
elif units == "metric":
return f"The weather in {city} is 23 °C and Sunny."
else:
return f"Sorry, I don't know the weather in {city}."
async def main():
# Initailize Redis memory
redis_memory = RedisMemory(
config=RedisMemoryConfig(
redis_url="redis://localhost:6379",
index_name="chat_history",
prefix="memory",
)
)
# Add user preferences to memory
await redis_memory.add(
MemoryContent(
content="The weather should be in metric units",
mime_type=MemoryMimeType.TEXT,
metadata={"category": "preferences", "type": "units"},
)
)
await redis_memory.add(
MemoryContent(
content="Meal recipe must be vegan",
mime_type=MemoryMimeType.TEXT,
metadata={"category": "preferences", "type": "dietary"},
)
)
model_client = OpenAIChatCompletionClient(
model="gpt-4o",
)
# Create assistant agent with ChromaDB memory
assistant_agent = AssistantAgent(
name="assistant_agent",
model_client=model_client,
tools=[get_weather],
memory=[redis_memory],
)
stream = assistant_agent.run_stream(task="What is the weather in New York?")
await Console(stream)
await model_client.close()
await redis_memory.close()
asyncio.run(main())
Output:
.. code-block:: text
---------- TextMessage (user) ----------
What is the weather in New York?
---------- MemoryQueryEvent (assistant_agent) ----------
[MemoryContent(content='The weather should be in metric units', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata={'category': 'preferences', 'type': 'units'})]
---------- ToolCallRequestEvent (assistant_agent) ----------
[FunctionCall(id='call_tyCPvPPAV4SHWhtfpM6UMemr', arguments='{"city":"New York","units":"metric"}', name='get_weather')]
---------- ToolCallExecutionEvent (assistant_agent) ----------
[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_tyCPvPPAV4SHWhtfpM6UMemr', is_error=False)]
---------- ToolCallSummaryMessage (assistant_agent) ----------
The weather in New York is 23 °C and Sunny.
"""
component_config_schema = RedisMemoryConfig
component_provider_override = "agentdhal_extensions.memory.redis_memory.RedisMemory"
def __init__(self, config: RedisMemoryConfig | None = None) -> None:
"""Initialize RedisMemory."""
self.config = config or RedisMemoryConfig()
client = Redis.from_url(url=self.config.redis_url) # type: ignore[reportUknownMemberType]
self.message_history = SemanticMessageHistory(name=self.config.index_name, redis_client=client)
async def update_context(
self,
model_context: ChatCompletionContext,
) -> UpdateContextResult:
"""
Update the model context with relevant memory content.
This method retrieves memory content relevant to the last message in the context
and adds it as a system message. This implementation uses the last message in the context
as a query to find semantically similar memories and adds them all to the context as a
single system message.
Args:
model_context (ChatCompletionContext): The model context to update with relevant
memories.
Returns:
UpdateContextResult: Object containing the memories that were used to update the
context.
"""
messages = await model_context.get_messages()
if messages:
last_message = str(messages[-1].content)
else:
last_message = ""
query_results = await self.query(last_message)
stringified_messages = "\n\n".join([str(m.content) for m in query_results.results])
await model_context.add_message(SystemMessage(content=stringified_messages))
return UpdateContextResult(memories=query_results)
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
"""Add a memory content object to Redis.
.. note::
To perform semantic search over stored memories RedisMemory creates a vector embedding
from the content field of a MemoryContent object. This content is assumed to be text,
JSON, or Markdown, and is passed to the vector embedding model specified in
RedisMemoryConfig.
Args:
content (MemoryContent): The memory content to store within Redis.
cancellation_token (CancellationToken): Token passed to cease operation. Not used.
"""
if content.mime_type == MemoryMimeType.TEXT:
memory_content = content.content
mime_type = "text/plain"
elif content.mime_type == MemoryMimeType.JSON:
memory_content = serialize(content.content)
mime_type = "application/json"
elif content.mime_type == MemoryMimeType.MARKDOWN:
memory_content = content.content
mime_type = "text/markdown"
else:
raise NotImplementedError(
f"Error: {content.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, and MemoryMimeType.MARKDOWN are currently supported."
)
metadata = {"mime_type": mime_type}
metadata.update(content.metadata if content.metadata else {})
self.message_history.add_message(
{"role": "user", "content": memory_content, "tool_call_id": serialize(metadata)} # type: ignore[reportArgumentType]
)
async def query(
self,
query: str | MemoryContent,
cancellation_token: CancellationToken | None = None,
**kwargs: Any,
) -> MemoryQueryResult:
"""Query memory content based on semantic vector similarity.
.. note::
RedisMemory.query() supports additional keyword arguments to improve query performance.
top_k (int): The maximum number of relevant memories to include. Defaults to 10.
distance_threshold (float): The maximum distance in vector space to consider a memory
semantically similar when performining cosine similarity search. Defaults to 0.7.
Args:
query (str | MemoryContent): query to perform vector similarity search with. If a
string is passed, a vector embedding is created from it with the model specified
in the RedisMemoryConfig. If a MemoryContent object is passed, the content field
of this object is extracted and a vector embedding is created from it with the
model specified in the RedisMemoryConfig.
cancellation_token (CancellationToken): Token passed to cease operation. Not used.
Returns:
memoryQueryResult: Object containing memories relevant to the provided query.
"""
# get the query string, or raise an error for unsupported MemoryContent types
if isinstance(query, str):
prompt = query
elif isinstance(query, MemoryContent):
if query.mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
prompt = str(query.content)
elif query.mime_type == MemoryMimeType.JSON:
prompt = serialize(query.content)
else:
raise NotImplementedError(
f"Error: {query.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, MemoryMimeType.MARKDOWN are currently supported."
)
else:
raise TypeError("'query' must be either a string or MemoryContent")
top_k = kwargs.pop("top_k", self.config.top_k)
distance_threshold = kwargs.pop("distance_threshold", self.config.distance_threshold)
results = self.message_history.get_relevant(
prompt=prompt, # type: ignore[reportArgumentType]
top_k=top_k,
distance_threshold=distance_threshold,
raw=False,
)
memories: List[MemoryContent] = []
for result in results:
metadata = deserialize(result["tool_call_id"]) # type: ignore[reportArgumentType]
mime_type = MemoryMimeType(metadata.pop("mime_type"))
if mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
memory_content = result["content"] # type: ignore[reportArgumentType]
elif mime_type == MemoryMimeType.JSON:
memory_content = deserialize(result["content"]) # type: ignore[reportArgumentType]
else:
raise NotImplementedError(
f"Error: {mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, and MemoryMimeType.MARKDOWN are currently supported."
)
memory = MemoryContent(
content=memory_content, # type: ignore[reportArgumentType]
mime_type=mime_type,
metadata=metadata,
)
memories.append(memory) # type: ignore[reportUknownMemberType]
return MemoryQueryResult(results=memories) # type: ignore[reportUknownMemberType]
async def clear(self) -> None:
"""Clear all entries from memory, preserving the RedisMemory resources."""
self.message_history.clear()
async def close(self) -> None:
"""Clears all entries from memory, and cleans up Redis client, index and resources."""
self.message_history.delete()

View File

View File

@@ -0,0 +1,16 @@
from ._worker_runtime import GrpcWorkerAgentRuntime
from ._worker_runtime_host import GrpcWorkerAgentRuntimeHost
from ._worker_runtime_host_servicer import GrpcWorkerAgentRuntimeHostServicer
try:
import grpc # type: ignore
except ImportError as e:
raise ImportError(
"To use the GRPC runtime the grpc extra must be installed. Run `pip install autogen-ext[grpc]`"
) from e
__all__ = [
"GrpcWorkerAgentRuntime",
"GrpcWorkerAgentRuntimeHost",
"GrpcWorkerAgentRuntimeHostServicer",
]

View File

@@ -0,0 +1,13 @@
GRPC_IMPORT_ERROR_STR = (
"Distributed runtime features require additional dependencies. Install them with: pip install autogen-core[grpc]"
)
DATA_CONTENT_TYPE_ATTR = "datacontenttype"
DATA_SCHEMA_ATTR = "dataschema"
AGENT_SENDER_TYPE_ATTR = "agagentsendertype"
AGENT_SENDER_KEY_ATTR = "agagentsenderkey"
MESSAGE_KIND_ATTR = "agmsgkind"
MESSAGE_KIND_VALUE_PUBLISH = "publish"
MESSAGE_KIND_VALUE_RPC_REQUEST = "rpc_request"
MESSAGE_KIND_VALUE_RPC_RESPONSE = "rpc_response"
MESSAGE_KIND_VALUE_RPC_ERROR = "error"

View File

@@ -0,0 +1,4 @@
from typing import Any, Sequence, Tuple
# Had to redefine this from grpc.aio._typing as using that one was causing mypy errors
ChannelArgumentType = Sequence[Tuple[str, Any]]

View File

@@ -0,0 +1,45 @@
from agentdhal_core._subscription import Subscription
from agentdhal_core._type_prefix_subscription import TypePrefixSubscription
from agentdhal_core._type_subscription import TypeSubscription
from .protos import agent_worker_pb2
def subscription_to_proto(subscription: Subscription) -> agent_worker_pb2.Subscription:
match subscription:
case TypeSubscription(topic_type=topic_type, agent_type=agent_type, id=id):
return agent_worker_pb2.Subscription(
id=id,
typeSubscription=agent_worker_pb2.TypeSubscription(topic_type=topic_type, agent_type=agent_type),
)
case TypePrefixSubscription(topic_type_prefix=topic_type_prefix, agent_type=agent_type, id=id):
return agent_worker_pb2.Subscription(
id=id,
typePrefixSubscription=agent_worker_pb2.TypePrefixSubscription(
topic_type_prefix=topic_type_prefix, agent_type=agent_type
),
)
case _:
raise ValueError("Unsupported subscription type.")
def subscription_from_proto(subscription: agent_worker_pb2.Subscription) -> Subscription:
oneofcase = subscription.WhichOneof("subscription")
match oneofcase:
case "typeSubscription":
type_subscription_msg: agent_worker_pb2.TypeSubscription = subscription.typeSubscription
return TypeSubscription(
topic_type=type_subscription_msg.topic_type,
agent_type=type_subscription_msg.agent_type,
id=subscription.id,
)
case "typePrefixSubscription":
type_prefix_subscription_msg: agent_worker_pb2.TypePrefixSubscription = subscription.typePrefixSubscription
return TypePrefixSubscription(
topic_type_prefix=type_prefix_subscription_msg.topic_type_prefix,
agent_type=type_prefix_subscription_msg.agent_type,
id=subscription.id,
)
case None:
raise ValueError("Invalid subscription message.")

View File

@@ -0,0 +1,856 @@
from __future__ import annotations
import asyncio
import inspect
import json
import logging
import signal
import uuid
import warnings
from asyncio import Future, Task
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
ClassVar,
DefaultDict,
Dict,
List,
Literal,
Mapping,
ParamSpec,
Sequence,
Set,
Tuple,
Type,
TypeVar,
cast,
)
from agentdhal_core import (
JSON_DATA_CONTENT_TYPE,
PROTOBUF_DATA_CONTENT_TYPE,
Agent,
AgentId,
AgentInstantiationContext,
AgentMetadata,
AgentRuntime,
AgentType,
CancellationToken,
MessageContext,
MessageHandlerContext,
MessageSerializer,
Subscription,
TopicId,
)
from agentdhal_core._runtime_impl_helpers import SubscriptionManager, get_impl
from agentdhal_core._serialization import (
SerializationRegistry,
)
from agentdhal_core._telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata
from google.protobuf import any_pb2
from opentelemetry.trace import TracerProvider
from typing_extensions import Self
from agentdhal_extensions.runtimes.grpc._utils import subscription_to_proto
from . import _constants
from ._constants import GRPC_IMPORT_ERROR_STR
from ._type_helpers import ChannelArgumentType
from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
try:
import grpc.aio
except ImportError as e:
raise ImportError(GRPC_IMPORT_ERROR_STR) from e
if TYPE_CHECKING:
from .protos.agent_worker_pb2_grpc import AgentRpcAsyncStub
logger = logging.getLogger("agentdhal_core")
event_logger = logging.getLogger("agentdhal_core.events")
P = ParamSpec("P")
T = TypeVar("T", bound=Agent)
type_func_alias = type
class QueueAsyncIterable(AsyncIterator[Any], AsyncIterable[Any]):
def __init__(self, queue: asyncio.Queue[Any]) -> None:
self._queue = queue
async def __anext__(self) -> Any:
return await self._queue.get()
def __aiter__(self) -> AsyncIterator[Any]:
return self
class HostConnection:
DEFAULT_GRPC_CONFIG: ClassVar[ChannelArgumentType] = [
(
"grpc.service_config",
json.dumps(
{
"methodConfig": [
{
"name": [{}],
"retryPolicy": {
"maxAttempts": 3,
"initialBackoff": "0.01s",
"maxBackoff": "5s",
"backoffMultiplier": 2,
"retryableStatusCodes": ["UNAVAILABLE"],
},
}
],
}
),
)
]
def __init__(self, channel: grpc.aio.Channel, stub: Any) -> None: # type: ignore
self._channel = channel
self._send_queue = asyncio.Queue[agent_worker_pb2.Message]()
self._recv_queue = asyncio.Queue[agent_worker_pb2.Message]()
self._connection_task: Task[None] | None = None
self._stub: AgentRpcAsyncStub = stub
self._client_id = str(uuid.uuid4())
@property
def stub(self) -> Any:
return self._stub
@property
def metadata(self) -> Sequence[Tuple[str, str]]:
return [("client-id", self._client_id)]
@classmethod
async def from_host_address(
cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG
) -> Self:
logger.info("Connecting to %s", host_address)
# Always use DEFAULT_GRPC_CONFIG and override it with provided grpc_config
merged_options = [
(k, v) for k, v in {**dict(HostConnection.DEFAULT_GRPC_CONFIG), **dict(extra_grpc_config)}.items()
]
channel = grpc.aio.insecure_channel(
host_address,
options=merged_options,
)
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore
instance = cls(channel, stub)
instance._connection_task = await instance._connect(
stub, instance._send_queue, instance._recv_queue, instance._client_id
)
return instance
async def close(self) -> None:
if self._connection_task is None:
raise RuntimeError("Connection is not open.")
await self._channel.close()
await self._connection_task
@staticmethod
async def _connect(
stub: Any, # AgentRpcAsyncStub
send_queue: asyncio.Queue[agent_worker_pb2.Message],
receive_queue: asyncio.Queue[agent_worker_pb2.Message],
client_id: str,
) -> Task[None]:
from grpc.aio import StreamStreamCall
# TODO: where do exceptions from reading the iterable go? How do we recover from those?
stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
QueueAsyncIterable(send_queue), metadata=[("client-id", client_id)]
)
await stream.wait_for_connection()
async def read_loop() -> None:
while True:
logger.info("Waiting for message from host")
message = cast(agent_worker_pb2.Message, await stream.read()) # type: ignore
if message == grpc.aio.EOF: # type: ignore
logger.info("EOF")
break
logger.info(f"Received a message from host: {message}")
await receive_queue.put(message)
logger.info("Put message in receive queue")
return asyncio.create_task(read_loop())
async def send(self, message: agent_worker_pb2.Message) -> None:
logger.info(f"Send message to host: {message}")
await self._send_queue.put(message)
logger.info("Put message in send queue")
async def recv(self) -> agent_worker_pb2.Message:
logger.info("Getting message from queue")
return await self._recv_queue.get()
# TODO: Lots of types need to have protobuf equivalents:
# Core:
# - FunctionCall, CodeResult, possibly CodeBlock
# - All the types in https://github.com/microsoft/autogen/blob/main/python/packages/autogen-core/src/agentdhal_core/models/_types.py
#
# Agentchat:
# - All the types in https://github.com/microsoft/autogen/blob/main/python/packages/autogen-agentchat/src/agentdhal_agentchat/messages.py to protobufs.
#
# Ext --
# CodeExecutor:
# - CommandLineCodeResult
class GrpcWorkerAgentRuntime(AgentRuntime):
"""An agent runtime for running remote or cross-language agents.
Agent messaging uses protobufs from `agent_worker.proto`_ and ``CloudEvent`` from `cloudevent.proto`_.
Cross-language agents will additionally require all agents use shared protobuf schemas for any message types that are sent between agents.
.. _agent_worker.proto: https://github.com/microsoft/autogen/blob/main/protos/agent_worker.proto
.. _cloudevent.proto: https://github.com/microsoft/autogen/blob/main/protos/cloudevent.proto
"""
# TODO: Needs to handle agent close() call
def __init__(
self,
host_address: str,
tracer_provider: TracerProvider | None = None,
extra_grpc_config: ChannelArgumentType | None = None,
payload_serialization_format: str = JSON_DATA_CONTENT_TYPE,
) -> None:
self._host_address = host_address
self._trace_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("Worker Runtime"))
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
self._agent_factories: Dict[
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
] = {}
self._instantiated_agents: Dict[AgentId, Agent] = {}
self._known_namespaces: set[str] = set()
self._read_task: None | Task[None] = None
self._running = False
self._pending_requests: Dict[str, Future[Any]] = {}
self._pending_requests_lock = asyncio.Lock()
self._next_request_id = 0
self._host_connection: HostConnection | None = None
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
self._serialization_registry = SerializationRegistry()
self._extra_grpc_config = extra_grpc_config or []
self._agent_instance_types: Dict[str, Type[Agent]] = {}
if payload_serialization_format not in {JSON_DATA_CONTENT_TYPE, PROTOBUF_DATA_CONTENT_TYPE}:
raise ValueError(f"Unsupported payload serialization format: {payload_serialization_format}")
self._payload_serialization_format = payload_serialization_format
async def start(self) -> None:
"""Start the runtime in a background task."""
if self._running:
raise ValueError("Runtime is already running.")
logger.info(f"Connecting to host: {self._host_address}")
self._host_connection = await HostConnection.from_host_address(
self._host_address, extra_grpc_config=self._extra_grpc_config
)
logger.info("Connection established")
if self._read_task is None:
self._read_task = asyncio.create_task(self._run_read_loop())
self._running = True
def _raise_on_exception(self, task: Task[Any]) -> None:
exception = task.exception()
if exception is not None:
raise exception
async def _run_read_loop(self) -> None:
logger.info("Starting read loop")
assert self._host_connection is not None
# TODO: catch exceptions and reconnect
while self._running:
try:
message = await self._host_connection.recv()
oneofcase = agent_worker_pb2.Message.WhichOneof(message, "message")
match oneofcase:
case "request":
task = asyncio.create_task(self._process_request(message.request))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "response":
task = asyncio.create_task(self._process_response(message.response))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "cloudEvent":
task = asyncio.create_task(self._process_event(message.cloudEvent))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warning("No message")
except Exception as e:
logger.error("Error in read loop", exc_info=e)
async def stop(self) -> None:
"""Stop the runtime immediately."""
if not self._running:
raise RuntimeError("Runtime is not running.")
self._running = False
# Wait for all background tasks to finish.
final_tasks_results = await asyncio.gather(*self._background_tasks, return_exceptions=True)
for task_result in final_tasks_results:
if isinstance(task_result, Exception):
logger.error("Error in background task", exc_info=task_result)
# Close the host connection.
if self._host_connection is not None:
try:
await self._host_connection.close()
except asyncio.CancelledError:
pass
# Cancel the read task.
if self._read_task is not None:
self._read_task.cancel()
try:
await self._read_task
except asyncio.CancelledError:
pass
async def stop_when_signal(self, signals: Sequence[signal.Signals] = (signal.SIGTERM, signal.SIGINT)) -> None:
"""Stop the runtime when a signal is received."""
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def signal_handler() -> None:
logger.info("Received exit signal, shutting down gracefully...")
shutdown_event.set()
for sig in signals:
loop.add_signal_handler(sig, signal_handler)
# Wait for the signal to trigger the shutdown event.
await shutdown_event.wait()
# Stop the runtime.
await self.stop()
@property
def _known_agent_names(self) -> Set[str]:
return set(self._agent_factories.keys())
async def _send_message(
self,
runtime_message: agent_worker_pb2.Message,
send_type: Literal["send", "publish"],
recipient: AgentId | TopicId,
telemetry_metadata: Mapping[str, str],
) -> None:
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
with self._trace_helper.trace_block(send_type, recipient, parent=telemetry_metadata):
await self._host_connection.send(runtime_message)
async def send_message(
self,
message: Any,
recipient: AgentId,
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> Any:
# TODO: use message_id
if not self._running:
raise ValueError("Runtime must be running when sending message.")
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
data_type = self._serialization_registry.type_name(message)
with self._trace_helper.trace_block(
"create", recipient, parent=None, extraAttributes={"message_type": data_type}
):
# create a new future for the result
future = asyncio.get_event_loop().create_future()
request_id = await self._get_new_request_id()
self._pending_requests[request_id] = future
serialized_message = self._serialization_registry.serialize(
message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE
)
telemetry_metadata = get_telemetry_grpc_metadata()
runtime_message = agent_worker_pb2.Message(
request=agent_worker_pb2.RpcRequest(
request_id=request_id,
target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key),
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None,
metadata=telemetry_metadata,
payload=agent_worker_pb2.Payload(
data_type=data_type,
data=serialized_message,
data_content_type=JSON_DATA_CONTENT_TYPE,
),
)
)
# TODO: Find a way to handle timeouts/errors
task = asyncio.create_task(self._send_message(runtime_message, "send", recipient, telemetry_metadata))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
return await future
async def publish_message(
self,
message: Any,
topic_id: TopicId,
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> None:
if not self._running:
raise ValueError("Runtime must be running when publishing message.")
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
if message_id is None:
message_id = str(uuid.uuid4())
message_type = self._serialization_registry.type_name(message)
with self._trace_helper.trace_block(
"create", topic_id, parent=None, extraAttributes={"message_type": message_type}
):
serialized_message = self._serialization_registry.serialize(
message, type_name=message_type, data_content_type=self._payload_serialization_format
)
sender_id = sender or AgentId("unknown", "unknown")
attributes = {
_constants.DATA_CONTENT_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=self._payload_serialization_format
),
_constants.DATA_SCHEMA_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=message_type),
_constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=sender_id.type
),
_constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=sender_id.key
),
_constants.MESSAGE_KIND_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=_constants.MESSAGE_KIND_VALUE_PUBLISH
),
}
# If sending JSON we fill text_data with the serialized message
# If sending Protobuf we fill proto_data with the serialized message
# TODO: add an encoding field for serializer
if self._payload_serialization_format == JSON_DATA_CONTENT_TYPE:
runtime_message = agent_worker_pb2.Message(
cloudEvent=cloudevent_pb2.CloudEvent(
id=message_id,
spec_version="1.0",
type=topic_id.type,
source=topic_id.source,
attributes=attributes,
# TODO: use text, or proto fields appropriately
binary_data=serialized_message,
)
)
else:
# We need to unpack the serialized proto back into an Any
# TODO: find a way to prevent the roundtrip serialization
any_proto = any_pb2.Any()
any_proto.ParseFromString(serialized_message)
runtime_message = agent_worker_pb2.Message(
cloudEvent=cloudevent_pb2.CloudEvent(
id=message_id,
spec_version="1.0",
type=topic_id.type,
source=topic_id.source,
attributes=attributes,
proto_data=any_proto,
)
)
telemetry_metadata = get_telemetry_grpc_metadata()
task = asyncio.create_task(self._send_message(runtime_message, "publish", topic_id, telemetry_metadata))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
async def save_state(self) -> Mapping[str, Any]:
raise NotImplementedError("Saving state is not yet implemented.")
async def load_state(self, state: Mapping[str, Any]) -> None:
raise NotImplementedError("Loading state is not yet implemented.")
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
raise NotImplementedError("Agent metadata is not yet implemented.")
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
raise NotImplementedError("Agent save_state is not yet implemented.")
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
raise NotImplementedError("Agent load_state is not yet implemented.")
async def _get_new_request_id(self) -> str:
async with self._pending_requests_lock:
self._next_request_id += 1
return str(self._next_request_id)
async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
assert self._host_connection is not None
recipient = AgentId(request.target.type, request.target.key)
sender: AgentId | None = None
if request.HasField("source"):
sender = AgentId(request.source.type, request.source.key)
logging.info(f"Processing request from {sender} to {recipient}")
else:
logging.info(f"Processing request from unknown source to {recipient}")
# Deserialize the message.
message = self._serialization_registry.deserialize(
request.payload.data,
type_name=request.payload.data_type,
data_content_type=request.payload.data_content_type,
)
# Get the receiving agent and prepare the message context.
rec_agent = await self._get_agent(recipient)
message_context = MessageContext(
sender=sender,
topic_id=None,
is_rpc=True,
cancellation_token=CancellationToken(),
message_id=request.request_id,
)
# Call the receiving agent.
try:
with MessageHandlerContext.populate_context(rec_agent.id):
with self._trace_helper.trace_block(
"process",
rec_agent.id,
parent=request.metadata,
attributes={"request_id": request.request_id},
extraAttributes={"message_type": request.payload.data_type},
):
result = await rec_agent.on_message(message, ctx=message_context)
except BaseException as e:
response_message = agent_worker_pb2.Message(
response=agent_worker_pb2.RpcResponse(
request_id=request.request_id,
error=str(e),
metadata=get_telemetry_grpc_metadata(),
),
)
# Send the error response.
await self._host_connection.send(response_message)
return
# Serialize the result.
result_type = self._serialization_registry.type_name(result)
serialized_result = self._serialization_registry.serialize(
result, type_name=result_type, data_content_type=JSON_DATA_CONTENT_TYPE
)
# Create the response message.
response_message = agent_worker_pb2.Message(
response=agent_worker_pb2.RpcResponse(
request_id=request.request_id,
payload=agent_worker_pb2.Payload(
data_type=result_type,
data=serialized_result,
data_content_type=JSON_DATA_CONTENT_TYPE,
),
metadata=get_telemetry_grpc_metadata(),
)
)
# Send the response.
await self._host_connection.send(response_message)
async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> None:
with self._trace_helper.trace_block(
"ack",
None,
parent=response.metadata,
attributes={"request_id": response.request_id},
extraAttributes={"message_type": response.payload.data_type},
):
# Deserialize the result.
result = self._serialization_registry.deserialize(
response.payload.data,
type_name=response.payload.data_type,
data_content_type=response.payload.data_content_type,
)
# Get the future and set the result.
future = self._pending_requests.pop(response.request_id)
if len(response.error) > 0:
future.set_exception(Exception(response.error))
else:
future.set_result(result)
async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
event_attributes = event.attributes
sender: AgentId | None = None
if (
_constants.AGENT_SENDER_TYPE_ATTR in event_attributes
and _constants.AGENT_SENDER_KEY_ATTR in event_attributes
):
sender = AgentId(
event_attributes[_constants.AGENT_SENDER_TYPE_ATTR].ce_string,
event_attributes[_constants.AGENT_SENDER_KEY_ATTR].ce_string,
)
topic_id = TopicId(event.type, event.source)
# Get the recipients for the topic.
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
message_content_type = event_attributes[_constants.DATA_CONTENT_TYPE_ATTR].ce_string
message_type = event_attributes[_constants.DATA_SCHEMA_ATTR].ce_string
if message_content_type == JSON_DATA_CONTENT_TYPE:
message = self._serialization_registry.deserialize(
event.binary_data, type_name=message_type, data_content_type=message_content_type
)
elif message_content_type == PROTOBUF_DATA_CONTENT_TYPE:
# TODO: find a way to prevent the roundtrip serialization
proto_binary_data = event.proto_data.SerializeToString()
message = self._serialization_registry.deserialize(
proto_binary_data, type_name=message_type, data_content_type=message_content_type
)
else:
raise ValueError(f"Unsupported message content type: {message_content_type}")
# TODO: dont read these values in the runtime
topic_type_suffix = topic_id.type.split(":", maxsplit=1)[1] if ":" in topic_id.type else ""
is_rpc = topic_type_suffix == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST
is_marked_rpc_type = (
_constants.MESSAGE_KIND_ATTR in event_attributes
and event_attributes[_constants.MESSAGE_KIND_ATTR].ce_string == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST
)
if is_rpc and not is_marked_rpc_type:
warnings.warn("Received RPC request with topic type suffix but not marked as RPC request.", stacklevel=2)
# Send the message to each recipient.
responses: List[Awaitable[Any]] = []
for agent_id in recipients:
if agent_id == sender:
continue
message_context = MessageContext(
sender=sender,
topic_id=topic_id,
is_rpc=is_rpc,
cancellation_token=CancellationToken(),
message_id=event.id,
)
agent = await self._get_agent(agent_id)
with MessageHandlerContext.populate_context(agent.id):
def stringify_attributes(
attributes: Mapping[str, cloudevent_pb2.CloudEvent.CloudEventAttributeValue],
) -> Mapping[str, str]:
result: Dict[str, str] = {}
for key, value in attributes.items():
item = None
match value.WhichOneof("attr"):
case "ce_boolean":
item = str(value.ce_boolean)
case "ce_integer":
item = str(value.ce_integer)
case "ce_string":
item = value.ce_string
case "ce_bytes":
item = str(value.ce_bytes)
case "ce_uri":
item = value.ce_uri
case "ce_uri_ref":
item = value.ce_uri_ref
case "ce_timestamp":
item = str(value.ce_timestamp)
case _:
raise ValueError("Unknown attribute kind")
result[key] = item
return result
async def send_message(agent: Agent, message_context: MessageContext) -> Any:
with self._trace_helper.trace_block(
"process",
agent.id,
parent=stringify_attributes(event.attributes),
extraAttributes={"message_type": message_type},
):
await agent.on_message(message, ctx=message_context)
future = send_message(agent, message_context)
responses.append(future)
# Wait for all responses.
try:
await asyncio.gather(*responses)
except BaseException as e:
logger.error("Error handling event", exc_info=e)
async def _register_agent_type(self, agent_type: str) -> None:
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
message = agent_worker_pb2.RegisterAgentTypeRequest(type=agent_type)
_response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent(
message, metadata=self._host_connection.metadata
)
async def register_factory(
self,
type: str | AgentType,
agent_factory: Callable[[], T | Awaitable[T]],
*,
expected_class: type[T] | None = None,
) -> AgentType:
if isinstance(type, str):
type = AgentType(type)
if type.type in self._agent_factories:
raise ValueError(f"Agent with type {type} already exists.")
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
async def factory_wrapper() -> T:
maybe_agent_instance = agent_factory()
if inspect.isawaitable(maybe_agent_instance):
agent_instance = await maybe_agent_instance
else:
agent_instance = maybe_agent_instance
if expected_class is not None and type_func_alias(agent_instance) != expected_class:
raise ValueError("Factory registered using the wrong type.")
return agent_instance
self._agent_factories[type.type] = factory_wrapper
# Send the registration request message to the host.
await self._register_agent_type(type.type)
return type
async def register_agent_instance(
self,
agent_instance: Agent,
agent_id: AgentId,
) -> AgentId:
def agent_factory() -> Agent:
raise RuntimeError(
"Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent."
)
if agent_id in self._instantiated_agents:
raise ValueError(f"Agent with id {agent_id} already exists.")
if agent_id.type not in self._agent_factories:
self._agent_factories[agent_id.type] = agent_factory
await self._register_agent_type(agent_id.type)
self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance)
else:
if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__:
raise ValueError("Agent factories and agent instances cannot be registered to the same type.")
if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance):
raise ValueError("Agent instances must be the same object type.")
await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self)
self._instantiated_agents[agent_id] = agent_instance
return agent_id
async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
agent_id: AgentId,
) -> T:
with AgentInstantiationContext.populate_context((self, agent_id)):
if len(inspect.signature(agent_factory).parameters) == 0:
factory_one = cast(Callable[[], T], agent_factory)
agent = factory_one()
elif len(inspect.signature(agent_factory).parameters) == 2:
warnings.warn(
"Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
stacklevel=2,
)
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
agent = factory_two(self, agent_id)
else:
raise ValueError("Agent factory must take 0 or 2 arguments.")
if inspect.isawaitable(agent):
agent = cast(T, await agent)
return agent
async def _get_agent(self, agent_id: AgentId) -> Agent:
if agent_id in self._instantiated_agents:
return self._instantiated_agents[agent_id]
if agent_id.type not in self._agent_factories:
raise ValueError(f"Agent with name {agent_id.type} not found.")
agent_factory = self._agent_factories[agent_id.type]
agent = await self._invoke_agent_factory(agent_factory, agent_id)
self._instantiated_agents[agent_id] = agent
return agent
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
if id.type not in self._agent_factories:
raise LookupError(f"Agent with name {id.type} not found.")
# TODO: check if remote
agent_instance = await self._get_agent(id)
if not isinstance(agent_instance, type):
raise TypeError(f"Agent with name {id.type} is not of type {type.__name__}")
return agent_instance
async def add_subscription(self, subscription: Subscription) -> None:
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
message = agent_worker_pb2.AddSubscriptionRequest(subscription=subscription_to_proto(subscription))
_response: agent_worker_pb2.AddSubscriptionResponse = await self._host_connection.stub.AddSubscription(
message, metadata=self._host_connection.metadata
)
# Add to local subscription manager.
await self._subscription_manager.add_subscription(subscription)
async def remove_subscription(self, id: str) -> None:
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
message = agent_worker_pb2.RemoveSubscriptionRequest(id=id)
_response: agent_worker_pb2.RemoveSubscriptionResponse = await self._host_connection.stub.RemoveSubscription(
message, metadata=self._host_connection.metadata
)
await self._subscription_manager.remove_subscription(id)
async def get(
self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True
) -> AgentId:
return await get_impl(
id_or_type=id_or_type,
key=key,
lazy=lazy,
instance_getter=self._get_agent,
)
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
self._serialization_registry.add_serializer(serializer)

View File

@@ -0,0 +1,73 @@
import asyncio
import logging
import signal
from typing import Optional, Sequence
from ._constants import GRPC_IMPORT_ERROR_STR
from ._type_helpers import ChannelArgumentType
from ._worker_runtime_host_servicer import GrpcWorkerAgentRuntimeHostServicer
try:
import grpc
except ImportError as e:
raise ImportError(GRPC_IMPORT_ERROR_STR) from e
from .protos import agent_worker_pb2_grpc
logger = logging.getLogger("agentdhal_core")
class GrpcWorkerAgentRuntimeHost:
def __init__(self, address: str, extra_grpc_config: Optional[ChannelArgumentType] = None) -> None:
self._server = grpc.aio.server(options=extra_grpc_config)
self._servicer = GrpcWorkerAgentRuntimeHostServicer()
agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(self._servicer, self._server)
self._server.add_insecure_port(address)
self._address = address
self._serve_task: asyncio.Task[None] | None = None
async def _serve(self) -> None:
await self._server.start()
logger.info(f"Server started at {self._address}.")
await self._server.wait_for_termination()
def start(self) -> None:
"""Start the server in a background task."""
if self._serve_task is not None:
raise RuntimeError("Host runtime is already started.")
self._serve_task = asyncio.create_task(self._serve())
async def stop(self, grace: int = 5) -> None:
"""Stop the server."""
if self._serve_task is None:
raise RuntimeError("Host runtime is not started.")
await self._server.stop(grace=grace)
self._serve_task.cancel()
try:
await self._serve_task
except asyncio.CancelledError:
pass
logger.info("Server stopped.")
self._serve_task = None
async def stop_when_signal(
self, grace: int = 5, signals: Sequence[signal.Signals] = (signal.SIGTERM, signal.SIGINT)
) -> None:
"""Stop the server when a signal is received."""
if self._serve_task is None:
raise RuntimeError("Host runtime is not started.")
# Set up signal handling for graceful shutdown.
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def signal_handler() -> None:
logger.info("Received exit signal, shutting down gracefully...")
shutdown_event.set()
for sig in signals:
loop.add_signal_handler(sig, signal_handler)
# Wait for the signal to trigger the shutdown event.
await shutdown_event.wait()
# Shutdown the server.
await self.stop(grace=grace)

View File

@@ -0,0 +1,364 @@
from __future__ import annotations
import asyncio
import logging
from abc import ABC, abstractmethod
from asyncio import Future, Task
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Sequence, Set, Tuple, TypeVar
from agentdhal_core import TopicId
from agentdhal_core._agent_id import AgentId
from agentdhal_core._runtime_impl_helpers import SubscriptionManager
from ._constants import GRPC_IMPORT_ERROR_STR
from ._utils import subscription_from_proto, subscription_to_proto
try:
import grpc
except ImportError as e:
raise ImportError(GRPC_IMPORT_ERROR_STR) from e
from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
logger = logging.getLogger("agentdhal_core")
event_logger = logging.getLogger("agentdhal_core.events")
ClientConnectionId = str
def metadata_to_dict(metadata: Sequence[Tuple[str, str]] | None) -> Dict[str, str]:
if metadata is None:
return {}
return {key: value for key, value in metadata}
async def get_client_id_or_abort(context: grpc.aio.ServicerContext[Any, Any]) -> str: # type: ignore
# The type hint on context.invocation_metadata() is incorrect.
metadata = metadata_to_dict(context.invocation_metadata()) # type: ignore
if (client_id := metadata.get("client-id")) is None:
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, "client-id metadata not found.")
return client_id # type: ignore
SendT = TypeVar("SendT")
ReceiveT = TypeVar("ReceiveT")
class ChannelConnection(ABC, Generic[SendT, ReceiveT]):
def __init__(self, request_iterator: AsyncIterator[ReceiveT], client_id: str) -> None:
self._request_iterator = request_iterator
self._client_id = client_id
self._send_queue: asyncio.Queue[SendT] = asyncio.Queue()
self._receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator))
async def _receive_messages(self, client_id: ClientConnectionId, request_iterator: AsyncIterator[ReceiveT]) -> None:
# Receive messages from the client and process them.
async for message in request_iterator:
logger.info(f"Received message from client {client_id}: {message}")
await self._handle_message(message)
def __aiter__(self) -> AsyncIterator[SendT]:
return self
async def __anext__(self) -> SendT:
try:
return await self._send_queue.get()
except StopAsyncIteration:
await self._receiving_task
raise
except Exception as e:
logger.error(f"Failed to get message from send queue: {e}", exc_info=True)
await self._receiving_task
raise
@abstractmethod
async def _handle_message(self, message: ReceiveT) -> None:
pass
async def send(self, message: SendT) -> None:
await self._send_queue.put(message)
class CallbackChannelConnection(ChannelConnection[SendT, ReceiveT]):
def __init__(
self,
request_iterator: AsyncIterator[ReceiveT],
client_id: str,
handle_callback: Callable[[ReceiveT], Awaitable[None]],
) -> None:
self._handle_callback = handle_callback
super().__init__(request_iterator, client_id)
async def _handle_message(self, message: ReceiveT) -> None:
await self._handle_callback(message)
class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
"""A gRPC servicer that hosts message delivery service for agents."""
def __init__(self) -> None:
self._data_connections: Dict[
ClientConnectionId, ChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message]
] = {}
self._control_connections: Dict[
ClientConnectionId, ChannelConnection[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage]
] = {}
self._agent_type_to_client_id_lock = asyncio.Lock()
self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {}
self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {}
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
self._client_id_to_subscription_id_mapping: Dict[ClientConnectionId, set[str]] = {}
async def OpenChannel( # type: ignore
self,
request_iterator: AsyncIterator[agent_worker_pb2.Message],
context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message],
) -> AsyncIterator[agent_worker_pb2.Message]:
client_id = await get_client_id_or_abort(context)
async def handle_callback(message: agent_worker_pb2.Message) -> None:
await self._receive_message(client_id, message)
connection = CallbackChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message](
request_iterator, client_id, handle_callback=handle_callback
)
self._data_connections[client_id] = connection
logger.info(f"Client {client_id} connected.")
try:
async for message in connection:
yield message
finally:
# Clean up the client connection.
del self._data_connections[client_id]
# Cancel pending requests sent to this client.
for future in self._pending_responses.pop(client_id, {}).values():
future.cancel()
# Remove the client id from the agent type to client id mapping.
await self._on_client_disconnect(client_id)
async def OpenControlChannel( # type: ignore
self,
request_iterator: AsyncIterator[agent_worker_pb2.ControlMessage],
context: grpc.aio.ServicerContext[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage],
) -> AsyncIterator[agent_worker_pb2.ControlMessage]:
client_id = await get_client_id_or_abort(context)
async def handle_callback(message: agent_worker_pb2.ControlMessage) -> None:
await self._receive_control_message(client_id, message)
connection = CallbackChannelConnection[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage](
request_iterator, client_id, handle_callback=handle_callback
)
self._control_connections[client_id] = connection
logger.info(f"Client {client_id} connected.")
try:
async for message in connection:
yield message
finally:
# Clean up the client connection.
del self._control_connections[client_id]
async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None:
async with self._agent_type_to_client_id_lock:
agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id]
for agent_type in agent_types:
logger.info(f"Removing agent type {agent_type} from agent type to client id mapping")
del self._agent_type_to_client_id[agent_type]
for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, set()):
logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}")
try:
await self._subscription_manager.remove_subscription(sub_id)
# Catch and ignore if the subscription does not exist.
except ValueError:
continue
logger.info(f"Client {client_id} disconnected successfully")
def _raise_on_exception(self, task: Task[Any]) -> None:
exception = task.exception()
if exception is not None:
raise exception
async def _receive_message(self, client_id: ClientConnectionId, message: agent_worker_pb2.Message) -> None:
logger.info(f"Received message from client {client_id}: {message}")
oneofcase = message.WhichOneof("message")
match oneofcase:
case "request":
request: agent_worker_pb2.RpcRequest = message.request
task = asyncio.create_task(self._process_request(request, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "response":
response: agent_worker_pb2.RpcResponse = message.response
task = asyncio.create_task(self._process_response(response, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "cloudEvent":
task = asyncio.create_task(self._process_event(message.cloudEvent))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warning("Received empty message")
async def _receive_control_message(
self, client_id: ClientConnectionId, message: agent_worker_pb2.ControlMessage
) -> None:
logger.info(f"Received message from client {client_id}: {message}")
destination = message.destination
if destination.startswith("agentid="):
agent_id = AgentId.from_str(destination[len("agentid=") :])
target_client_id = self._agent_type_to_client_id.get(agent_id.type)
if target_client_id is None:
logger.error(f"Agent client id not found for agent type {agent_id.type}.")
return
elif destination.startswith("clientid="):
target_client_id = destination[len("clientid=") :]
else:
logger.error(f"Invalid destination {destination}")
return
target_send_queue = self._control_connections.get(target_client_id)
if target_send_queue is None:
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
return
await target_send_queue.send(message)
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None:
# Deliver the message to a client given the target agent type.
async with self._agent_type_to_client_id_lock:
target_client_id = self._agent_type_to_client_id.get(request.target.type)
if target_client_id is None:
logger.error(f"Agent {request.target.type} not found, failed to deliver message.")
return
target_send_queue = self._data_connections.get(target_client_id)
if target_send_queue is None:
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
return
await target_send_queue.send(agent_worker_pb2.Message(request=request))
# Create a future to wait for the response from the target.
future = asyncio.get_event_loop().create_future()
self._pending_responses.setdefault(target_client_id, {})[request.request_id] = future
# Create a task to wait for the response and send it back to the client.
send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id))
self._background_tasks.add(send_response_task)
send_response_task.add_done_callback(self._raise_on_exception)
send_response_task.add_done_callback(self._background_tasks.discard)
async def _wait_and_send_response(
self, future: Future[agent_worker_pb2.RpcResponse], client_id: ClientConnectionId
) -> None:
response = await future
message = agent_worker_pb2.Message(response=response)
send_queue = self._data_connections.get(client_id)
if send_queue is None:
logger.error(f"Client {client_id} not found, failed to send response message.")
return
await send_queue.send(message)
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: ClientConnectionId) -> None:
# Setting the result of the future will send the response back to the original sender.
future = self._pending_responses[client_id].pop(response.request_id)
future.set_result(response)
async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
topic_id = TopicId(type=event.type, source=event.source)
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
# Get the client ids of the recipients.
async with self._agent_type_to_client_id_lock:
client_ids: Set[ClientConnectionId] = set()
for recipient in recipients:
client_id = self._agent_type_to_client_id.get(recipient.type)
if client_id is not None:
client_ids.add(client_id)
else:
logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.")
# Deliver the event to clients.
for client_id in client_ids:
await self._data_connections[client_id].send(agent_worker_pb2.Message(cloudEvent=event))
async def RegisterAgent( # type: ignore
self,
request: agent_worker_pb2.RegisterAgentTypeRequest,
context: grpc.aio.ServicerContext[
agent_worker_pb2.RegisterAgentTypeRequest, agent_worker_pb2.RegisterAgentTypeResponse
],
) -> agent_worker_pb2.RegisterAgentTypeResponse:
client_id = await get_client_id_or_abort(context)
async with self._agent_type_to_client_id_lock:
if request.type in self._agent_type_to_client_id:
existing_client_id = self._agent_type_to_client_id[request.type]
await context.abort(
grpc.StatusCode.INVALID_ARGUMENT,
f"Agent type {request.type} already registered with client {existing_client_id}.",
)
else:
self._agent_type_to_client_id[request.type] = client_id
return agent_worker_pb2.RegisterAgentTypeResponse()
async def AddSubscription( # type: ignore
self,
request: agent_worker_pb2.AddSubscriptionRequest,
context: grpc.aio.ServicerContext[
agent_worker_pb2.AddSubscriptionRequest, agent_worker_pb2.AddSubscriptionResponse
],
) -> agent_worker_pb2.AddSubscriptionResponse:
client_id = await get_client_id_or_abort(context)
subscription = subscription_from_proto(request.subscription)
try:
await self._subscription_manager.add_subscription(subscription)
subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set())
subscription_ids.add(subscription.id)
except ValueError as e:
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
return agent_worker_pb2.AddSubscriptionResponse()
async def RemoveSubscription( # type: ignore
self,
request: agent_worker_pb2.RemoveSubscriptionRequest,
context: grpc.aio.ServicerContext[
agent_worker_pb2.RemoveSubscriptionRequest, agent_worker_pb2.RemoveSubscriptionResponse
],
) -> agent_worker_pb2.RemoveSubscriptionResponse:
_client_id = await get_client_id_or_abort(context)
await self._subscription_manager.remove_subscription(request.id)
return agent_worker_pb2.RemoveSubscriptionResponse()
async def GetSubscriptions( # type: ignore
self,
request: agent_worker_pb2.GetSubscriptionsRequest,
context: grpc.aio.ServicerContext[
agent_worker_pb2.GetSubscriptionsRequest, agent_worker_pb2.GetSubscriptionsResponse
],
) -> agent_worker_pb2.GetSubscriptionsResponse:
_client_id = await get_client_id_or_abort(context)
subscriptions = self._subscription_manager.subscriptions
return agent_worker_pb2.GetSubscriptionsResponse(
subscriptions=[subscription_to_proto(sub) for sub in subscriptions]
)
# async def GetState( # type: ignore
# self,
# request: agent_worker_pb2.AgentId,
# context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.GetStateResponse],
# ) -> agent_worker_pb2.GetStateResponse:
# _client_id = await get_client_id_or_abort(context)
# raise NotImplementedError("Method not implemented!")
# async def SaveState( # type: ignore
# self,
# request: agent_worker_pb2.AgentState,
# context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.SaveStateResponse],
# ) -> agent_worker_pb2.SaveStateResponse:
# _client_id = await get_client_id_or_abort(context)
# raise NotImplementedError("Method not implemented!")

View File

@@ -0,0 +1,4 @@
"""
The :mod:`agentdhal_extensions.runtimes.grpc.protos` module provides Google Protobuf classes for agent-worker communication
"""

View File

@@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: agent_worker.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'agent_worker.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from . import cloudevent_pb2 as cloudevent__pb2
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"(\n\x18RegisterAgentTypeRequest\x12\x0c\n\x04type\x18\x01 \x01(\t\"\x1b\n\x19RegisterAgentTypeResponse\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\xa2\x01\n\x0cSubscription\x12\n\n\x02id\x18\x01 \x01(\t\x12\x34\n\x10typeSubscription\x18\x02 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x03 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"D\n\x16\x41\x64\x64SubscriptionRequest\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\x19\n\x17\x41\x64\x64SubscriptionResponse\"\'\n\x19RemoveSubscriptionRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\x1c\n\x1aRemoveSubscriptionResponse\"\x19\n\x17GetSubscriptionsRequest\"G\n\x18GetSubscriptionsResponse\x12+\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x14.agents.Subscription\"\x99\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x33\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x1d.io.cloudevents.v1.CloudEventH\x00\x42\t\n\x07message\"4\n\x10SaveStateRequest\x12 \n\x07\x61gentId\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\"@\n\x11SaveStateResponse\x12\r\n\x05state\x18\x01 \x01(\t\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"C\n\x10LoadStateRequest\x12 \n\x07\x61gentId\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\r\n\x05state\x18\x02 \x01(\t\"1\n\x11LoadStateResponse\x12\x12\n\x05\x65rror\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x87\x01\n\x0e\x43ontrolMessage\x12\x0e\n\x06rpc_id\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65stination\x18\x02 \x01(\t\x12\x17\n\nrespond_to\x18\x03 \x01(\tH\x00\x88\x01\x01\x12(\n\nrpcMessage\x18\x04 \x01(\x0b\x32\x14.google.protobuf.AnyB\r\n\x0b_respond_to2\xe7\x03\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12H\n\x12OpenControlChannel\x12\x16.agents.ControlMessage\x1a\x16.agents.ControlMessage(\x01\x30\x01\x12T\n\rRegisterAgent\x12 .agents.RegisterAgentTypeRequest\x1a!.agents.RegisterAgentTypeResponse\x12R\n\x0f\x41\x64\x64Subscription\x12\x1e.agents.AddSubscriptionRequest\x1a\x1f.agents.AddSubscriptionResponse\x12[\n\x12RemoveSubscription\x12!.agents.RemoveSubscriptionRequest\x1a\".agents.RemoveSubscriptionResponse\x12U\n\x10GetSubscriptions\x12\x1f.agents.GetSubscriptionsRequest\x1a .agents.GetSubscriptionsResponseB\x1d\xaa\x02\x1aMicrosoft.AutoGen.Protobufb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'agent_worker_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'\252\002\032Microsoft.AutoGen.Protobuf'
_globals['_RPCREQUEST_METADATAENTRY']._loaded_options = None
_globals['_RPCREQUEST_METADATAENTRY']._serialized_options = b'8\001'
_globals['_RPCRESPONSE_METADATAENTRY']._loaded_options = None
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_options = b'8\001'
_globals['_AGENTID']._serialized_start=75
_globals['_AGENTID']._serialized_end=111
_globals['_PAYLOAD']._serialized_start=113
_globals['_PAYLOAD']._serialized_end=182
_globals['_RPCREQUEST']._serialized_start=185
_globals['_RPCREQUEST']._serialized_end=450
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=392
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=439
_globals['_RPCRESPONSE']._serialized_start=453
_globals['_RPCRESPONSE']._serialized_end=637
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=392
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=439
_globals['_REGISTERAGENTTYPEREQUEST']._serialized_start=639
_globals['_REGISTERAGENTTYPEREQUEST']._serialized_end=679
_globals['_REGISTERAGENTTYPERESPONSE']._serialized_start=681
_globals['_REGISTERAGENTTYPERESPONSE']._serialized_end=708
_globals['_TYPESUBSCRIPTION']._serialized_start=710
_globals['_TYPESUBSCRIPTION']._serialized_end=768
_globals['_TYPEPREFIXSUBSCRIPTION']._serialized_start=770
_globals['_TYPEPREFIXSUBSCRIPTION']._serialized_end=841
_globals['_SUBSCRIPTION']._serialized_start=844
_globals['_SUBSCRIPTION']._serialized_end=1006
_globals['_ADDSUBSCRIPTIONREQUEST']._serialized_start=1008
_globals['_ADDSUBSCRIPTIONREQUEST']._serialized_end=1076
_globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_start=1078
_globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_end=1103
_globals['_REMOVESUBSCRIPTIONREQUEST']._serialized_start=1105
_globals['_REMOVESUBSCRIPTIONREQUEST']._serialized_end=1144
_globals['_REMOVESUBSCRIPTIONRESPONSE']._serialized_start=1146
_globals['_REMOVESUBSCRIPTIONRESPONSE']._serialized_end=1174
_globals['_GETSUBSCRIPTIONSREQUEST']._serialized_start=1176
_globals['_GETSUBSCRIPTIONSREQUEST']._serialized_end=1201
_globals['_GETSUBSCRIPTIONSRESPONSE']._serialized_start=1203
_globals['_GETSUBSCRIPTIONSRESPONSE']._serialized_end=1274
_globals['_MESSAGE']._serialized_start=1277
_globals['_MESSAGE']._serialized_end=1430
_globals['_SAVESTATEREQUEST']._serialized_start=1432
_globals['_SAVESTATEREQUEST']._serialized_end=1484
_globals['_SAVESTATERESPONSE']._serialized_start=1486
_globals['_SAVESTATERESPONSE']._serialized_end=1550
_globals['_LOADSTATEREQUEST']._serialized_start=1552
_globals['_LOADSTATEREQUEST']._serialized_end=1619
_globals['_LOADSTATERESPONSE']._serialized_start=1621
_globals['_LOADSTATERESPONSE']._serialized_end=1670
_globals['_CONTROLMESSAGE']._serialized_start=1673
_globals['_CONTROLMESSAGE']._serialized_end=1808
_globals['_AGENTRPC']._serialized_start=1811
_globals['_AGENTRPC']._serialized_end=2298
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,457 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
from . import cloudevent_pb2
import collections.abc
import google.protobuf.any_pb2
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import typing
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class AgentId(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TYPE_FIELD_NUMBER: builtins.int
KEY_FIELD_NUMBER: builtins.int
type: builtins.str
key: builtins.str
def __init__(
self,
*,
type: builtins.str = ...,
key: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "type", b"type"]) -> None: ...
global___AgentId = AgentId
@typing.final
class Payload(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DATA_TYPE_FIELD_NUMBER: builtins.int
DATA_CONTENT_TYPE_FIELD_NUMBER: builtins.int
DATA_FIELD_NUMBER: builtins.int
data_type: builtins.str
data_content_type: builtins.str
data: builtins.bytes
def __init__(
self,
*,
data_type: builtins.str = ...,
data_content_type: builtins.str = ...,
data: builtins.bytes = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "data_content_type", b"data_content_type", "data_type", b"data_type"]) -> None: ...
global___Payload = Payload
@typing.final
class RpcRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@typing.final
class MetadataEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.str
value: builtins.str
def __init__(
self,
*,
key: builtins.str = ...,
value: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
REQUEST_ID_FIELD_NUMBER: builtins.int
SOURCE_FIELD_NUMBER: builtins.int
TARGET_FIELD_NUMBER: builtins.int
METHOD_FIELD_NUMBER: builtins.int
PAYLOAD_FIELD_NUMBER: builtins.int
METADATA_FIELD_NUMBER: builtins.int
request_id: builtins.str
method: builtins.str
@property
def source(self) -> global___AgentId: ...
@property
def target(self) -> global___AgentId: ...
@property
def payload(self) -> global___Payload: ...
@property
def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ...
def __init__(
self,
*,
request_id: builtins.str = ...,
source: global___AgentId | None = ...,
target: global___AgentId | None = ...,
method: builtins.str = ...,
payload: global___Payload | None = ...,
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ...
global___RpcRequest = RpcRequest
@typing.final
class RpcResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@typing.final
class MetadataEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.str
value: builtins.str
def __init__(
self,
*,
key: builtins.str = ...,
value: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
REQUEST_ID_FIELD_NUMBER: builtins.int
PAYLOAD_FIELD_NUMBER: builtins.int
ERROR_FIELD_NUMBER: builtins.int
METADATA_FIELD_NUMBER: builtins.int
request_id: builtins.str
error: builtins.str
@property
def payload(self) -> global___Payload: ...
@property
def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ...
def __init__(
self,
*,
request_id: builtins.str = ...,
payload: global___Payload | None = ...,
error: builtins.str = ...,
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "payload", b"payload", "request_id", b"request_id"]) -> None: ...
global___RpcResponse = RpcResponse
@typing.final
class RegisterAgentTypeRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TYPE_FIELD_NUMBER: builtins.int
type: builtins.str
def __init__(
self,
*,
type: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["type", b"type"]) -> None: ...
global___RegisterAgentTypeRequest = RegisterAgentTypeRequest
@typing.final
class RegisterAgentTypeResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
def __init__(
self,
) -> None: ...
global___RegisterAgentTypeResponse = RegisterAgentTypeResponse
@typing.final
class TypeSubscription(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TOPIC_TYPE_FIELD_NUMBER: builtins.int
AGENT_TYPE_FIELD_NUMBER: builtins.int
topic_type: builtins.str
agent_type: builtins.str
def __init__(
self,
*,
topic_type: builtins.str = ...,
agent_type: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type", b"topic_type"]) -> None: ...
global___TypeSubscription = TypeSubscription
@typing.final
class TypePrefixSubscription(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TOPIC_TYPE_PREFIX_FIELD_NUMBER: builtins.int
AGENT_TYPE_FIELD_NUMBER: builtins.int
topic_type_prefix: builtins.str
agent_type: builtins.str
def __init__(
self,
*,
topic_type_prefix: builtins.str = ...,
agent_type: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type_prefix", b"topic_type_prefix"]) -> None: ...
global___TypePrefixSubscription = TypePrefixSubscription
@typing.final
class Subscription(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ID_FIELD_NUMBER: builtins.int
TYPESUBSCRIPTION_FIELD_NUMBER: builtins.int
TYPEPREFIXSUBSCRIPTION_FIELD_NUMBER: builtins.int
id: builtins.str
@property
def typeSubscription(self) -> global___TypeSubscription: ...
@property
def typePrefixSubscription(self) -> global___TypePrefixSubscription: ...
def __init__(
self,
*,
id: builtins.str = ...,
typeSubscription: global___TypeSubscription | None = ...,
typePrefixSubscription: global___TypePrefixSubscription | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["subscription", b"subscription", "typePrefixSubscription", b"typePrefixSubscription", "typeSubscription", b"typeSubscription"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["id", b"id", "subscription", b"subscription", "typePrefixSubscription", b"typePrefixSubscription", "typeSubscription", b"typeSubscription"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["subscription", b"subscription"]) -> typing.Literal["typeSubscription", "typePrefixSubscription"] | None: ...
global___Subscription = Subscription
@typing.final
class AddSubscriptionRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SUBSCRIPTION_FIELD_NUMBER: builtins.int
@property
def subscription(self) -> global___Subscription: ...
def __init__(
self,
*,
subscription: global___Subscription | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["subscription", b"subscription"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["subscription", b"subscription"]) -> None: ...
global___AddSubscriptionRequest = AddSubscriptionRequest
@typing.final
class AddSubscriptionResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
def __init__(
self,
) -> None: ...
global___AddSubscriptionResponse = AddSubscriptionResponse
@typing.final
class RemoveSubscriptionRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ID_FIELD_NUMBER: builtins.int
id: builtins.str
def __init__(
self,
*,
id: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["id", b"id"]) -> None: ...
global___RemoveSubscriptionRequest = RemoveSubscriptionRequest
@typing.final
class RemoveSubscriptionResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
def __init__(
self,
) -> None: ...
global___RemoveSubscriptionResponse = RemoveSubscriptionResponse
@typing.final
class GetSubscriptionsRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
def __init__(
self,
) -> None: ...
global___GetSubscriptionsRequest = GetSubscriptionsRequest
@typing.final
class GetSubscriptionsResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SUBSCRIPTIONS_FIELD_NUMBER: builtins.int
@property
def subscriptions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Subscription]: ...
def __init__(
self,
*,
subscriptions: collections.abc.Iterable[global___Subscription] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["subscriptions", b"subscriptions"]) -> None: ...
global___GetSubscriptionsResponse = GetSubscriptionsResponse
@typing.final
class Message(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
REQUEST_FIELD_NUMBER: builtins.int
RESPONSE_FIELD_NUMBER: builtins.int
CLOUDEVENT_FIELD_NUMBER: builtins.int
@property
def request(self) -> global___RpcRequest: ...
@property
def response(self) -> global___RpcResponse: ...
@property
def cloudEvent(self) -> cloudevent_pb2.CloudEvent: ...
def __init__(
self,
*,
request: global___RpcRequest | None = ...,
response: global___RpcResponse | None = ...,
cloudEvent: cloudevent_pb2.CloudEvent | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["cloudEvent", b"cloudEvent", "message", b"message", "request", b"request", "response", b"response"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["cloudEvent", b"cloudEvent", "message", b"message", "request", b"request", "response", b"response"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "cloudEvent"] | None: ...
global___Message = Message
@typing.final
class SaveStateRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
AGENTID_FIELD_NUMBER: builtins.int
@property
def agentId(self) -> global___AgentId: ...
def __init__(
self,
*,
agentId: global___AgentId | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["agentId", b"agentId"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["agentId", b"agentId"]) -> None: ...
global___SaveStateRequest = SaveStateRequest
@typing.final
class SaveStateResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
STATE_FIELD_NUMBER: builtins.int
ERROR_FIELD_NUMBER: builtins.int
state: builtins.str
error: builtins.str
def __init__(
self,
*,
state: builtins.str = ...,
error: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "state", b"state"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
global___SaveStateResponse = SaveStateResponse
@typing.final
class LoadStateRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
AGENTID_FIELD_NUMBER: builtins.int
STATE_FIELD_NUMBER: builtins.int
state: builtins.str
@property
def agentId(self) -> global___AgentId: ...
def __init__(
self,
*,
agentId: global___AgentId | None = ...,
state: builtins.str = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["agentId", b"agentId"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["agentId", b"agentId", "state", b"state"]) -> None: ...
global___LoadStateRequest = LoadStateRequest
@typing.final
class LoadStateResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ERROR_FIELD_NUMBER: builtins.int
error: builtins.str
def __init__(
self,
*,
error: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
global___LoadStateResponse = LoadStateResponse
@typing.final
class ControlMessage(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RPC_ID_FIELD_NUMBER: builtins.int
DESTINATION_FIELD_NUMBER: builtins.int
RESPOND_TO_FIELD_NUMBER: builtins.int
RPCMESSAGE_FIELD_NUMBER: builtins.int
rpc_id: builtins.str
"""A response message should have the same id as the request message"""
destination: builtins.str
"""This is either:
agentid=AGENT_ID
clientid=CLIENT_ID
"""
respond_to: builtins.str
"""This is either:
agentid=AGENT_ID
clientid=CLIENT_ID
Empty string means the message is a response
"""
@property
def rpcMessage(self) -> google.protobuf.any_pb2.Any:
"""One of:
SaveStateRequest saveStateRequest = 2;
SaveStateResponse saveStateResponse = 3;
LoadStateRequest loadStateRequest = 4;
LoadStateResponse loadStateResponse = 5;
"""
def __init__(
self,
*,
rpc_id: builtins.str = ...,
destination: builtins.str = ...,
respond_to: builtins.str | None = ...,
rpcMessage: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_respond_to", b"_respond_to", "respond_to", b"respond_to", "rpcMessage", b"rpcMessage"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_respond_to", b"_respond_to", "destination", b"destination", "respond_to", b"respond_to", "rpcMessage", b"rpcMessage", "rpc_id", b"rpc_id"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_respond_to", b"_respond_to"]) -> typing.Literal["respond_to"] | None: ...
global___ControlMessage = ControlMessage

View File

@@ -0,0 +1,312 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
from . import agent_worker_pb2 as agent__worker__pb2
GRPC_GENERATED_VERSION = '1.70.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in agent_worker_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class AgentRpcStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.OpenChannel = channel.stream_stream(
'/agents.AgentRpc/OpenChannel',
request_serializer=agent__worker__pb2.Message.SerializeToString,
response_deserializer=agent__worker__pb2.Message.FromString,
_registered_method=True)
self.OpenControlChannel = channel.stream_stream(
'/agents.AgentRpc/OpenControlChannel',
request_serializer=agent__worker__pb2.ControlMessage.SerializeToString,
response_deserializer=agent__worker__pb2.ControlMessage.FromString,
_registered_method=True)
self.RegisterAgent = channel.unary_unary(
'/agents.AgentRpc/RegisterAgent',
request_serializer=agent__worker__pb2.RegisterAgentTypeRequest.SerializeToString,
response_deserializer=agent__worker__pb2.RegisterAgentTypeResponse.FromString,
_registered_method=True)
self.AddSubscription = channel.unary_unary(
'/agents.AgentRpc/AddSubscription',
request_serializer=agent__worker__pb2.AddSubscriptionRequest.SerializeToString,
response_deserializer=agent__worker__pb2.AddSubscriptionResponse.FromString,
_registered_method=True)
self.RemoveSubscription = channel.unary_unary(
'/agents.AgentRpc/RemoveSubscription',
request_serializer=agent__worker__pb2.RemoveSubscriptionRequest.SerializeToString,
response_deserializer=agent__worker__pb2.RemoveSubscriptionResponse.FromString,
_registered_method=True)
self.GetSubscriptions = channel.unary_unary(
'/agents.AgentRpc/GetSubscriptions',
request_serializer=agent__worker__pb2.GetSubscriptionsRequest.SerializeToString,
response_deserializer=agent__worker__pb2.GetSubscriptionsResponse.FromString,
_registered_method=True)
class AgentRpcServicer(object):
"""Missing associated documentation comment in .proto file."""
def OpenChannel(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def OpenControlChannel(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def RegisterAgent(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def AddSubscription(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def RemoveSubscription(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetSubscriptions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AgentRpcServicer_to_server(servicer, server):
rpc_method_handlers = {
'OpenChannel': grpc.stream_stream_rpc_method_handler(
servicer.OpenChannel,
request_deserializer=agent__worker__pb2.Message.FromString,
response_serializer=agent__worker__pb2.Message.SerializeToString,
),
'OpenControlChannel': grpc.stream_stream_rpc_method_handler(
servicer.OpenControlChannel,
request_deserializer=agent__worker__pb2.ControlMessage.FromString,
response_serializer=agent__worker__pb2.ControlMessage.SerializeToString,
),
'RegisterAgent': grpc.unary_unary_rpc_method_handler(
servicer.RegisterAgent,
request_deserializer=agent__worker__pb2.RegisterAgentTypeRequest.FromString,
response_serializer=agent__worker__pb2.RegisterAgentTypeResponse.SerializeToString,
),
'AddSubscription': grpc.unary_unary_rpc_method_handler(
servicer.AddSubscription,
request_deserializer=agent__worker__pb2.AddSubscriptionRequest.FromString,
response_serializer=agent__worker__pb2.AddSubscriptionResponse.SerializeToString,
),
'RemoveSubscription': grpc.unary_unary_rpc_method_handler(
servicer.RemoveSubscription,
request_deserializer=agent__worker__pb2.RemoveSubscriptionRequest.FromString,
response_serializer=agent__worker__pb2.RemoveSubscriptionResponse.SerializeToString,
),
'GetSubscriptions': grpc.unary_unary_rpc_method_handler(
servicer.GetSubscriptions,
request_deserializer=agent__worker__pb2.GetSubscriptionsRequest.FromString,
response_serializer=agent__worker__pb2.GetSubscriptionsResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'agents.AgentRpc', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('agents.AgentRpc', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AgentRpc(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def OpenChannel(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(
request_iterator,
target,
'/agents.AgentRpc/OpenChannel',
agent__worker__pb2.Message.SerializeToString,
agent__worker__pb2.Message.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def OpenControlChannel(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(
request_iterator,
target,
'/agents.AgentRpc/OpenControlChannel',
agent__worker__pb2.ControlMessage.SerializeToString,
agent__worker__pb2.ControlMessage.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def RegisterAgent(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/agents.AgentRpc/RegisterAgent',
agent__worker__pb2.RegisterAgentTypeRequest.SerializeToString,
agent__worker__pb2.RegisterAgentTypeResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def AddSubscription(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/agents.AgentRpc/AddSubscription',
agent__worker__pb2.AddSubscriptionRequest.SerializeToString,
agent__worker__pb2.AddSubscriptionResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def RemoveSubscription(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/agents.AgentRpc/RemoveSubscription',
agent__worker__pb2.RemoveSubscriptionRequest.SerializeToString,
agent__worker__pb2.RemoveSubscriptionResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetSubscriptions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/agents.AgentRpc/GetSubscriptions',
agent__worker__pb2.GetSubscriptionsRequest.SerializeToString,
agent__worker__pb2.GetSubscriptionsResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -0,0 +1,126 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import abc
from . import agent_worker_pb2
import collections.abc
import grpc
import grpc.aio
import typing
_T = typing.TypeVar("_T")
class _MaybeAsyncIterator(collections.abc.AsyncIterator[_T], collections.abc.Iterator[_T], metaclass=abc.ABCMeta): ...
class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type: ignore[misc, type-arg]
...
class AgentRpcStub:
def __init__(self, channel: typing.Union[grpc.Channel, grpc.aio.Channel]) -> None: ...
OpenChannel: grpc.StreamStreamMultiCallable[
agent_worker_pb2.Message,
agent_worker_pb2.Message,
]
OpenControlChannel: grpc.StreamStreamMultiCallable[
agent_worker_pb2.ControlMessage,
agent_worker_pb2.ControlMessage,
]
RegisterAgent: grpc.UnaryUnaryMultiCallable[
agent_worker_pb2.RegisterAgentTypeRequest,
agent_worker_pb2.RegisterAgentTypeResponse,
]
AddSubscription: grpc.UnaryUnaryMultiCallable[
agent_worker_pb2.AddSubscriptionRequest,
agent_worker_pb2.AddSubscriptionResponse,
]
RemoveSubscription: grpc.UnaryUnaryMultiCallable[
agent_worker_pb2.RemoveSubscriptionRequest,
agent_worker_pb2.RemoveSubscriptionResponse,
]
GetSubscriptions: grpc.UnaryUnaryMultiCallable[
agent_worker_pb2.GetSubscriptionsRequest,
agent_worker_pb2.GetSubscriptionsResponse,
]
class AgentRpcAsyncStub:
OpenChannel: grpc.aio.StreamStreamMultiCallable[
agent_worker_pb2.Message,
agent_worker_pb2.Message,
]
OpenControlChannel: grpc.aio.StreamStreamMultiCallable[
agent_worker_pb2.ControlMessage,
agent_worker_pb2.ControlMessage,
]
RegisterAgent: grpc.aio.UnaryUnaryMultiCallable[
agent_worker_pb2.RegisterAgentTypeRequest,
agent_worker_pb2.RegisterAgentTypeResponse,
]
AddSubscription: grpc.aio.UnaryUnaryMultiCallable[
agent_worker_pb2.AddSubscriptionRequest,
agent_worker_pb2.AddSubscriptionResponse,
]
RemoveSubscription: grpc.aio.UnaryUnaryMultiCallable[
agent_worker_pb2.RemoveSubscriptionRequest,
agent_worker_pb2.RemoveSubscriptionResponse,
]
GetSubscriptions: grpc.aio.UnaryUnaryMultiCallable[
agent_worker_pb2.GetSubscriptionsRequest,
agent_worker_pb2.GetSubscriptionsResponse,
]
class AgentRpcServicer(metaclass=abc.ABCMeta):
@abc.abstractmethod
def OpenChannel(
self,
request_iterator: _MaybeAsyncIterator[agent_worker_pb2.Message],
context: _ServicerContext,
) -> typing.Union[collections.abc.Iterator[agent_worker_pb2.Message], collections.abc.AsyncIterator[agent_worker_pb2.Message]]: ...
@abc.abstractmethod
def OpenControlChannel(
self,
request_iterator: _MaybeAsyncIterator[agent_worker_pb2.ControlMessage],
context: _ServicerContext,
) -> typing.Union[collections.abc.Iterator[agent_worker_pb2.ControlMessage], collections.abc.AsyncIterator[agent_worker_pb2.ControlMessage]]: ...
@abc.abstractmethod
def RegisterAgent(
self,
request: agent_worker_pb2.RegisterAgentTypeRequest,
context: _ServicerContext,
) -> typing.Union[agent_worker_pb2.RegisterAgentTypeResponse, collections.abc.Awaitable[agent_worker_pb2.RegisterAgentTypeResponse]]: ...
@abc.abstractmethod
def AddSubscription(
self,
request: agent_worker_pb2.AddSubscriptionRequest,
context: _ServicerContext,
) -> typing.Union[agent_worker_pb2.AddSubscriptionResponse, collections.abc.Awaitable[agent_worker_pb2.AddSubscriptionResponse]]: ...
@abc.abstractmethod
def RemoveSubscription(
self,
request: agent_worker_pb2.RemoveSubscriptionRequest,
context: _ServicerContext,
) -> typing.Union[agent_worker_pb2.RemoveSubscriptionResponse, collections.abc.Awaitable[agent_worker_pb2.RemoveSubscriptionResponse]]: ...
@abc.abstractmethod
def GetSubscriptions(
self,
request: agent_worker_pb2.GetSubscriptionsRequest,
context: _ServicerContext,
) -> typing.Union[agent_worker_pb2.GetSubscriptionsResponse, collections.abc.Awaitable[agent_worker_pb2.GetSubscriptionsResponse]]: ...
def add_AgentRpcServicer_to_server(servicer: AgentRpcServicer, server: typing.Union[grpc.Server, grpc.aio.Server]) -> None: ...

View File

@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: cloudevent.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'cloudevent.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63loudevent.proto\x12\x11io.cloudevents.v1\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xb0\x04\n\nCloudEvent\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\x12\x14\n\x0cspec_version\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x41\n\nattributes\x18\x05 \x03(\x0b\x32-.io.cloudevents.v1.CloudEvent.AttributesEntry\x12\x15\n\x0b\x62inary_data\x18\x06 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x07 \x01(\tH\x00\x12*\n\nproto_data\x18\x08 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x1ai\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x45\n\x05value\x18\x02 \x01(\x0b\x32\x36.io.cloudevents.v1.CloudEvent.CloudEventAttributeValue:\x02\x38\x01\x1a\xd3\x01\n\x18\x43loudEventAttributeValue\x12\x14\n\nce_boolean\x18\x01 \x01(\x08H\x00\x12\x14\n\nce_integer\x18\x02 \x01(\x05H\x00\x12\x13\n\tce_string\x18\x03 \x01(\tH\x00\x12\x12\n\x08\x63\x65_bytes\x18\x04 \x01(\x0cH\x00\x12\x10\n\x06\x63\x65_uri\x18\x05 \x01(\tH\x00\x12\x14\n\nce_uri_ref\x18\x06 \x01(\tH\x00\x12\x32\n\x0c\x63\x65_timestamp\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x00\x42\x06\n\x04\x61ttrB\x06\n\x04\x64\x61taB\x1e\xaa\x02\x1bMicrosoft.AutoGen.Contractsb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'cloudevent_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'\252\002\033Microsoft.AutoGen.Contracts'
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._loaded_options = None
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_options = b'8\001'
_globals['_CLOUDEVENT']._serialized_start=100
_globals['_CLOUDEVENT']._serialized_end=660
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_start=333
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_end=438
_globals['_CLOUDEVENT_CLOUDEVENTATTRIBUTEVALUE']._serialized_start=441
_globals['_CLOUDEVENT_CLOUDEVENTATTRIBUTEVALUE']._serialized_end=652
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,125 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
*
CloudEvent Protobuf Format
- Required context attributes are explicitly represented.
- Optional and Extension context attributes are carried in a map structure.
- Data may be represented as binary, text, or protobuf messages.
"""
import builtins
import collections.abc
import google.protobuf.any_pb2
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import google.protobuf.timestamp_pb2
import typing
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class CloudEvent(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@typing.final
class AttributesEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.str
@property
def value(self) -> global___CloudEvent.CloudEventAttributeValue: ...
def __init__(
self,
*,
key: builtins.str = ...,
value: global___CloudEvent.CloudEventAttributeValue | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["value", b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
@typing.final
class CloudEventAttributeValue(google.protobuf.message.Message):
"""*
The CloudEvent specification defines
seven attribute value types...
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
CE_BOOLEAN_FIELD_NUMBER: builtins.int
CE_INTEGER_FIELD_NUMBER: builtins.int
CE_STRING_FIELD_NUMBER: builtins.int
CE_BYTES_FIELD_NUMBER: builtins.int
CE_URI_FIELD_NUMBER: builtins.int
CE_URI_REF_FIELD_NUMBER: builtins.int
CE_TIMESTAMP_FIELD_NUMBER: builtins.int
ce_boolean: builtins.bool
ce_integer: builtins.int
ce_string: builtins.str
ce_bytes: builtins.bytes
ce_uri: builtins.str
ce_uri_ref: builtins.str
@property
def ce_timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp: ...
def __init__(
self,
*,
ce_boolean: builtins.bool = ...,
ce_integer: builtins.int = ...,
ce_string: builtins.str = ...,
ce_bytes: builtins.bytes = ...,
ce_uri: builtins.str = ...,
ce_uri_ref: builtins.str = ...,
ce_timestamp: google.protobuf.timestamp_pb2.Timestamp | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["attr", b"attr", "ce_boolean", b"ce_boolean", "ce_bytes", b"ce_bytes", "ce_integer", b"ce_integer", "ce_string", b"ce_string", "ce_timestamp", b"ce_timestamp", "ce_uri", b"ce_uri", "ce_uri_ref", b"ce_uri_ref"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["attr", b"attr", "ce_boolean", b"ce_boolean", "ce_bytes", b"ce_bytes", "ce_integer", b"ce_integer", "ce_string", b"ce_string", "ce_timestamp", b"ce_timestamp", "ce_uri", b"ce_uri", "ce_uri_ref", b"ce_uri_ref"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["attr", b"attr"]) -> typing.Literal["ce_boolean", "ce_integer", "ce_string", "ce_bytes", "ce_uri", "ce_uri_ref", "ce_timestamp"] | None: ...
ID_FIELD_NUMBER: builtins.int
SOURCE_FIELD_NUMBER: builtins.int
SPEC_VERSION_FIELD_NUMBER: builtins.int
TYPE_FIELD_NUMBER: builtins.int
ATTRIBUTES_FIELD_NUMBER: builtins.int
BINARY_DATA_FIELD_NUMBER: builtins.int
TEXT_DATA_FIELD_NUMBER: builtins.int
PROTO_DATA_FIELD_NUMBER: builtins.int
id: builtins.str
"""-- CloudEvent Context Attributes
Required Attributes
"""
source: builtins.str
"""URI-reference"""
spec_version: builtins.str
type: builtins.str
binary_data: builtins.bytes
text_data: builtins.str
@property
def attributes(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___CloudEvent.CloudEventAttributeValue]:
"""Optional & Extension Attributes"""
@property
def proto_data(self) -> google.protobuf.any_pb2.Any: ...
def __init__(
self,
*,
id: builtins.str = ...,
source: builtins.str = ...,
spec_version: builtins.str = ...,
type: builtins.str = ...,
attributes: collections.abc.Mapping[builtins.str, global___CloudEvent.CloudEventAttributeValue] | None = ...,
binary_data: builtins.bytes = ...,
text_data: builtins.str = ...,
proto_data: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["binary_data", b"binary_data", "data", b"data", "proto_data", b"proto_data", "text_data", b"text_data"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["attributes", b"attributes", "binary_data", b"binary_data", "data", b"data", "id", b"id", "proto_data", b"proto_data", "source", b"source", "spec_version", b"spec_version", "text_data", b"text_data", "type", b"type"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["data", b"data"]) -> typing.Literal["binary_data", "text_data", "proto_data"] | None: ...
global___CloudEvent = CloudEvent

View File

@@ -0,0 +1,24 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
GRPC_GENERATED_VERSION = '1.70.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in cloudevent_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)

View File

@@ -0,0 +1,23 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
*
CloudEvent Protobuf Format
- Required context attributes are explicitly represented.
- Optional and Extension context attributes are carried in a map structure.
- Data may be represented as binary, text, or protobuf messages.
"""
import abc
import collections.abc
import grpc
import grpc.aio
import typing
_T = typing.TypeVar("_T")
class _MaybeAsyncIterator(collections.abc.AsyncIterator[_T], collections.abc.Iterator[_T], metaclass=abc.ABCMeta): ...
class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type: ignore[misc, type-arg]
...

View File

@@ -0,0 +1,287 @@
import warnings
from typing import Awaitable, Callable, List, Optional, Union
from agentdhal_agentchat.agents import ApprovalFuncType, CodeExecutorAgent, UserProxyAgent
from agentdhal_agentchat.base import ChatAgent
from agentdhal_agentchat.teams import MagenticOneGroupChat
from agentdhal_core import CancellationToken
from agentdhal_core.code_executor import CodeExecutor
from agentdhal_core.models import ChatCompletionClient
from agentdhal_extensions.agents.file_surfer import FileSurfer
from agentdhal_extensions.agents.magentic_one import MagenticOneCoderAgent
from agentdhal_extensions.agents.web_surfer import MultimodalWebSurfer
from agentdhal_extensions.code_executors.local import LocalCommandLineCodeExecutor
from agentdhal_extensions.models.openai._openai_client import BaseOpenAIChatCompletionClient
# Docker imports for default code executor
try:
import docker
from docker.errors import DockerException
from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor
_docker_available = True
except ImportError:
docker = None # type: ignore
DockerException = Exception # type: ignore
DockerCommandLineCodeExecutor = None # type: ignore
_docker_available = False
SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
def _is_docker_available() -> bool:
"""Check if Docker is available and running."""
if not _docker_available:
return False
try:
if docker is not None:
client = docker.from_env()
client.ping() # type: ignore
return True
except DockerException:
return False
return False
def _create_default_code_executor() -> CodeExecutor:
"""Create the default code executor, preferring Docker if available."""
if _is_docker_available() and DockerCommandLineCodeExecutor is not None:
try:
return DockerCommandLineCodeExecutor()
except Exception:
# Fallback to local if Docker fails to initialize
pass
# Issue warning and use local executor if Docker is not available
warnings.warn(
"Docker is not available or not running. Using LocalCommandLineCodeExecutor instead of the recommended DockerCommandLineCodeExecutor. "
"For security, it is recommended to install Docker and ensure it's running before using MagenticOne. "
"To install Docker, visit: https://docs.docker.com/get-docker/",
UserWarning,
stacklevel=3,
)
return LocalCommandLineCodeExecutor()
class MagenticOne(MagenticOneGroupChat):
"""
MagenticOne is a specialized group chat class that integrates various agents
such as FileSurfer, WebSurfer, Coder, and Executor to solve complex tasks.
To read more about the science behind Magentic-One, see the full blog post: `Magentic-One: A Generalist Multi-Agent System for Solving Complex Tasks <https://www.microsoft.com/en-us/research/articles/magentic-one-a-generalist-multi-agent-system-for-solving-complex-tasks>`_ and the references below.
Installation:
.. code-block:: bash
pip install "agentdhal-ext[magentic-one]"
Args:
client (ChatCompletionClient): The client used for model interactions.
hil_mode (bool): Optional; If set to True, adds the UserProxyAgent to the list of agents.
input_func (InputFuncType | None): Optional; Function to use for user input in human-in-the-loop mode.
code_executor (CodeExecutor | None): Optional; Code executor to use. If None, will use Docker if available, otherwise local executor.
approval_func (ApprovalFuncType | None): Optional; Function to approve code execution before running. If None, code will execute without approval.
.. warning::
Using Magentic-One involves interacting with a digital world designed for humans, which carries inherent risks. To minimize these risks, consider the following precautions:
1. **Use Containers**: Run all tasks in docker containers to isolate the agents and prevent direct system attacks.
2. **Virtual Environment**: Use a virtual environment to run the agents and prevent them from accessing sensitive data.
3. **Monitor Logs**: Closely monitor logs during and after execution to detect and mitigate risky behavior.
4. **Human Oversight**: Run the examples with a human in the loop to supervise the agents and prevent unintended consequences.
5. **Limit Access**: Restrict the agents' access to the internet and other resources to prevent unauthorized actions.
6. **Safeguard Data**: Ensure that the agents do not have access to sensitive data or resources that could be compromised. Do not share sensitive information with the agents.
Be aware that agents may occasionally attempt risky actions, such as recruiting humans for help or accepting cookie agreements without human involvement. Always ensure agents are monitored and operate within a controlled environment to prevent unintended consequences. Moreover, be cautious that Magentic-One may be susceptible to prompt injection attacks from webpages.
Architecture:
Magentic-One is a generalist multi-agent system for solving open-ended web and file-based tasks across a variety of domains. It represents a significant step towards developing agents that can complete tasks that people encounter in their work and personal lives.
Magentic-One work is based on a multi-agent architecture where a lead Orchestrator agent is responsible for high-level planning, directing other agents, and tracking task progress. The Orchestrator begins by creating a plan to tackle the task, gathering needed facts and educated guesses in a Task Ledger that is maintained. At each step of its plan, the Orchestrator creates a Progress Ledger where it self-reflects on task progress and checks whether the task is completed. If the task is not yet completed, it assigns one of Magentic-One's other agents a subtask to complete. After the assigned agent completes its subtask, the Orchestrator updates the Progress Ledger and continues in this way until the task is complete. If the Orchestrator finds that progress is not being made for enough steps, it can update the Task Ledger and create a new plan.
Overall, Magentic-One consists of the following agents:
- Orchestrator: The lead agent responsible for task decomposition and planning, directing other agents in executing subtasks, tracking overall progress, and taking corrective actions as needed.
- WebSurfer: An LLM-based agent proficient in commanding and managing the state of a Chromium-based web browser. It performs actions on the browser and reports on the new state of the web page.
- FileSurfer: An LLM-based agent that commands a markdown-based file preview application to read local files of most types. It can also perform common navigation tasks such as listing the contents of directories and navigating a folder structure.
- Coder: An LLM-based agent specialized in writing code, analyzing information collected from other agents, or creating new artifacts.
- ComputerTerminal: Provides the team with access to a console shell where the Coder's programs can be executed, and where new programming libraries can be installed.
Together, Magentic-One's agents provide the Orchestrator with the tools and capabilities needed to solve a broad variety of open-ended problems, as well as the ability to autonomously adapt to, and act in, dynamic and ever-changing web and file-system environments.
Examples:
.. code-block:: python
# Autonomously complete a coding task:
import asyncio
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.teams.magentic_one import MagenticOne
from agentdhal_agentchat.ui import Console
async def example_usage():
client = OpenAIChatCompletionClient(model="gpt-4o")
m1 = MagenticOne(client=client) # Uses DockerCommandLineCodeExecutor by default
task = "Write a Python script to fetch data from an API."
result = await Console(m1.run_stream(task=task))
print(result)
if __name__ == "__main__":
asyncio.run(example_usage())
.. code-block:: python
# Enable human-in-the-loop mode with explicit Docker executor and code approval
import asyncio
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.teams.magentic_one import MagenticOne
from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor
from agentdhal_agentchat.ui import Console
from agentdhal_agentchat.agents import ApprovalRequest, ApprovalResponse
def user_input_func(prompt: str) -> str:
\"\"\"Custom input function for user interaction.\"\"\"
return input(prompt)
def approval_func(request: ApprovalRequest) -> ApprovalResponse:
\"\"\"Simple approval function that requests user input.\"\"\"
print(f\"Code to execute:\\n{request.code}\")
user_input = input("Do you approve this code execution? (y/n): ").strip().lower()
if user_input == 'y':
return ApprovalResponse(approved=True, reason=\"User approved the code execution\")
else:
return ApprovalResponse(approved=False, reason=\"User denied the code execution\")
async def example_usage_hil():
client = OpenAIChatCompletionClient(model="gpt-4o")
# Explicitly specify Docker code executor for better security
async with DockerCommandLineCodeExecutor() as code_executor:
m1 = MagenticOne(
client=client,
hil_mode=True,
input_func=user_input_func,
code_executor=code_executor,
approval_func=approval_func
)
task = "Write a Python script to fetch data from an API."
result = await Console(m1.run_stream(task=task))
print(result)
if __name__ == "__main__":
asyncio.run(example_usage_hil())
.. code-block:: python
# Enable code execution approval without human-in-the-loop mode
import asyncio
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.teams.magentic_one import MagenticOne
from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor
from agentdhal_agentchat.ui import Console
from agentdhal_agentchat.agents import ApprovalRequest, ApprovalResponse
def approval_func(request: ApprovalRequest) -> ApprovalResponse:
\"\"\"Simple approval function that requests user input.\"\"\"
print(f\"Code to execute:\\n{request.code}\")
user_input = input("Do you approve this code execution? (y/n): ").strip().lower()
if user_input == 'y':
return ApprovalResponse(approved=True, reason=\"User approved the code execution\")
else:
return ApprovalResponse(approved=False, reason=\"User denied the code execution\")
async def example_usage_with_approval():
client = OpenAIChatCompletionClient(model="gpt-4o")
# Use approval_func for code approval only (hil_mode=False)
async with DockerCommandLineCodeExecutor() as code_executor:
m1 = MagenticOne(
client=client,
hil_mode=False, # No human-in-the-loop for general conversation
code_executor=code_executor,
approval_func=approval_func # But still ask for code execution approval
)
task = "Write a Python script to fetch data from an API."
result = await Console(m1.run_stream(task=task))
print(result)
if __name__ == "__main__":
asyncio.run(example_usage_with_approval())
References:
.. code-block:: bibtex
@article{fourney2024magentic,
title={Magentic-one: A generalist multi-agent system for solving complex tasks},
author={Fourney, Adam and Bansal, Gagan and Mozannar, Hussein and Tan, Cheng and Salinas, Eduardo and Niedtner, Friederike and Proebsting, Grace and Bassman, Griffin and Gerrits, Jack and Alber, Jacob and others},
journal={arXiv preprint arXiv:2411.04468},
year={2024},
url={https://arxiv.org/abs/2411.04468}
}
"""
def __init__(
self,
client: ChatCompletionClient,
hil_mode: bool = False,
input_func: InputFuncType | None = None,
code_executor: CodeExecutor | None = None,
approval_func: ApprovalFuncType | None = None,
):
self.client = client
self._validate_client_capabilities(client)
if code_executor is None:
warnings.warn(
"Instantiating MagenticOne without a code_executor is deprecated. Provide a code_executor to clear this warning (e.g., code_executor=DockerCommandLineCodeExecutor() ).",
DeprecationWarning,
stacklevel=2,
)
code_executor = _create_default_code_executor()
fs = FileSurfer("FileSurfer", model_client=client)
ws = MultimodalWebSurfer("WebSurfer", model_client=client)
coder = MagenticOneCoderAgent("Coder", model_client=client)
executor = CodeExecutorAgent("ComputerTerminal", code_executor=code_executor, approval_func=approval_func)
agents: List[ChatAgent] = [fs, ws, coder, executor]
if hil_mode:
user_proxy = UserProxyAgent("User", input_func=input_func)
agents.append(user_proxy)
super().__init__(agents, model_client=client)
def _validate_client_capabilities(self, client: ChatCompletionClient) -> None:
capabilities = client.model_info
required_capabilities = ["function_calling", "json_output"]
if not all(capabilities.get(cap) for cap in required_capabilities):
warnings.warn(
"Client capabilities for MagenticOne must include vision, " "function calling, and json output.",
stacklevel=2,
)
if not isinstance(client, BaseOpenAIChatCompletionClient):
warnings.warn(
"MagenticOne performs best with OpenAI GPT-4o model either " "through OpenAI or Azure OpenAI.",
stacklevel=2,
)

View File

@@ -0,0 +1,19 @@
from ._ai_search import (
AzureAISearchTool,
BaseAzureAISearchTool,
SearchQuery,
SearchResult,
SearchResults,
VectorizableTextQuery,
)
from ._config import AzureAISearchConfig
__all__ = [
"AzureAISearchTool",
"BaseAzureAISearchTool",
"SearchQuery",
"SearchResult",
"SearchResults",
"AzureAISearchConfig",
"VectorizableTextQuery",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,186 @@
"""Configuration for Azure AI Search tool.
This module provides configuration classes for the Azure AI Search tool, including
settings for authentication, search behavior, retry policies, and caching.
"""
import logging
from typing import (
List,
Literal,
Optional,
TypeVar,
Union,
)
from pydantic import BaseModel, Field, field_validator, model_validator
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
T = TypeVar("T", bound="AzureAISearchConfig")
logger = logging.getLogger(__name__)
QueryTypeLiteral = Literal["simple", "full", "semantic", "vector"]
DEFAULT_API_VERSION = "2023-10-01-preview"
class AzureAISearchConfig(BaseModel):
"""Configuration for Azure AI Search with validation.
This class defines the configuration parameters for Azure AI Search tools, including
authentication, search behavior, caching, and embedding settings.
.. note::
This class requires the ``azure`` extra for the ``autogen-ext`` package.
.. code-block:: bash
pip install -U "agentdhal-ext[azure]"
.. note::
**Prerequisites:**
1. An Azure AI Search service must be created in your Azure subscription.
2. The search index must be properly configured for your use case:
- For vector search: Index must have vector fields
- For semantic search: Index must have semantic configuration
- For hybrid search: Both vector fields and text fields must be configured
3. Required packages:
- Base functionality: ``azure-search-documents>=11.4.0``
- For Azure OpenAI embeddings: ``openai azure-identity``
- For OpenAI embeddings: ``openai``
Example Usage:
.. code-block:: python
from azure.core.credentials import AzureKeyCredential
from agentdhal_extensions.tools.azure import AzureAISearchConfig
# Basic configuration for full-text search
config = AzureAISearchConfig(
name="doc-search",
endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint
index_name="<your-index>", # Name of your search index
credential=AzureKeyCredential("<your-key>"), # Your Azure AI Search admin key
query_type="simple",
search_fields=["content", "title"], # Update with your searchable fields
top=5,
)
# Configuration for vector search with Azure OpenAI embeddings
vector_config = AzureAISearchConfig(
name="vector-search",
endpoint="https://your-search.search.windows.net",
index_name="<your-index>",
credential=AzureKeyCredential("<your-key>"),
query_type="vector",
vector_fields=["embedding"], # Update with your vector field name
embedding_provider="azure_openai",
embedding_model="text-embedding-ada-002",
openai_endpoint="https://your-openai.openai.azure.com", # Your Azure OpenAI endpoint
openai_api_key="<your-openai-key>", # Your Azure OpenAI key
top=5,
)
# Configuration for hybrid search with semantic ranking
hybrid_config = AzureAISearchConfig(
name="hybrid-search",
endpoint="https://your-search.search.windows.net",
index_name="<your-index>",
credential=AzureKeyCredential("<your-key>"),
query_type="semantic",
semantic_config_name="<your-semantic-config>", # Name of your semantic configuration
search_fields=["content", "title"], # Update with your search fields
vector_fields=["embedding"], # Update with your vector field name
embedding_provider="openai",
embedding_model="text-embedding-ada-002",
openai_api_key="<your-openai-key>", # Your OpenAI API key
top=5,
)
"""
name: str = Field(description="The name of this tool instance")
description: Optional[str] = Field(default=None, description="Description explaining the tool's purpose")
endpoint: str = Field(description="The full URL of your Azure AI Search service")
index_name: str = Field(description="Name of the search index to query")
credential: Union[AzureKeyCredential, AsyncTokenCredential] = Field(
description="Azure credential for authentication (API key or token)"
)
api_version: str = Field(
default=DEFAULT_API_VERSION,
description=f"Azure AI Search API version to use. Defaults to {DEFAULT_API_VERSION}.",
)
query_type: QueryTypeLiteral = Field(
default="simple", description="Type of search to perform: simple, full, semantic, or vector"
)
search_fields: Optional[List[str]] = Field(default=None, description="Fields to search within documents")
select_fields: Optional[List[str]] = Field(default=None, description="Fields to return in search results")
vector_fields: Optional[List[str]] = Field(default=None, description="Fields to use for vector search")
top: Optional[int] = Field(
default=None, description="Maximum number of results to return. For vector searches, acts as k in k-NN."
)
filter: Optional[str] = Field(default=None, description="OData filter expression to refine search results")
semantic_config_name: Optional[str] = Field(
default=None, description="Semantic configuration name for enhanced results"
)
enable_caching: bool = Field(default=False, description="Whether to cache search results")
cache_ttl_seconds: int = Field(default=300, description="How long to cache results in seconds")
embedding_provider: Optional[str] = Field(
default=None, description="Name of embedding provider for client-side embeddings"
)
embedding_model: Optional[str] = Field(default=None, description="Model name for client-side embeddings")
openai_api_key: Optional[str] = Field(default=None, description="API key for OpenAI/Azure OpenAI embeddings")
openai_api_version: Optional[str] = Field(default=None, description="API version for Azure OpenAI embeddings")
openai_endpoint: Optional[str] = Field(default=None, description="Endpoint URL for Azure OpenAI embeddings")
model_config = {"arbitrary_types_allowed": True}
@field_validator("endpoint")
def validate_endpoint(cls, v: str) -> str:
"""Validate that the endpoint is a valid URL."""
if not v.startswith(("http://", "https://")):
raise ValueError("endpoint must be a valid URL starting with http:// or https://")
return v
@field_validator("query_type")
def normalize_query_type(cls, v: QueryTypeLiteral) -> QueryTypeLiteral:
"""Normalize query type to standard values."""
if not v:
return "simple"
if isinstance(v, str) and v.lower() == "fulltext":
return "full"
return v
@field_validator("top")
def validate_top(cls, v: Optional[int]) -> Optional[int]:
"""Ensure top is a positive integer if provided."""
if v is not None and v <= 0:
raise ValueError("top must be a positive integer")
return v
@model_validator(mode="after")
def validate_interdependent_fields(self) -> "AzureAISearchConfig":
"""Validate interdependent fields after all fields have been parsed."""
if self.query_type == "semantic" and not self.semantic_config_name:
raise ValueError("semantic_config_name must be provided when query_type is 'semantic'")
if self.query_type == "vector" and not self.vector_fields:
raise ValueError("vector_fields must be provided for vector search")
if (
self.embedding_provider
and self.embedding_provider.lower() == "azure_openai"
and self.embedding_model
and not self.openai_endpoint
):
raise ValueError("openai_endpoint must be provided for azure_openai embedding provider")
return self

View File

@@ -0,0 +1,3 @@
from ._code_execution import CodeExecutionInput, CodeExecutionResult, PythonCodeExecutionTool
__all__ = ["CodeExecutionInput", "CodeExecutionResult", "PythonCodeExecutionTool"]

View File

@@ -0,0 +1,96 @@
from agentdhal_core import CancellationToken, Component, ComponentModel
from agentdhal_core.code_executor import CodeBlock, CodeExecutor
from agentdhal_core.tools import BaseTool
from pydantic import BaseModel, Field, model_serializer
from typing_extensions import Self
class CodeExecutionInput(BaseModel):
code: str = Field(description="The contents of the Python code block that should be executed")
class CodeExecutionResult(BaseModel):
success: bool
output: str
@model_serializer
def ser_model(self) -> str:
return self.output
class PythonCodeExecutionToolConfig(BaseModel):
"""Configuration for PythonCodeExecutionTool"""
executor: ComponentModel
description: str = "Execute Python code blocks."
class PythonCodeExecutionTool(
BaseTool[CodeExecutionInput, CodeExecutionResult], Component[PythonCodeExecutionToolConfig]
):
"""A tool that executes Python code in a code executor and returns output.
Example executors:
* :class:`agentdhal_extensions.code_executors.local.LocalCommandLineCodeExecutor`
* :class:`agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor`
* :class:`agentdhal_extensions.code_executors.azure.ACADynamicSessionsCodeExecutor`
Example usage:
.. code-block:: bash
pip install -U "agentdhal-agentchat" "agentdhal-ext[openai]" "yfinance" "matplotlib"
.. code-block:: python
import asyncio
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.code_executors.local import LocalCommandLineCodeExecutor
from agentdhal_extensions.tools.code_execution import PythonCodeExecutionTool
async def main() -> None:
tool = PythonCodeExecutionTool(LocalCommandLineCodeExecutor(work_dir="coding"))
agent = AssistantAgent(
"assistant", OpenAIChatCompletionClient(model="gpt-4o"), tools=[tool], reflect_on_tool_use=True
)
await Console(
agent.run_stream(
task="Create a plot of MSFT stock prices in 2024 and save it to a file. Use yfinance and matplotlib."
)
)
asyncio.run(main())
Args:
executor (CodeExecutor): The code executor that will be used to execute the code blocks.
"""
component_config_schema = PythonCodeExecutionToolConfig
component_provider_override = "agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool"
def __init__(self, executor: CodeExecutor):
super().__init__(CodeExecutionInput, CodeExecutionResult, "CodeExecutor", "Execute Python code blocks.")
self._executor = executor
async def run(self, args: CodeExecutionInput, cancellation_token: CancellationToken) -> CodeExecutionResult:
code_blocks = [CodeBlock(code=args.code, language="python")]
result = await self._executor.execute_code_blocks(
code_blocks=code_blocks, cancellation_token=cancellation_token
)
return CodeExecutionResult(success=result.exit_code == 0, output=result.output)
def _to_config(self) -> PythonCodeExecutionToolConfig:
"""Convert current instance to config object"""
return PythonCodeExecutionToolConfig(executor=self._executor.dump_component())
@classmethod
def _from_config(cls, config: PythonCodeExecutionToolConfig) -> Self:
"""Create instance from config object"""
executor = CodeExecutor.load_component(config.executor)
return cls(executor=executor)

View File

@@ -0,0 +1,25 @@
from ._config import (
GlobalContextConfig,
GlobalDataConfig,
LocalContextConfig,
LocalDataConfig,
MapReduceConfig,
SearchConfig,
)
from ._global_search import GlobalSearchTool, GlobalSearchToolArgs, GlobalSearchToolReturn
from ._local_search import LocalSearchTool, LocalSearchToolArgs, LocalSearchToolReturn
__all__ = [
"GlobalSearchTool",
"LocalSearchTool",
"GlobalDataConfig",
"LocalDataConfig",
"GlobalContextConfig",
"GlobalSearchToolArgs",
"GlobalSearchToolReturn",
"LocalContextConfig",
"LocalSearchToolArgs",
"LocalSearchToolReturn",
"MapReduceConfig",
"SearchConfig",
]

View File

@@ -0,0 +1,59 @@
from pydantic import BaseModel
class DataConfig(BaseModel):
input_dir: str
entity_table: str = "entities"
entity_embedding_table: str = "entities"
community_table: str = "communities"
community_level: int = 2
class GlobalDataConfig(DataConfig):
community_report_table: str = "community_reports"
class LocalDataConfig(DataConfig):
relationship_table: str = "relationships"
text_unit_table: str = "text_units"
class ContextConfig(BaseModel):
max_data_tokens: int = 8000
class GlobalContextConfig(ContextConfig):
use_community_summary: bool = False
shuffle_data: bool = True
include_community_rank: bool = True
min_community_rank: int = 0
community_rank_name: str = "rank"
include_community_weight: bool = True
community_weight_name: str = "occurrence weight"
normalize_community_weight: bool = True
max_data_tokens: int = 12000
class LocalContextConfig(ContextConfig):
text_unit_prop: float = 0.5
community_prop: float = 0.25
include_entity_rank: bool = True
rank_description: str = "number of relationships"
include_relationship_weight: bool = True
relationship_ranking_attribute: str = "rank"
class MapReduceConfig(BaseModel):
map_max_tokens: int = 1000
map_temperature: float = 0.0
reduce_max_tokens: int = 2000
reduce_temperature: float = 0.0
allow_general_knowledge: bool = False
json_mode: bool = False
response_type: str = "multiple paragraphs"
class SearchConfig(BaseModel):
max_tokens: int = 1500
temperature: float = 0.0
response_type: str = "multiple paragraphs"

View File

@@ -0,0 +1,233 @@
from pathlib import Path
import pandas as pd
import tiktoken
from agentdhal_core import CancellationToken
from agentdhal_core.tools import BaseTool
from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
from graphrag.config.load_config import load_config
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol import ChatModel
from graphrag.query.indexer_adapters import (
read_indexer_communities,
read_indexer_entities,
read_indexer_reports,
)
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
from graphrag.query.structured_search.global_search.search import GlobalSearch
from ._config import GlobalContextConfig as ContextConfig
from ._config import GlobalDataConfig as DataConfig
from ._config import MapReduceConfig
_default_context_config = ContextConfig()
_default_mapreduce_config = MapReduceConfig()
class GlobalSearchToolArgs(BaseModel):
query: str = Field(..., description="The user query to perform global search on.")
class GlobalSearchToolReturn(BaseModel):
answer: str
class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]):
"""Enables running GraphRAG global search queries as an AutoGen tool.
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework.
The search combines graph-based document relationships with semantic embeddings to find relevant information.
.. note::
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package.
To install:
.. code-block:: bash
pip install -U "agentdhal-agentchat" "agentdhal-ext[graphrag]"
Before using this tool, you must complete the GraphRAG setup and indexing process:
1. Follow the GraphRAG documentation to initialize your project and settings
2. Configure and tune your prompts for the specific use case
3. Run the indexing process to generate the required data files
4. Ensure you have the settings.yaml file from the setup process
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/)
for detailed instructions on completing these prerequisite steps.
Example usage with AssistantAgent:
.. code-block:: python
import asyncio
from pathlib import Path
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.tools.graphrag import GlobalSearchTool
from agentdhal_agentchat.agents import AssistantAgent
async def main():
# Initialize the OpenAI client
openai_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key="<api-key>",
)
# Set up global search tool
global_tool = GlobalSearchTool.from_settings(root_dir=Path("./"), config_filepath=Path("./settings.yaml"))
# Create assistant agent with the global search tool
assistant_agent = AssistantAgent(
name="search_assistant",
tools=[global_tool],
model_client=openai_client,
system_message=(
"You are a tool selector AI assistant using the GraphRAG framework. "
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
"For broader, abstract questions requiring a comprehensive understanding of the dataset, call the 'global_search' function."
),
)
# Run a sample query
query = "What is the overall sentiment of the community reports?"
await Console(assistant_agent.run_stream(task=query))
if __name__ == "__main__":
asyncio.run(main())
"""
def __init__(
self,
token_encoder: tiktoken.Encoding,
model: ChatModel,
data_config: DataConfig,
context_config: ContextConfig = _default_context_config,
mapreduce_config: MapReduceConfig = _default_mapreduce_config,
):
super().__init__(
args_type=GlobalSearchToolArgs,
return_type=GlobalSearchToolReturn,
name="global_search_tool",
description="Perform a global search with given parameters using graphrag.",
)
# Use the provided model
self._model = model
# Load parquet files
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
report_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.community_report_table}.parquet"
)
# Fix: Use correct argument order and types for GraphRAG API
communities = read_indexer_communities(community_df, report_df)
reports = read_indexer_reports(report_df, community_df, data_config.community_level)
entities = read_indexer_entities(entity_df, community_df, data_config.community_level)
context_builder = GlobalCommunityContext(
community_reports=reports,
communities=communities,
entities=entities,
token_encoder=token_encoder,
)
context_builder_params = {
"use_community_summary": context_config.use_community_summary,
"shuffle_data": context_config.shuffle_data,
"include_community_rank": context_config.include_community_rank,
"min_community_rank": context_config.min_community_rank,
"community_rank_name": context_config.community_rank_name,
"include_community_weight": context_config.include_community_weight,
"community_weight_name": context_config.community_weight_name,
"normalize_community_weight": context_config.normalize_community_weight,
"max_tokens": context_config.max_data_tokens,
"context_name": "Reports",
}
map_llm_params = {
"max_tokens": mapreduce_config.map_max_tokens,
"temperature": mapreduce_config.map_temperature,
"response_format": {"type": "json_object"},
}
reduce_llm_params = {
"max_tokens": mapreduce_config.reduce_max_tokens,
"temperature": mapreduce_config.reduce_temperature,
}
self._search_engine = GlobalSearch(
model=self._model,
context_builder=context_builder,
token_encoder=token_encoder,
max_data_tokens=context_config.max_data_tokens,
map_llm_params=map_llm_params,
reduce_llm_params=reduce_llm_params,
allow_general_knowledge=mapreduce_config.allow_general_knowledge,
json_mode=mapreduce_config.json_mode,
context_builder_params=context_builder_params,
concurrent_coroutines=32,
response_type=mapreduce_config.response_type,
)
async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> GlobalSearchToolReturn:
search_result = await self._search_engine.search(args.query)
assert isinstance(search_result.response, str), "Expected response to be a string"
return GlobalSearchToolReturn(answer=search_result.response)
@classmethod
def from_settings(cls, root_dir: str | Path, config_filepath: str | Path | None = None) -> "GlobalSearchTool":
"""Create a GlobalSearchTool instance from GraphRAG settings file.
Args:
root_dir: Path to the GraphRAG root directory
config_filepath: Path to the GraphRAG settings file (optional)
Returns:
An initialized GlobalSearchTool instance
"""
# Load GraphRAG config
if isinstance(root_dir, str):
root_dir = Path(root_dir)
if isinstance(config_filepath, str):
config_filepath = Path(config_filepath)
config = load_config(root_dir=root_dir, config_filepath=config_filepath)
# Get the language model configuration from the models section
chat_model_config = config.models.get(defs.DEFAULT_CHAT_MODEL_ID)
if chat_model_config is None:
raise ValueError("default_chat_model not found in config.models")
# Initialize token encoder based on the model being used
try:
token_encoder = tiktoken.encoding_for_model(chat_model_config.model)
except KeyError:
# Fallback to cl100k_base if model is not recognized by tiktoken
token_encoder = tiktoken.get_encoding("cl100k_base")
# Create the LLM using ModelManager
model = ModelManager().get_or_create_chat_model(
name="global_search_model",
model_type=chat_model_config.type,
config=chat_model_config,
)
# Create data config from storage paths
data_config = DataConfig(
input_dir=str(config.output.base_dir),
)
return cls(
token_encoder=token_encoder,
model=model,
data_config=data_config,
context_config=_default_context_config,
mapreduce_config=_default_mapreduce_config,
)

View File

@@ -0,0 +1,245 @@
# mypy: disable-error-code="no-any-unimported,misc"
from pathlib import Path
import pandas as pd
import tiktoken
from agentdhal_core import CancellationToken
from agentdhal_core.tools import BaseTool
from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
from graphrag.config.load_config import load_config
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol import ChatModel, EmbeddingModel
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_text_units,
)
from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from ._config import LocalContextConfig, SearchConfig
from ._config import LocalDataConfig as DataConfig
_default_context_config = LocalContextConfig()
_default_search_config = SearchConfig()
class LocalSearchToolArgs(BaseModel):
query: str = Field(..., description="The user query to perform local search on.")
class LocalSearchToolReturn(BaseModel):
answer: str = Field(..., description="The answer to the user query.")
class LocalSearchTool(BaseTool[LocalSearchToolArgs, LocalSearchToolReturn]):
"""Enables running GraphRAG local search queries as an AutoGen tool.
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework.
The search combines local document context with semantic embeddings to find relevant information.
.. note::
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package.
To install:
.. code-block:: bash
pip install -U "agentdhal-agentchat" "agentdhal-ext[graphrag]"
Before using this tool, you must complete the GraphRAG setup and indexing process:
1. Follow the GraphRAG documentation to initialize your project and settings
2. Configure and tune your prompts for the specific use case
3. Run the indexing process to generate the required data files
4. Ensure you have the settings.yaml file from the setup process
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/)
for detailed instructions on completing these prerequisite steps.
Example usage with AssistantAgent:
.. code-block:: python
import asyncio
from pathlib import Path
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.tools.graphrag import LocalSearchTool
from agentdhal_agentchat.agents import AssistantAgent
async def main():
# Initialize the OpenAI client
openai_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key="<api-key>",
)
# Set up local search tool
local_tool = LocalSearchTool.from_settings(root_dir=Path("./"), config_filepath=Path("./settings.yaml"))
# Create assistant agent with the local search tool
assistant_agent = AssistantAgent(
name="search_assistant",
tools=[local_tool],
model_client=openai_client,
system_message=(
"You are a tool selector AI assistant using the GraphRAG framework. "
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
"For specific, detailed information about particular entities or relationships, call the 'local_search' function."
),
)
# Run a sample query
query = "What does the station-master say about Dr. Becher?"
await Console(assistant_agent.run_stream(task=query))
if __name__ == "__main__":
asyncio.run(main())
Args:
token_encoder (tiktoken.Encoding): The tokenizer used for text encoding
model: The chat model to use for search (GraphRAG ChatModel)
embedder: The text embedding model to use (GraphRAG EmbeddingModel)
data_config (DataConfig): Configuration for data source locations and settings
context_config (LocalContextConfig, optional): Configuration for context building. Defaults to default config.
search_config (SearchConfig, optional): Configuration for search operations. Defaults to default config.
"""
def __init__(
self,
token_encoder: tiktoken.Encoding,
model: ChatModel, # ChatModel from GraphRAG
embedder: EmbeddingModel, # EmbeddingModel from GraphRAG
data_config: DataConfig,
context_config: LocalContextConfig = _default_context_config,
search_config: SearchConfig = _default_search_config,
):
super().__init__(
args_type=LocalSearchToolArgs,
return_type=LocalSearchToolReturn,
name="local_search_tool",
description="Perform a local search with given parameters using graphrag.",
)
# Use the provided models
self._model = model
self._embedder = embedder
# Load parquet files
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
relationship_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.relationship_table}.parquet"
)
text_unit_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.text_unit_table}.parquet") # type: ignore
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
# Read data using indexer adapters
entities = read_indexer_entities(entity_df, community_df, data_config.community_level)
relationships = read_indexer_relationships(relationship_df)
text_units = read_indexer_text_units(text_unit_df)
# Set up vector store for entity embeddings
description_embedding_store = LanceDBVectorStore(
collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=f"{data_config.input_dir}/lancedb")
# Set up context builder
context_builder = LocalSearchMixedContext(
entities=entities,
entity_text_embeddings=description_embedding_store,
text_embedder=self._embedder,
text_units=text_units,
relationships=relationships,
token_encoder=token_encoder,
)
context_builder_params = {
"text_unit_prop": context_config.text_unit_prop,
"community_prop": context_config.community_prop,
"include_entity_rank": context_config.include_entity_rank,
"rank_description": context_config.rank_description,
"include_relationship_weight": context_config.include_relationship_weight,
"relationship_ranking_attribute": context_config.relationship_ranking_attribute,
"max_tokens": context_config.max_data_tokens,
}
llm_params = {
"max_tokens": search_config.max_tokens,
"temperature": search_config.temperature,
}
self._search_engine = LocalSearch(
model=self._model,
context_builder=context_builder,
token_encoder=token_encoder,
response_type=search_config.response_type,
context_builder_params=context_builder_params,
model_params=llm_params,
)
async def run(self, args: LocalSearchToolArgs, cancellation_token: CancellationToken) -> LocalSearchToolReturn:
search_result = await self._search_engine.search(args.query) # type: ignore[reportUnknownMemberType]
assert isinstance(search_result.response, str), "Expected response to be a string"
return LocalSearchToolReturn(answer=search_result.response)
@classmethod
def from_settings(cls, root_dir: Path, config_filepath: Path | None = None) -> "LocalSearchTool":
"""Create a LocalSearchTool instance from GraphRAG settings file.
Args:
root_dir: Path to the GraphRAG root directory
config_filepath: Path to the GraphRAG settings file (optional)
Returns:
An initialized LocalSearchTool instance
"""
# Load GraphRAG config
config = load_config(root_dir=root_dir, config_filepath=config_filepath)
# Get the language model configurations from the models section
chat_model_config = config.models.get(defs.DEFAULT_CHAT_MODEL_ID)
embedding_model_config = config.models.get(defs.DEFAULT_EMBEDDING_MODEL_ID)
if chat_model_config is None:
raise ValueError("default_chat_model not found in config.models")
if embedding_model_config is None:
raise ValueError("default_embedding_model not found in config.models")
# Initialize token encoder based on the model being used
try:
token_encoder = tiktoken.encoding_for_model(chat_model_config.model)
except KeyError:
# Fallback to cl100k_base if model is not recognized by tiktoken
token_encoder = tiktoken.get_encoding("cl100k_base")
# Create the models using ModelManager
model = ModelManager().get_or_create_chat_model(
name="local_search_model",
model_type=chat_model_config.type,
config=chat_model_config,
)
embedder = ModelManager().get_or_create_embedding_model(
name="local_search_embedder",
model_type=embedding_model_config.type,
config=embedding_model_config,
)
# Create data config from storage paths
data_config = DataConfig(
input_dir=str(config.output.base_dir),
)
return cls(
token_encoder=token_encoder,
model=model,
embedder=embedder,
data_config=data_config,
context_config=_default_context_config,
search_config=_default_search_config,
)

View File

@@ -0,0 +1,3 @@
from ._http_tool import HttpTool
__all__ = ["HttpTool"]

View File

@@ -0,0 +1,244 @@
import re
from typing import Any, Literal, Optional, Type
import httpx
from agentdhal_core import CancellationToken, Component
from agentdhal_core.tools import BaseTool
from json_schema_to_pydantic import create_model
from pydantic import BaseModel, Field
from typing_extensions import Self
DEFAULT_TIMEOUT_CONFIG = 5.0
class HttpToolConfig(BaseModel):
name: str
"""
The name of the tool.
"""
description: Optional[str]
"""
A description of the tool.
"""
scheme: Literal["http", "https"] = "http"
"""
The scheme to use for the request.
"""
host: str
"""
The URL to send the request to.
"""
port: int
"""
The port to send the request to.
"""
path: str = Field(default="/")
"""
The path to send the request to. defaults to "/"
The path can accept parameters, e.g. "/{param1}/{param2}".
These parameters will be templated from the inputs args, any additional parameters will be added as query parameters or the body of the request.
"""
method: Optional[Literal["GET", "POST", "PUT", "DELETE", "PATCH"]] = "POST"
"""
The HTTP method to use, will default to POST if not provided.
"""
headers: Optional[dict[str, Any]]
"""
A dictionary of headers to send with the request.
"""
json_schema: dict[str, Any]
"""
A JSON Schema object defining the expected parameters for the tool.
Path parameters MUST also be included in the json_schema. They must also MUST be set to string
"""
return_type: Optional[Literal["text", "json"]] = "text"
"""
The type of response to return from the tool.
"""
timeout: float = DEFAULT_TIMEOUT_CONFIG
"""
The timeout for the tool request in seconds.
"""
class HttpTool(BaseTool[BaseModel, Any], Component[HttpToolConfig]):
"""A wrapper for using an HTTP server as a tool.
Args:
name (str): The name of the tool.
description (str, optional): A description of the tool.
scheme (str): The scheme to use for the request. Must be either "http" or "https".
host (str): The host to send the request to.
port (int): The port to send the request to.
path (str, optional): The path to send the request to. Defaults to "/".
Can include path parameters like "/{param1}/{param2}" which will be templated from input args.
method (str, optional): The HTTP method to use, will default to POST if not provided.
Must be one of "GET", "POST", "PUT", "DELETE", "PATCH".
headers (dict[str, Any], optional): A dictionary of headers to send with the request.
json_schema (dict[str, Any]): A JSON Schema object defining the expected parameters for the tool.
Path parameters must also be included in the schema and must be strings.
return_type (Literal["text", "json"], optional): The type of response to return from the tool.
Defaults to "text".
timeout (float, optional): The timeout for HTTP requests in seconds.
Defaults to 5.0.
.. note::
This tool requires the :code:`http-tool` extra for the :code:`autogen-ext` package.
To install:
.. code-block:: bash
pip install -U "agentdhal-agentchat" "agentdhal-ext[http-tool]"
Example:
Simple use case::
import asyncio
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.messages import TextMessage
from agentdhal_core import CancellationToken
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.tools.http import HttpTool
# Define a JSON schema for a base64 decode tool
base64_schema = {
"type": "object",
"properties": {
"value": {"type": "string", "description": "The base64 value to decode"},
},
"required": ["value"],
}
# Create an HTTP tool for the httpbin API
base64_tool = HttpTool(
name="base64_decode",
description="base64 decode a value",
scheme="https",
host="httpbin.org",
port=443,
path="/base64/{value}",
method="GET",
json_schema=base64_schema,
)
async def main():
# Create an assistant with the base64 tool
model = OpenAIChatCompletionClient(model="gpt-4")
assistant = AssistantAgent("base64_assistant", model_client=model, tools=[base64_tool])
# The assistant can now use the base64 tool to decode the string
response = await assistant.on_messages(
[TextMessage(content="Can you base64 decode the value 'YWJjZGU=', please?", source="user")],
CancellationToken(),
)
print(response.chat_message)
asyncio.run(main())
"""
component_type = "tool"
component_provider_override = "agentdhal_extensions.tools.http.HttpTool"
component_config_schema = HttpToolConfig
def __init__(
self,
name: str,
host: str,
port: int,
json_schema: dict[str, Any],
headers: Optional[dict[str, Any]] = None,
description: str = "HTTP tool",
path: str = "/",
scheme: Literal["http", "https"] = "http",
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "POST",
return_type: Literal["text", "json"] = "text",
timeout: float = DEFAULT_TIMEOUT_CONFIG,
) -> None:
self.server_params = HttpToolConfig(
name=name,
description=description,
host=host,
port=port,
path=path,
scheme=scheme,
method=method,
headers=headers,
json_schema=json_schema,
return_type=return_type,
timeout=timeout,
)
# Use regex to find all path parameters, we will need those later to template the path
path_params = {match.group(1) for match in re.finditer(r"{([^}]*)}", path)}
self._path_params = path_params
# Create the input model from the modified schema
input_model = create_model(json_schema)
# Use Any as return type since HTTP responses can vary
base_return_type: Type[Any] = object
super().__init__(input_model, base_return_type, name, description)
def _to_config(self) -> HttpToolConfig:
copied_config = self.server_params.model_copy()
return copied_config
@classmethod
def _from_config(cls, config: HttpToolConfig) -> Self:
copied_config = config.model_copy().model_dump()
return cls(**copied_config)
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
"""Execute the HTTP tool with the given arguments.
Args:
args: The validated input arguments
cancellation_token: Token for cancelling the operation
Returns:
The response body from the HTTP call in JSON format
Raises:
Exception: If tool execution fails
"""
model_dump = args.model_dump()
path_params = {k: v for k, v in model_dump.items() if k in self._path_params}
# Remove path params from the model dump
for k in self._path_params:
model_dump.pop(k)
path = self.server_params.path.format(**path_params)
url = httpx.URL(
scheme=self.server_params.scheme,
host=self.server_params.host,
port=self.server_params.port,
path=path,
)
timeout_config = httpx.Timeout(timeout=self.server_params.timeout)
async with httpx.AsyncClient(timeout=timeout_config) as client:
match self.server_params.method:
case "GET":
response = await client.get(url, headers=self.server_params.headers, params=model_dump)
case "PUT":
response = await client.put(url, headers=self.server_params.headers, json=model_dump)
case "DELETE":
response = await client.delete(url, headers=self.server_params.headers, params=model_dump)
case "PATCH":
response = await client.patch(url, headers=self.server_params.headers, json=model_dump)
case _: # Default case POST
response = await client.post(url, headers=self.server_params.headers, json=model_dump)
match self.server_params.return_type:
case "text":
return response.text
case "json":
return response.json()
case _:
raise ValueError(f"Invalid return type: {self.server_params.return_type}")

View File

@@ -0,0 +1,3 @@
from ._langchain_adapter import LangChainToolAdapter
__all__ = ["LangChainToolAdapter"]

View File

@@ -0,0 +1,198 @@
from __future__ import annotations
import asyncio
import inspect
from typing import TYPE_CHECKING, Any, Callable, Dict, Type, cast
from agentdhal_core import CancellationToken
from agentdhal_core.tools import BaseTool
from pydantic import BaseModel, Field, create_model
if TYPE_CHECKING:
from langchain_core.tools import BaseTool as LangChainTool
class LangChainToolAdapter(BaseTool[BaseModel, Any]):
"""Allows you to wrap a LangChain tool and make it available to AutoGen.
.. note::
This class requires the :code:`langchain` extra for the :code:`autogen-ext` package.
.. code-block:: bash
pip install -U "agentdhal-ext[langchain]"
Args:
langchain_tool (LangChainTool): A LangChain tool to wrap
Examples:
Use the `PythonAstREPLTool` from the `langchain_experimental` package to
create a tool that allows you to interact with a Pandas DataFrame.
.. code-block:: python
import asyncio
import pandas as pd
from langchain_experimental.tools.python.tool import PythonAstREPLTool
from agentdhal_extensions.tools.langchain import LangChainToolAdapter
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_agentchat.messages import TextMessage
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.ui import Console
from agentdhal_core import CancellationToken
async def main() -> None:
df = pd.read_csv("https://raw.githubusercontent.com/pandas-dev/pandas/main/doc/data/titanic.csv") # type: ignore
tool = LangChainToolAdapter(PythonAstREPLTool(locals={"df": df}))
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(
"assistant",
tools=[tool],
model_client=model_client,
system_message="Use the `df` variable to access the dataset.",
)
await Console(
agent.on_messages_stream(
[TextMessage(content="What's the average age of the passengers?", source="user")], CancellationToken()
)
)
asyncio.run(main())
This example demonstrates how to use the `SQLDatabaseToolkit` from the `langchain_community`
package to interact with an SQLite database.
It uses the :class:`~agentdhal_agentchat.team.RoundRobinGroupChat` to iterate the single agent over multiple steps.
If you want to one step at a time, you can just call `run_stream` method of the
:class:`~agentdhal_agentchat.agents.AssistantAgent` class directly.
.. code-block:: python
import asyncio
import sqlite3
import requests
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.conditions import TextMentionTermination
from agentdhal_agentchat.teams import RoundRobinGroupChat
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.tools.langchain import LangChainToolAdapter
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_openai import ChatOpenAI
from sqlalchemy import Engine, create_engine
from sqlalchemy.pool import StaticPool
def get_engine_for_chinook_db() -> Engine:
url = "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql"
response = requests.get(url)
sql_script = response.text
connection = sqlite3.connect(":memory:", check_same_thread=False)
connection.executescript(sql_script)
return create_engine(
"sqlite://",
creator=lambda: connection,
poolclass=StaticPool,
connect_args={"check_same_thread": False},
)
async def main() -> None:
# Create the engine and database wrapper.
engine = get_engine_for_chinook_db()
db = SQLDatabase(engine)
# Create the toolkit.
llm = ChatOpenAI(temperature=0)
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
# Create the LangChain tool adapter for every tool in the toolkit.
tools = [LangChainToolAdapter(tool) for tool in toolkit.get_tools()]
# Create the chat completion client.
model_client = OpenAIChatCompletionClient(model="gpt-4o")
# Create the assistant agent.
agent = AssistantAgent(
"assistant",
model_client=model_client,
tools=tools, # type: ignore
model_client_stream=True,
system_message="Respond with 'TERMINATE' if the task is completed.",
)
# Create termination condition.
termination = TextMentionTermination("TERMINATE")
# Create a round-robin group chat to iterate the single agent over multiple steps.
chat = RoundRobinGroupChat([agent], termination_condition=termination)
# Run the chat.
await Console(chat.run_stream(task="Show some tables in the database"))
if __name__ == "__main__":
asyncio.run(main())
"""
def __init__(self, langchain_tool: LangChainTool):
self._langchain_tool: LangChainTool = langchain_tool
# Extract name and description
name = self._langchain_tool.name
description = self._langchain_tool.description or ""
# Determine the callable method
if hasattr(self._langchain_tool, "func") and callable(self._langchain_tool.func): # type: ignore
assert self._langchain_tool.func is not None # type: ignore
self._callable: Callable[..., Any] = self._langchain_tool.func # type: ignore
elif hasattr(self._langchain_tool, "_run") and callable(self._langchain_tool._run): # type: ignore
self._callable: Callable[..., Any] = self._langchain_tool._run # type: ignore
else:
raise AttributeError(
f"The provided LangChain tool '{name}' does not have a callable 'func' or '_run' method."
)
# Determine args_type
if self._langchain_tool.args_schema: # pyright: ignore
args_type = self._langchain_tool.args_schema # pyright: ignore
else:
# Infer args_type from the callable's signature
sig = inspect.signature(cast(Callable[..., Any], self._callable)) # type: ignore
fields = {
k: (v.annotation, Field(...))
for k, v in sig.parameters.items()
if k != "self" and v.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
}
args_type = create_model(f"{name}Args", **fields) # type: ignore
# Note: type ignore is used due to a LangChain typing limitation
# Ensure args_type is a subclass of BaseModel
if not issubclass(args_type, BaseModel):
raise ValueError(f"Failed to create a valid Pydantic v2 model for {name}")
# Assume return_type as Any if not specified
return_type: Type[Any] = object
super().__init__(args_type, return_type, name, description)
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
# Prepare arguments
kwargs = args.model_dump()
# Determine if the callable is asynchronous
if inspect.iscoroutinefunction(self._callable):
return await self._callable(**kwargs)
else:
# Run in a thread to avoid blocking the event loop
return await asyncio.to_thread(self._call_sync, kwargs)
def _call_sync(self, kwargs: Dict[str, Any]) -> Any:
return self._callable(**kwargs)

Some files were not shown because too many files have changed in this diff Show More