first commit
This commit is contained in:
3
agent_dhal/agentdhal_extensions/__init__.py
Normal file
3
agent_dhal/agentdhal_extensions/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
import importlib.metadata
|
||||
|
||||
__version__ = importlib.metadata.version("agentdhal_extensions")
|
||||
10
agent_dhal/agentdhal_extensions/agents/azure/__init__.py
Normal file
10
agent_dhal/agentdhal_extensions/agents/azure/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
try:
|
||||
from ._azure_ai_agent import AzureAIAgent
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Dependencies for AzureAIAgent not found. "
|
||||
'Please install autogen-ext with the "azure" extra: '
|
||||
'pip install "agentdhal-ext[azure]"'
|
||||
) from e
|
||||
|
||||
__all__ = ["AzureAIAgent"]
|
||||
1096
agent_dhal/agentdhal_extensions/agents/azure/_azure_ai_agent.py
Normal file
1096
agent_dhal/agentdhal_extensions/agents/azure/_azure_ai_agent.py
Normal file
File diff suppressed because it is too large
Load Diff
61
agent_dhal/agentdhal_extensions/agents/azure/_types.py
Normal file
61
agent_dhal/agentdhal_extensions/agents/azure/_types.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import Any, Awaitable, Callable, Iterable, List, Literal, Optional, TypeGuard, Union
|
||||
|
||||
from agentdhal_core.tools import Tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from azure.ai.agents.models import (
|
||||
AzureAISearchToolDefinition,
|
||||
AzureFunctionToolDefinition,
|
||||
BingGroundingToolDefinition,
|
||||
CodeInterpreterToolDefinition,
|
||||
FileSearchToolDefinition,
|
||||
MessageTextUrlCitationAnnotation,
|
||||
)
|
||||
|
||||
ListToolType = Iterable[
|
||||
Union[
|
||||
Literal[
|
||||
"file_search",
|
||||
"code_interpreter",
|
||||
"bing_grounding",
|
||||
"azure_ai_search",
|
||||
"azure_function",
|
||||
],
|
||||
BingGroundingToolDefinition,
|
||||
CodeInterpreterToolDefinition,
|
||||
AzureAISearchToolDefinition,
|
||||
FileSearchToolDefinition,
|
||||
AzureFunctionToolDefinition,
|
||||
Tool,
|
||||
Callable[..., Any],
|
||||
Callable[..., Awaitable[Any]],
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class AzureAIAgentState(BaseModel):
|
||||
"""
|
||||
Represents the state of an AzureAIAgent that can be saved and loaded.
|
||||
|
||||
This state model keeps track of persistent information about an agent session
|
||||
including agent and thread identifiers, message history, and associated resources.
|
||||
|
||||
Attributes:
|
||||
type (str): The type identifier for the state object, always "AzureAIAgentState"
|
||||
agent_id (Optional[str]): The ID of the Azure AI agent
|
||||
thread_id (Optional[str]): The ID of the conversation thread
|
||||
initial_message_ids (List[str]): List of message IDs from the initial state
|
||||
vector_store_id (Optional[str]): The ID of the associated vector store for file search
|
||||
uploaded_file_ids (List[str]): List of IDs for files uploaded to the agent
|
||||
"""
|
||||
|
||||
type: str = Field(default="AzureAIAgentState")
|
||||
agent_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None
|
||||
initial_message_ids: List[str] = Field(default_factory=list)
|
||||
vector_store_id: Optional[str] = None
|
||||
uploaded_file_ids: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
def has_annotations(obj: Any) -> TypeGuard[list[MessageTextUrlCitationAnnotation]]:
|
||||
return obj is not None and isinstance(obj, list)
|
||||
@@ -0,0 +1,3 @@
|
||||
from ._file_surfer import FileSurfer
|
||||
|
||||
__all__ = ["FileSurfer"]
|
||||
@@ -0,0 +1,208 @@
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from agentdhal_agentchat.agents import BaseChatAgent
|
||||
from agentdhal_agentchat.base import Response
|
||||
from agentdhal_agentchat.messages import (
|
||||
BaseChatMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from agentdhal_agentchat.utils import remove_images
|
||||
from agentdhal_core import CancellationToken, Component, ComponentModel, FunctionCall
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._markdown_file_browser import MarkdownFileBrowser
|
||||
|
||||
# from typing_extensions import Annotated
|
||||
from ._tool_definitions import (
|
||||
TOOL_FIND_NEXT,
|
||||
TOOL_FIND_ON_PAGE_CTRL_F,
|
||||
TOOL_OPEN_PATH,
|
||||
TOOL_PAGE_DOWN,
|
||||
TOOL_PAGE_UP,
|
||||
)
|
||||
|
||||
|
||||
class FileSurferConfig(BaseModel):
|
||||
"""Configuration for FileSurfer agent"""
|
||||
|
||||
name: str
|
||||
model_client: ComponentModel
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class FileSurfer(BaseChatAgent, Component[FileSurferConfig]):
|
||||
"""An agent, used by MagenticOne, that acts as a local file previewer. FileSurfer can open and read a variety of common file types, and can navigate the local file hierarchy.
|
||||
|
||||
Installation:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[file-surfer]"
|
||||
|
||||
Args:
|
||||
name (str): The agent's name
|
||||
model_client (ChatCompletionClient): The model to use (must be tool-use enabled)
|
||||
description (str): The agent's description used by the team. Defaults to DEFAULT_DESCRIPTION
|
||||
base_path (str): The base path to use for the file browser. Defaults to the current working directory.
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = FileSurferConfig
|
||||
component_provider_override = "agentdhal_extensions.agents.file_surfer.FileSurfer"
|
||||
|
||||
DEFAULT_DESCRIPTION = "An agent that can handle local files."
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGES = [
|
||||
SystemMessage(
|
||||
content="""
|
||||
You are a helpful AI Assistant.
|
||||
When given a user query, use available functions to help the user with their request."""
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_client: ChatCompletionClient,
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
base_path: str = os.getcwd(),
|
||||
) -> None:
|
||||
super().__init__(name, description)
|
||||
self._model_client = model_client
|
||||
self._chat_history: List[LLMMessage] = []
|
||||
self._browser = MarkdownFileBrowser(viewport_size=1024 * 5, base_path=base_path)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (TextMessage,)
|
||||
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
for chat_message in messages:
|
||||
self._chat_history.append(chat_message.to_model_message())
|
||||
try:
|
||||
_, content = await self._generate_reply(cancellation_token=cancellation_token)
|
||||
self._chat_history.append(AssistantMessage(content=content, source=self.name))
|
||||
return Response(chat_message=TextMessage(content=content, source=self.name))
|
||||
|
||||
except BaseException:
|
||||
content = f"File surfing error:\n\n{traceback.format_exc()}"
|
||||
self._chat_history.append(AssistantMessage(content=content, source=self.name))
|
||||
return Response(chat_message=TextMessage(content=content, source=self.name))
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
self._chat_history.clear()
|
||||
|
||||
def _get_browser_state(self) -> Tuple[str, str]:
|
||||
"""
|
||||
Get the current state of the browser, including the header and content.
|
||||
"""
|
||||
header = f"Path: {self._browser.path}\n"
|
||||
|
||||
if self._browser.page_title is not None:
|
||||
header += f"Title: {self._browser.page_title}\n"
|
||||
|
||||
current_page = self._browser.viewport_current_page
|
||||
total_pages = len(self._browser.viewport_pages)
|
||||
header += f"Viewport position: Showing page {current_page+1} of {total_pages}.\n"
|
||||
|
||||
return (header, self._browser.viewport)
|
||||
|
||||
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, str]:
|
||||
history = self._chat_history[0:-1]
|
||||
last_message = self._chat_history[-1]
|
||||
assert isinstance(last_message, UserMessage)
|
||||
|
||||
task_content = last_message.content # the last message from the sender is the task
|
||||
|
||||
assert self._browser is not None
|
||||
|
||||
context_message = UserMessage(
|
||||
source="user",
|
||||
content=f"Your file viewer is currently open to the file or directory '{self._browser.page_title}' with path '{self._browser.path}'.",
|
||||
)
|
||||
|
||||
task_message = UserMessage(
|
||||
source="user",
|
||||
content=task_content,
|
||||
)
|
||||
|
||||
create_result = await self._model_client.create(
|
||||
messages=self._get_compatible_context(history + [context_message, task_message]),
|
||||
tools=[
|
||||
TOOL_OPEN_PATH,
|
||||
TOOL_PAGE_DOWN,
|
||||
TOOL_PAGE_UP,
|
||||
TOOL_FIND_NEXT,
|
||||
TOOL_FIND_ON_PAGE_CTRL_F,
|
||||
],
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
response = create_result.content
|
||||
|
||||
if isinstance(response, str):
|
||||
# Answer directly.
|
||||
return False, response
|
||||
|
||||
elif isinstance(response, list) and all(isinstance(item, FunctionCall) for item in response):
|
||||
function_calls = response
|
||||
for function_call in function_calls:
|
||||
tool_name = function_call.name
|
||||
|
||||
try:
|
||||
arguments = json.loads(function_call.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
error_str = f"File surfer encountered an error decoding JSON arguments: {e}"
|
||||
return False, error_str
|
||||
|
||||
if tool_name == "open_path":
|
||||
path = arguments["path"]
|
||||
self._browser.open_path(path)
|
||||
elif tool_name == "page_up":
|
||||
self._browser.page_up()
|
||||
elif tool_name == "page_down":
|
||||
self._browser.page_down()
|
||||
elif tool_name == "find_on_page_ctrl_f":
|
||||
search_string = arguments["search_string"]
|
||||
self._browser.find_on_page(search_string)
|
||||
elif tool_name == "find_next":
|
||||
self._browser.find_next()
|
||||
header, content = self._get_browser_state()
|
||||
final_response = header.strip() + "\n=======================\n" + content
|
||||
return False, final_response
|
||||
|
||||
final_response = "TERMINATE"
|
||||
return False, final_response
|
||||
|
||||
def _get_compatible_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
|
||||
if self._model_client.model_info["vision"]:
|
||||
return messages
|
||||
else:
|
||||
return remove_images(messages)
|
||||
|
||||
def _to_config(self) -> FileSurferConfig:
|
||||
return FileSurferConfig(
|
||||
name=self.name,
|
||||
model_client=self._model_client.dump_component(),
|
||||
description=self.description,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: FileSurferConfig) -> Self:
|
||||
return cls(
|
||||
name=config.name,
|
||||
model_client=ChatCompletionClient.load_component(config.model_client),
|
||||
description=config.description or cls.DEFAULT_DESCRIPTION,
|
||||
)
|
||||
@@ -0,0 +1,317 @@
|
||||
# ruff: noqa: E722
|
||||
import datetime
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
# TODO: Fix unfollowed import
|
||||
from markitdown import FileConversionException, MarkItDown, UnsupportedFormatException # type: ignore
|
||||
|
||||
|
||||
class MarkdownFileBrowser:
|
||||
"""
|
||||
(In preview) An extremely simple Markdown-powered file browser.
|
||||
"""
|
||||
|
||||
# TODO: Fix unfollowed import
|
||||
def __init__( # type: ignore
|
||||
self,
|
||||
viewport_size: Union[int, None] = 1024 * 8,
|
||||
base_path: str | None = os.getcwd(),
|
||||
cwd: str | None = None,
|
||||
):
|
||||
"""
|
||||
Instantiate a new MarkdownFileBrowser.
|
||||
|
||||
Arguments:
|
||||
viewport_size: Approximately how many *characters* fit in the viewport. Viewport dimensions are adjusted dynamically to avoid cutting off words (default: 8192).
|
||||
base_path: The base path to use for the file browser. Files outside this path cannot be accessed. Defaults to the current working directory.
|
||||
cwd: The browser's current working directory. Defaults to the system's current working directory.
|
||||
"""
|
||||
self.viewport_size = viewport_size # Applies only to the standard uri types
|
||||
self.history: List[Tuple[str, float]] = list()
|
||||
self.page_title: Optional[str] = None
|
||||
self.viewport_current_page = 0
|
||||
self.viewport_pages: List[Tuple[int, int]] = list()
|
||||
self._markdown_converter = MarkItDown()
|
||||
self._base_path = None if base_path is None else os.path.realpath(base_path)
|
||||
self._page_content: str = ""
|
||||
self._find_on_page_query: Union[str, None] = None
|
||||
self._find_on_page_last_result: Union[int, None] = None # Location of the last result
|
||||
|
||||
# Set the working directory
|
||||
if cwd is None:
|
||||
if self._validate_path(os.getcwd()):
|
||||
# Use the current working directory if it's in the base path
|
||||
cwd = os.path.realpath(os.getcwd())
|
||||
elif self._base_path is not None:
|
||||
# Otherwise, use the base path
|
||||
cwd = os.path.realpath(self._base_path)
|
||||
else:
|
||||
raise ValueError("No valid working directory (cwd) provided.")
|
||||
elif not self._validate_path(cwd):
|
||||
# A cwd was provided, but it is not valid
|
||||
raise ValueError(f"Working directory (cwd) '{cwd}' is not valid. It must be within the base path.")
|
||||
|
||||
# Populate the history with the current working directory
|
||||
self.set_path(os.path.realpath(cwd))
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
"""Return the path of the current page."""
|
||||
assert len(self.history) > 0
|
||||
return self.history[-1][0]
|
||||
|
||||
def _validate_path(self, path: str) -> bool:
|
||||
"""Validates the path to ensure it is within the base path.
|
||||
|
||||
Arguments:
|
||||
path: The path to validate.
|
||||
Returns:
|
||||
True if the path is valid, False otherwise.
|
||||
"""
|
||||
if self._base_path is None:
|
||||
return True
|
||||
|
||||
# Normalize the paths
|
||||
path = os.path.realpath(path)
|
||||
base = os.path.realpath(self._base_path)
|
||||
|
||||
# Check if the path is within the base path
|
||||
if os.path.commonpath([path, base]) != base:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def set_path(self, path: str) -> None:
|
||||
"""Sets the path of the current page.
|
||||
This will result in the file being opened for reading.
|
||||
|
||||
Arguments:
|
||||
path: An absolute or relative path of the file or directory to open."
|
||||
"""
|
||||
|
||||
# Handle relative paths
|
||||
path = os.path.expanduser(path)
|
||||
if not os.path.isabs(path):
|
||||
if os.path.isfile(self.path):
|
||||
path = os.path.abspath(os.path.join(os.path.dirname(self.path), path))
|
||||
elif os.path.isdir(self.path):
|
||||
path = os.path.abspath(os.path.join(self.path, path))
|
||||
# If neither a file or a directory, take it verbatim
|
||||
|
||||
# Validating the path wrt. the base path is done in _open_path
|
||||
path = os.path.realpath(path)
|
||||
|
||||
self.history.append((path, time.time()))
|
||||
self._open_path(path)
|
||||
self.viewport_current_page = 0
|
||||
self.find_on_page_query = None
|
||||
self.find_on_page_viewport = None
|
||||
|
||||
@property
|
||||
def viewport(self) -> str:
|
||||
"""Return the content of the current viewport."""
|
||||
bounds = self.viewport_pages[self.viewport_current_page]
|
||||
return self.page_content[bounds[0] : bounds[1]]
|
||||
|
||||
@property
|
||||
def page_content(self) -> str:
|
||||
"""Return the full contents of the current page."""
|
||||
return self._page_content
|
||||
|
||||
def _set_page_content(self, content: str, split_pages: bool = True) -> None:
|
||||
"""Sets the text content of the current page."""
|
||||
self._page_content = content
|
||||
|
||||
if split_pages:
|
||||
self._split_pages()
|
||||
else:
|
||||
self.viewport_pages = [(0, len(self._page_content))]
|
||||
|
||||
if self.viewport_current_page >= len(self.viewport_pages):
|
||||
self.viewport_current_page = len(self.viewport_pages) - 1
|
||||
|
||||
def page_down(self) -> None:
|
||||
"""Move the viewport down one page, if possible."""
|
||||
self.viewport_current_page = min(self.viewport_current_page + 1, len(self.viewport_pages) - 1)
|
||||
|
||||
def page_up(self) -> None:
|
||||
"""Move the viewport up one page, if possible."""
|
||||
self.viewport_current_page = max(self.viewport_current_page - 1, 0)
|
||||
|
||||
def find_on_page(self, query: str) -> Union[str, None]:
|
||||
"""Searches for the query from the current viewport forward, looping back to the start if necessary."""
|
||||
|
||||
# Did we get here via a previous find_on_page search with the same query?
|
||||
# If so, map to find_next
|
||||
if query == self._find_on_page_query and self.viewport_current_page == self._find_on_page_last_result:
|
||||
return self.find_next()
|
||||
|
||||
# Ok it's a new search start from the current viewport
|
||||
self._find_on_page_query = query
|
||||
viewport_match = self._find_next_viewport(query, self.viewport_current_page)
|
||||
if viewport_match is None:
|
||||
self._find_on_page_last_result = None
|
||||
return None
|
||||
else:
|
||||
self.viewport_current_page = viewport_match
|
||||
self._find_on_page_last_result = viewport_match
|
||||
return self.viewport
|
||||
|
||||
def find_next(self) -> Union[str, None]:
|
||||
"""Scroll to the next viewport that matches the query"""
|
||||
|
||||
if self._find_on_page_query is None:
|
||||
return None
|
||||
|
||||
starting_viewport = self._find_on_page_last_result
|
||||
if starting_viewport is None:
|
||||
starting_viewport = 0
|
||||
else:
|
||||
starting_viewport += 1
|
||||
if starting_viewport >= len(self.viewport_pages):
|
||||
starting_viewport = 0
|
||||
|
||||
viewport_match = self._find_next_viewport(self._find_on_page_query, starting_viewport)
|
||||
if viewport_match is None:
|
||||
self._find_on_page_last_result = None
|
||||
return None
|
||||
else:
|
||||
self.viewport_current_page = viewport_match
|
||||
self._find_on_page_last_result = viewport_match
|
||||
return self.viewport
|
||||
|
||||
def _find_next_viewport(self, query: Optional[str], starting_viewport: int) -> Union[int, None]:
|
||||
"""Search for matches between the starting viewport looping when reaching the end."""
|
||||
|
||||
if query is None:
|
||||
return None
|
||||
|
||||
# Normalize the query, and convert to a regular expression
|
||||
nquery = re.sub(r"\*", "__STAR__", query)
|
||||
nquery = " " + (" ".join(re.split(r"\W+", nquery))).strip() + " "
|
||||
nquery = nquery.replace(" __STAR__ ", "__STAR__ ") # Merge isolated stars with prior word
|
||||
nquery = nquery.replace("__STAR__", ".*").lower()
|
||||
|
||||
if nquery.strip() == "":
|
||||
return None
|
||||
|
||||
idxs: List[int] = list()
|
||||
idxs.extend(range(starting_viewport, len(self.viewport_pages)))
|
||||
idxs.extend(range(0, starting_viewport))
|
||||
|
||||
for i in idxs:
|
||||
bounds = self.viewport_pages[i]
|
||||
content = self.page_content[bounds[0] : bounds[1]]
|
||||
|
||||
# TODO: Remove markdown links and images
|
||||
ncontent = " " + (" ".join(re.split(r"\W+", content))).strip().lower() + " "
|
||||
if re.search(nquery, ncontent):
|
||||
return i
|
||||
|
||||
return None
|
||||
|
||||
def open_path(self, path: str) -> str:
|
||||
"""Open a file or directory in the file surfer."""
|
||||
self.set_path(path)
|
||||
return self.viewport
|
||||
|
||||
def _split_pages(self) -> None:
|
||||
"""Split the page contents into pages that are approximately the viewport size. Small deviations are permitted to ensure words are not broken."""
|
||||
# Handle empty pages
|
||||
if len(self._page_content) == 0:
|
||||
self.viewport_pages = [(0, 0)]
|
||||
return
|
||||
|
||||
# Break the viewport into pages
|
||||
self.viewport_pages = []
|
||||
start_idx = 0
|
||||
while start_idx < len(self._page_content):
|
||||
end_idx = min(start_idx + self.viewport_size, len(self._page_content)) # type: ignore[operator]
|
||||
# Adjust to end on a space
|
||||
while end_idx < len(self._page_content) and self._page_content[end_idx - 1] not in [" ", "\t", "\r", "\n"]:
|
||||
end_idx += 1
|
||||
self.viewport_pages.append((start_idx, end_idx))
|
||||
start_idx = end_idx
|
||||
|
||||
def _open_path(
|
||||
self,
|
||||
path: str,
|
||||
) -> None:
|
||||
"""Open a file for reading, converting it to Markdown in the process.
|
||||
|
||||
Arguments:
|
||||
path: The path of the file or directory to open.
|
||||
"""
|
||||
|
||||
if not self._validate_path(path):
|
||||
# Not robust to TOCTOU issues.
|
||||
# Mitigate by running with limited permissions, or use a sandbox.
|
||||
self.page_title = "FileNotFoundError"
|
||||
self._set_page_content(f"# FileNotFoundError\n\nFile not found: {path}")
|
||||
else:
|
||||
try:
|
||||
if os.path.isdir(path): # TODO: Fix markdown_converter types
|
||||
res = self._markdown_converter.convert_stream( # type: ignore
|
||||
io.BytesIO(self._fetch_local_dir(path).encode("utf-8")), file_extension=".txt"
|
||||
)
|
||||
assert self._validate_path(path)
|
||||
self.page_title = res.title
|
||||
self._set_page_content(res.text_content, split_pages=False)
|
||||
else:
|
||||
res = self._markdown_converter.convert_local(path)
|
||||
assert self._validate_path(path)
|
||||
self.page_title = res.title
|
||||
self._set_page_content(res.text_content)
|
||||
except UnsupportedFormatException:
|
||||
self.page_title = "UnsupportedFormatException"
|
||||
self._set_page_content(f"# UnsupportedFormatException\n\nCannot preview '{path}' as Markdown.")
|
||||
except FileConversionException:
|
||||
self.page_title = "FileConversionException."
|
||||
self._set_page_content(f"# FileConversionException\n\nError converting '{path}' to Markdown.")
|
||||
except FileNotFoundError:
|
||||
self.page_title = "FileNotFoundError"
|
||||
self._set_page_content(f"# FileNotFoundError\n\nFile not found: {path}")
|
||||
|
||||
def _fetch_local_dir(self, local_path: str) -> str:
|
||||
"""Render a local directory listing in HTML to assist with local file browsing via the "file://" protocol.
|
||||
Through rendered in HTML, later parts of the pipeline will convert the listing to Markdown.
|
||||
|
||||
Arguments:
|
||||
local_path: A path to the local directory whose contents are to be listed.
|
||||
|
||||
Returns:
|
||||
A directory listing, rendered in HTML.
|
||||
"""
|
||||
listing = f"""
|
||||
# Index of {local_path}
|
||||
|
||||
| Name | Size | Date Modified |
|
||||
| ---- | ---- | ------------- |
|
||||
| .. (parent directory) | | |
|
||||
"""
|
||||
for entry in os.listdir(local_path):
|
||||
size = ""
|
||||
full_path = os.path.join(local_path, entry)
|
||||
|
||||
mtime = ""
|
||||
try:
|
||||
mtime = datetime.datetime.fromtimestamp(os.path.getmtime(full_path)).strftime("%Y-%m-%d %H:%M")
|
||||
except Exception as e:
|
||||
# Handles PermissionError, etc.
|
||||
mtime = f"N/A: {type(e).__name__}"
|
||||
|
||||
if os.path.isdir(full_path):
|
||||
entry = entry + os.path.sep
|
||||
else:
|
||||
try:
|
||||
size = str(os.path.getsize(full_path))
|
||||
except Exception as e:
|
||||
# Handles PermissionError, etc.
|
||||
size = f"N/A: {type(e).__name__}"
|
||||
|
||||
listing += f"| {entry} | {size} | {mtime} |\n"
|
||||
return listing
|
||||
@@ -0,0 +1,50 @@
|
||||
from agentdhal_core.tools import ParametersSchema, ToolSchema
|
||||
|
||||
TOOL_OPEN_PATH = ToolSchema(
|
||||
name="open_path",
|
||||
description="Open a local file or directory at a path in the text-based file browser and return current viewport content.",
|
||||
parameters=ParametersSchema(
|
||||
type="object",
|
||||
properties={
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The relative or absolute path of a local file to visit.",
|
||||
},
|
||||
},
|
||||
required=["path"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
TOOL_PAGE_UP = ToolSchema(
|
||||
name="page_up",
|
||||
description="Scroll the viewport UP one page-length in the current file and return the new viewport content.",
|
||||
)
|
||||
|
||||
|
||||
TOOL_PAGE_DOWN = ToolSchema(
|
||||
name="page_down",
|
||||
description="Scroll the viewport DOWN one page-length in the current file and return the new viewport content.",
|
||||
)
|
||||
|
||||
|
||||
TOOL_FIND_ON_PAGE_CTRL_F = ToolSchema(
|
||||
name="find_on_page_ctrl_f",
|
||||
description="Scroll the viewport to the first occurrence of the search string. This is equivalent to Ctrl+F.",
|
||||
parameters=ParametersSchema(
|
||||
type="object",
|
||||
properties={
|
||||
"search_string": {
|
||||
"type": "string",
|
||||
"description": "The string to search for on the page. This search string supports wildcards like '*'",
|
||||
},
|
||||
},
|
||||
required=["search_string"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
TOOL_FIND_NEXT = ToolSchema(
|
||||
name="find_next",
|
||||
description="Scroll the viewport to next occurrence of the search string.",
|
||||
)
|
||||
@@ -0,0 +1,10 @@
|
||||
try:
|
||||
from ._magentic_one_coder_agent import MagenticOneCoderAgent
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Dependencies for MagenticOneCoderAgent not found. "
|
||||
'Please install autogen-ext with the "magentic-one" extra: '
|
||||
'pip install "agentdhal-ext[magentic-one]"'
|
||||
) from e
|
||||
|
||||
__all__ = ["MagenticOneCoderAgent"]
|
||||
@@ -0,0 +1,41 @@
|
||||
from typing import Any
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_core.models import (
|
||||
ChatCompletionClient,
|
||||
)
|
||||
|
||||
MAGENTIC_ONE_CODER_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills."
|
||||
|
||||
MAGENTIC_ONE_CODER_SYSTEM_MESSAGE = """You are a helpful AI assistant.
|
||||
Solve tasks using your coding and language skills.
|
||||
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
|
||||
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
|
||||
2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.
|
||||
Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.
|
||||
When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.
|
||||
Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use the 'print' function for the output when relevant. Check the execution result returned by the user.
|
||||
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
|
||||
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible."""
|
||||
|
||||
|
||||
class MagenticOneCoderAgent(AssistantAgent):
|
||||
"""An agent, used by MagenticOne that provides coding assistance using an LLM model client.
|
||||
|
||||
The prompts and description are sealed, to replicate the original MagenticOne configuration. See AssistantAgent if you wish to modify these values.
|
||||
"""
|
||||
|
||||
component_provider_override = "agentdhal_extensions.agents.magentic_one.MagenticOneCoderAgent"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_client: ChatCompletionClient,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
model_client,
|
||||
description=MAGENTIC_ONE_CODER_DESCRIPTION,
|
||||
system_message=MAGENTIC_ONE_CODER_SYSTEM_MESSAGE,
|
||||
)
|
||||
@@ -0,0 +1,7 @@
|
||||
from ._openai_agent import OpenAIAgent
|
||||
from ._openai_assistant_agent import OpenAIAssistantAgent
|
||||
|
||||
__all__ = [
|
||||
"OpenAIAgent",
|
||||
"OpenAIAssistantAgent",
|
||||
]
|
||||
682
agent_dhal/agentdhal_extensions/agents/openai/_openai_agent.py
Normal file
682
agent_dhal/agentdhal_extensions/agents/openai/_openai_agent.py
Normal file
@@ -0,0 +1,682 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from agentdhal_agentchat import EVENT_LOGGER_NAME
|
||||
from agentdhal_agentchat.agents import BaseChatAgent
|
||||
from agentdhal_agentchat.base import Response
|
||||
from agentdhal_agentchat.messages import (
|
||||
AgentEvent,
|
||||
BaseChatMessage,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallSummaryMessage,
|
||||
)
|
||||
from agentdhal_core import CancellationToken, Component
|
||||
from agentdhal_core.models import UserMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI # type: ignore
|
||||
|
||||
# Number of characters to display when previewing image content in logs and UI
|
||||
# Base64 encoded images can be very long, so we truncate for readability
|
||||
IMAGE_CONTENT_PREVIEW_LENGTH = 50
|
||||
|
||||
# NOTE: We use the new Responses API, so ChatCompletion imports are not needed.
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
# TypedDict classes for built-in tool configurations
|
||||
class FileSearchToolConfig(TypedDict):
|
||||
"""Configuration for file_search tool."""
|
||||
|
||||
type: Literal["file_search"]
|
||||
vector_store_ids: List[str] # required - The IDs of the vector stores to search
|
||||
max_num_results: NotRequired[int] # optional
|
||||
ranking_options: NotRequired[Dict[str, Any]] # optional
|
||||
filters: NotRequired[Dict[str, Any]] # optional
|
||||
|
||||
|
||||
class WebSearchToolConfig(TypedDict):
|
||||
"""Configuration for web_search_preview tool."""
|
||||
|
||||
type: Literal["web_search_preview"]
|
||||
search_context_size: NotRequired[str] # optional
|
||||
user_location: NotRequired[Union[str, Dict[str, Any]]] # optional - Can be string or structured location
|
||||
|
||||
|
||||
class ComputerUseToolConfig(TypedDict):
|
||||
"""Configuration for computer_use_preview tool."""
|
||||
|
||||
type: Literal["computer_use_preview"]
|
||||
display_height: int # required - Display height in pixels
|
||||
display_width: int # required - Display width in pixels
|
||||
environment: str # required - Environment type for computer use
|
||||
|
||||
|
||||
class MCPToolConfig(TypedDict):
|
||||
"""Configuration for mcp tool."""
|
||||
|
||||
type: Literal["mcp"]
|
||||
server_label: str # required - Label for the MCP server
|
||||
server_url: str # required - URL of the MCP server
|
||||
allowed_tools: NotRequired[List[str]] # optional - List of allowed tools
|
||||
headers: NotRequired[Dict[str, str]] # optional - HTTP headers for requests
|
||||
require_approval: NotRequired[bool] # optional - Whether to require user approval
|
||||
|
||||
|
||||
class CodeInterpreterToolConfig(TypedDict):
|
||||
"""Configuration for code_interpreter tool."""
|
||||
|
||||
type: Literal["code_interpreter"]
|
||||
container: str | Dict[str, Any] # required - Container configuration for code execution
|
||||
|
||||
|
||||
class ImageGenerationToolConfig(TypedDict):
|
||||
"""Configuration for image_generation tool."""
|
||||
|
||||
type: Literal["image_generation"]
|
||||
background: NotRequired[str] # optional - Background color or image
|
||||
input_image_mask: NotRequired[str] # optional - Mask for input image editing
|
||||
|
||||
|
||||
class LocalShellToolConfig(TypedDict):
|
||||
"""Configuration for local_shell tool.
|
||||
|
||||
WARNING: This tool is only supported with the 'codex-mini-latest' model
|
||||
and is available exclusively through the Responses API.
|
||||
"""
|
||||
|
||||
type: Literal["local_shell"]
|
||||
# Note: local_shell currently has no additional parameters in the API
|
||||
|
||||
|
||||
# Union type for all built-in tool configurations
|
||||
BuiltinToolConfig = Union[
|
||||
FileSearchToolConfig,
|
||||
WebSearchToolConfig,
|
||||
ComputerUseToolConfig,
|
||||
MCPToolConfig,
|
||||
CodeInterpreterToolConfig,
|
||||
ImageGenerationToolConfig,
|
||||
LocalShellToolConfig,
|
||||
]
|
||||
|
||||
|
||||
# Define ImageMessage class early since it's used in _convert_message_to_openai_message
|
||||
class ImageMessage(BaseChatMessage):
|
||||
"""A message containing an image."""
|
||||
|
||||
content: str # URL or base64 string
|
||||
|
||||
def to_model_message(self) -> UserMessage:
|
||||
return UserMessage(content=self.content, source=self.source)
|
||||
|
||||
def to_model_text(self) -> str:
|
||||
return "[image]"
|
||||
|
||||
def to_text(self) -> str:
|
||||
# Truncate long image content (especially base64) for better readability
|
||||
# While still showing enough of the URL or content to be identifiable
|
||||
if len(self.content) > IMAGE_CONTENT_PREVIEW_LENGTH:
|
||||
return f"[Image: {self.content[:IMAGE_CONTENT_PREVIEW_LENGTH]}...]"
|
||||
return f"[Image: {self.content}]"
|
||||
|
||||
|
||||
class OpenAIMessageContent(TypedDict):
|
||||
type: str
|
||||
text: str
|
||||
|
||||
|
||||
class OpenAIImageUrlContent(TypedDict):
|
||||
url: str
|
||||
|
||||
|
||||
class OpenAIImageContent(TypedDict):
|
||||
type: str
|
||||
image_url: OpenAIImageUrlContent
|
||||
|
||||
|
||||
class OpenAIMessage(TypedDict):
|
||||
role: str
|
||||
content: Union[str, List[Union[OpenAIMessageContent, OpenAIImageContent]]]
|
||||
|
||||
|
||||
def _convert_message_to_openai_message(
|
||||
message: Union[TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage],
|
||||
) -> OpenAIMessage:
|
||||
"""Convert an AutoGen message to an OpenAI message format."""
|
||||
if isinstance(message, TextMessage):
|
||||
if message.source == "user":
|
||||
return {"role": "user", "content": str(message.content)}
|
||||
elif message.source == "system":
|
||||
return {"role": "system", "content": str(message.content)}
|
||||
elif message.source == "assistant":
|
||||
return {"role": "assistant", "content": str(message.content)}
|
||||
else:
|
||||
return {"role": "user", "content": str(message.content)}
|
||||
elif isinstance(message, MultiModalMessage):
|
||||
content_parts: List[Union[OpenAIMessageContent, OpenAIImageContent]] = []
|
||||
for part in message.content:
|
||||
if isinstance(part, TextMessage):
|
||||
content_parts.append({"type": "text", "text": str(part.content)})
|
||||
elif isinstance(part, ImageMessage):
|
||||
image_content = str(part.content)
|
||||
content_parts.append({"type": "image_url", "image_url": {"url": image_content}})
|
||||
return {"role": "user", "content": content_parts}
|
||||
else:
|
||||
return {"role": "user", "content": str(message.content)}
|
||||
|
||||
|
||||
class OpenAIAgentState(BaseModel):
|
||||
type: str = Field(default="OpenAIAgentState")
|
||||
response_id: Optional[str] = None
|
||||
history: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OpenAIAgentConfig(BaseModel):
|
||||
"""
|
||||
Configuration model for OpenAI agent supporting OpenAI built-in tools only.
|
||||
|
||||
.. versionchanged:: v0.7.0
|
||||
Added support for built-in tools in JSON configuration via _to_config and _from_config methods.
|
||||
The tools field accepts built-in tool configurations (dict format) and built-in tool names (string format).
|
||||
Custom tools are not supported.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
model: str
|
||||
instructions: str
|
||||
tools: List[Dict[str, Any] | str] | None = None
|
||||
temperature: Optional[float] = 1
|
||||
max_output_tokens: Optional[int] = None
|
||||
json_mode: bool = False
|
||||
store: bool = True
|
||||
truncation: str = "disabled"
|
||||
|
||||
|
||||
class OpenAIAgent(BaseChatAgent, Component[OpenAIAgentConfig]):
|
||||
"""
|
||||
An agent implementation that uses the OpenAI Responses API to generate responses.
|
||||
|
||||
Installation:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[openai]"
|
||||
# pip install "agentdhal-ext[openai,azure]" # For Azure OpenAI Assistant
|
||||
|
||||
This agent leverages the Responses API to generate responses with capabilities like:
|
||||
|
||||
* Multi-turn conversations
|
||||
* Built-in tool support (file_search, code_interpreter, web_search_preview, etc.)
|
||||
|
||||
Currently, custom tools are not supported.
|
||||
|
||||
.. versionchanged:: v0.7.0
|
||||
|
||||
Added support for built-in tool types like file_search, web_search_preview,
|
||||
code_interpreter, computer_use_preview, image_generation, and mcp.
|
||||
Added support for tool configurations with required and optional parameters.
|
||||
|
||||
Built-in tools are split into two categories:
|
||||
|
||||
**Tools that can use string format** (no required parameters):
|
||||
|
||||
- web_search_preview: Can be used as "web_search_preview" or with optional config
|
||||
(user_location, search_context_size)
|
||||
- image_generation: Can be used as "image_generation" or with optional config (background, input_image_mask)
|
||||
- local_shell: Can be used as "local_shell" (WARNING: Only works with codex-mini-latest model)
|
||||
|
||||
**Tools that REQUIRE dict configuration** (have required parameters):
|
||||
|
||||
- file_search: MUST use dict with vector_store_ids (List[str])
|
||||
- computer_use_preview: MUST use dict with display_height (int), display_width (int), environment (str)
|
||||
- code_interpreter: MUST use dict with container (str)
|
||||
- mcp: MUST use dict with server_label (str), server_url (str)
|
||||
|
||||
Using required-parameter tools in string format will raise a ValueError with helpful error messages.
|
||||
The tools parameter type annotation only accepts string values for tools that don't require parameters.
|
||||
|
||||
Note:
|
||||
Custom tools (autogen FunctionTool or other user-defined tools) are not supported by this agent.
|
||||
Only OpenAI built-in tools provided via the Responses API are supported.
|
||||
|
||||
|
||||
Args:
|
||||
name (str): Name of the agent
|
||||
description (str): Description of the agent's purpose
|
||||
client (Union[AsyncOpenAI, AsyncAzureOpenAI]): OpenAI client instance
|
||||
model (str): Model to use (e.g. "gpt-4.1")
|
||||
instructions (str): System instructions for the agent
|
||||
tools (Optional[Iterable[Union[str, BuiltinToolConfig]]]): Tools the agent can use.
|
||||
Supported string values (no required parameters): "web_search_preview", "image_generation", "local_shell".
|
||||
Dict values can provide configuration for built-in tools with parameters.
|
||||
Required parameters for built-in tools:
|
||||
- file_search: vector_store_ids (List[str])
|
||||
- computer_use_preview: display_height (int), display_width (int), environment (str)
|
||||
- code_interpreter: container (str)
|
||||
- mcp: server_label (str), server_url (str)
|
||||
Optional parameters for built-in tools:
|
||||
- file_search: max_num_results (int), ranking_options (dict), filters (dict)
|
||||
- web_search_preview: user_location (str or dict), search_context_size (int)
|
||||
- image_generation: background (str), input_image_mask (str)
|
||||
- mcp: allowed_tools (List[str]), headers (dict), require_approval (bool)
|
||||
Special tools with model restrictions:
|
||||
- local_shell: Only works with "codex-mini-latest" model (WARNING: Very limited support)
|
||||
Custom tools are not supported.
|
||||
temperature (Optional[float]): Temperature for response generation (default: 1)
|
||||
max_output_tokens (Optional[int]): Maximum output tokens
|
||||
json_mode (bool): Whether to use JSON mode (default: False)
|
||||
store (bool): Whether to store conversations (default: True)
|
||||
truncation (str): Truncation strategy (default: "disabled")
|
||||
|
||||
Example:
|
||||
|
||||
Basic usage with built-in tools:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.agents.openai import OpenAIAgent
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
async def example():
|
||||
client = AsyncOpenAI()
|
||||
agent = OpenAIAgent(
|
||||
name="SimpleAgent",
|
||||
description="A simple OpenAI agent using the Responses API",
|
||||
client=client,
|
||||
model="gpt-4.1",
|
||||
instructions="You are a helpful assistant.",
|
||||
tools=["web_search_preview"], # Only tools without required params
|
||||
)
|
||||
await Console(agent.run_stream(task="Search for recent AI developments"))
|
||||
|
||||
|
||||
asyncio.run(example())
|
||||
|
||||
Usage with configured built-in tools:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.agents.openai import OpenAIAgent
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
async def example_with_configs():
|
||||
client = AsyncOpenAI()
|
||||
# Configure tools with required and optional parameters
|
||||
tools = [
|
||||
# {
|
||||
# "type": "file_search",
|
||||
# "vector_store_ids": ["vs_abc123"], # required
|
||||
# "max_num_results": 10, # optional
|
||||
# },
|
||||
# {
|
||||
# "type": "computer_use_preview",
|
||||
# "display_height": 1024, # required
|
||||
# "display_width": 1280, # required
|
||||
# "environment": "linux", # required
|
||||
# },
|
||||
{
|
||||
"type": "code_interpreter",
|
||||
"container": {"type": "auto"}, # required
|
||||
},
|
||||
# {
|
||||
# "type": "mcp",
|
||||
# "server_label": "my-mcp-server", # required
|
||||
# "server_url": "http://localhost:3000", # required
|
||||
# },
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"user_location": { # optional - structured location
|
||||
"type": "approximate", # required: "approximate" or "exact"
|
||||
"country": "US", # optional
|
||||
"region": "CA", # optional
|
||||
"city": "San Francisco", # optional
|
||||
},
|
||||
"search_context_size": "low", # optional
|
||||
},
|
||||
# "image_generation", # Simple tools can still use string format
|
||||
]
|
||||
|
||||
agent = OpenAIAgent(
|
||||
name="ConfiguredAgent",
|
||||
description="An agent with configured tools",
|
||||
client=client,
|
||||
model="gpt-4.1",
|
||||
instructions="You are a helpful assistant with specialized tools.",
|
||||
tools=tools, # type: ignore
|
||||
)
|
||||
await Console(agent.run_stream(task="Search for recent AI developments"))
|
||||
|
||||
|
||||
asyncio.run(example_with_configs())
|
||||
|
||||
|
||||
Note:
|
||||
Custom tools are not supported by OpenAIAgent. Use only built-in tools from the Responses API.
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = OpenAIAgentConfig
|
||||
component_provider_override = "agentdhal_extensions.agents.openai.OpenAIAgent"
|
||||
|
||||
def __init__(
|
||||
self: "OpenAIAgent",
|
||||
name: str,
|
||||
description: str,
|
||||
client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
||||
model: str,
|
||||
instructions: str,
|
||||
tools: Optional[
|
||||
Iterable[
|
||||
Union[
|
||||
Literal["web_search_preview", "image_generation", "local_shell"],
|
||||
BuiltinToolConfig,
|
||||
]
|
||||
]
|
||||
] = None,
|
||||
temperature: Optional[float] = 1,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
json_mode: bool = False,
|
||||
store: bool = True,
|
||||
truncation: str = "disabled",
|
||||
) -> None:
|
||||
super().__init__(name, description)
|
||||
self._client: Union[AsyncOpenAI, AsyncAzureOpenAI] = client
|
||||
self._model: str = model
|
||||
self._instructions: str = instructions
|
||||
self._temperature: Optional[float] = temperature
|
||||
self._max_output_tokens: Optional[int] = max_output_tokens
|
||||
self._json_mode: bool = json_mode
|
||||
self._store: bool = store
|
||||
self._truncation: str = truncation
|
||||
self._last_response_id: Optional[str] = None
|
||||
self._message_history: List[Dict[str, Any]] = []
|
||||
self._tools: List[Dict[str, Any]] = []
|
||||
if tools is not None:
|
||||
for tool in tools:
|
||||
if isinstance(tool, str):
|
||||
# Handle built-in tool types
|
||||
self._add_builtin_tool(tool)
|
||||
elif isinstance(tool, dict) and "type" in tool:
|
||||
# Handle configured built-in tools
|
||||
self._tools.append(cast(dict[str, Any], tool))
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||
|
||||
def _add_builtin_tool(self, tool_name: str) -> None:
|
||||
"""Add a built-in tool by name."""
|
||||
# Skip if an identical tool has already been registered (idempotent behaviour)
|
||||
if any(td.get("type") == tool_name for td in self._tools):
|
||||
return # Duplicate – ignore rather than raise to stay backward-compatible
|
||||
# Only allow string format for tools that don't require parameters
|
||||
if tool_name == "web_search_preview":
|
||||
self._tools.append({"type": "web_search_preview"})
|
||||
elif tool_name == "image_generation":
|
||||
self._tools.append({"type": "image_generation"})
|
||||
elif tool_name == "local_shell":
|
||||
# Special handling for local_shell - very limited model support
|
||||
if self._model != "codex-mini-latest":
|
||||
raise ValueError(
|
||||
f"Tool 'local_shell' is only supported with model 'codex-mini-latest', "
|
||||
f"but current model is '{self._model}'. "
|
||||
f"This tool is available exclusively through the Responses API and has severe limitations. "
|
||||
f"Consider using agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool with "
|
||||
f"agentdhal_extensions.code_executors.local.LocalCommandLineCodeExecutor for shell execution instead."
|
||||
)
|
||||
self._tools.append({"type": "local_shell"})
|
||||
elif tool_name in ["file_search", "code_interpreter", "computer_use_preview", "mcp"]:
|
||||
# These tools require specific parameters and must use dict configuration
|
||||
raise ValueError(
|
||||
f"Tool '{tool_name}' requires specific parameters and cannot be added using string format. "
|
||||
f"Use dict configuration instead. Required parameters for {tool_name}: "
|
||||
f"{self._get_required_params_help(tool_name)}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported built-in tool type: {tool_name}")
|
||||
|
||||
def _get_required_params_help(self, tool_name: str) -> str:
|
||||
"""Get help text for required parameters of a tool."""
|
||||
help_text = {
|
||||
"file_search": "vector_store_ids (List[str])",
|
||||
"code_interpreter": "container (str | dict)",
|
||||
"computer_use_preview": "display_height (int), display_width (int), environment (str)",
|
||||
"mcp": "server_label (str), server_url (str)",
|
||||
}
|
||||
return help_text.get(tool_name, "unknown parameters")
|
||||
|
||||
def _convert_message_to_dict(self, message: OpenAIMessage) -> Dict[str, Any]:
|
||||
"""Convert an OpenAIMessage to a Dict[str, Any]."""
|
||||
return dict(message)
|
||||
|
||||
@property
|
||||
def produced_message_types(
|
||||
self: "OpenAIAgent",
|
||||
) -> Sequence[
|
||||
Union[
|
||||
Type[TextMessage],
|
||||
Type[MultiModalMessage],
|
||||
Type[StopMessage],
|
||||
Type[ToolCallSummaryMessage],
|
||||
Type[HandoffMessage],
|
||||
]
|
||||
]:
|
||||
"""Return the types of messages that this agent can produce."""
|
||||
return [TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage]
|
||||
|
||||
# Custom tool execution is not supported by this agent.
|
||||
|
||||
def _build_api_parameters(self: "OpenAIAgent", messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
has_system_message = any(msg.get("role") == "system" for msg in messages)
|
||||
if self._instructions and not has_system_message:
|
||||
messages = [{"role": "system", "content": self._instructions}] + messages
|
||||
api_params: Dict[str, Any] = {
|
||||
"model": self._model,
|
||||
"input": messages, # Responses API expects 'input'
|
||||
}
|
||||
if self._temperature is not None:
|
||||
api_params["temperature"] = self._temperature
|
||||
if self._max_output_tokens is not None:
|
||||
api_params["max_output_tokens"] = self._max_output_tokens
|
||||
if self._tools:
|
||||
api_params["tools"] = self._tools
|
||||
if self._json_mode:
|
||||
api_params["text"] = {"type": "json_object"}
|
||||
api_params["store"] = self._store
|
||||
api_params["truncation"] = self._truncation
|
||||
if self._last_response_id:
|
||||
api_params["previous_response_id"] = self._last_response_id
|
||||
return api_params
|
||||
|
||||
async def on_messages(
|
||||
self: "OpenAIAgent", messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> Response:
|
||||
response = None
|
||||
inner_messages: List[
|
||||
Union[AgentEvent, TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage]
|
||||
] = []
|
||||
|
||||
async for msg in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(msg, Response):
|
||||
response = msg
|
||||
# ModelClientStreamingChunkEvent does not exist in this version, so skip this check
|
||||
else:
|
||||
inner_messages.append(msg)
|
||||
|
||||
if response is None:
|
||||
raise ValueError("No response was generated")
|
||||
|
||||
if response.inner_messages is None:
|
||||
response.inner_messages = []
|
||||
|
||||
for msg in inner_messages:
|
||||
if msg not in response.inner_messages:
|
||||
response.inner_messages = list(response.inner_messages) + [msg]
|
||||
|
||||
return response
|
||||
|
||||
async def on_messages_stream(
|
||||
self: "OpenAIAgent", messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[
|
||||
Union[
|
||||
AgentEvent, TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage, Response
|
||||
],
|
||||
None,
|
||||
]:
|
||||
input_messages: List[Dict[str, Any]] = []
|
||||
|
||||
if self._message_history:
|
||||
input_messages.extend(self._message_history)
|
||||
|
||||
for message in messages:
|
||||
if isinstance(
|
||||
message, (TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage)
|
||||
):
|
||||
openai_message = _convert_message_to_openai_message(message)
|
||||
dict_message = self._convert_message_to_dict(openai_message)
|
||||
input_messages.append(dict_message)
|
||||
self._message_history.append(dict_message)
|
||||
else:
|
||||
msg_content = str(cast(Any, message).content) if hasattr(message, "content") else str(message)
|
||||
dict_message = {"role": "user", "content": msg_content}
|
||||
input_messages.append(dict_message)
|
||||
self._message_history.append(dict_message)
|
||||
|
||||
inner_messages: List[AgentEvent | ChatMessage] = []
|
||||
|
||||
api_params = self._build_api_parameters(input_messages)
|
||||
|
||||
try:
|
||||
client = cast(Any, self._client)
|
||||
response_obj = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(client.responses.create(**api_params))
|
||||
)
|
||||
content = getattr(response_obj, "output_text", None)
|
||||
response_id = getattr(response_obj, "id", None)
|
||||
self._last_response_id = response_id
|
||||
# Use a readable placeholder when the API returns no content to aid debugging
|
||||
content_str: str = str(content) if content is not None else "[no content returned]"
|
||||
self._message_history.append({"role": "assistant", "content": content_str})
|
||||
final_message = TextMessage(source=self.name, content=content_str)
|
||||
response = Response(chat_message=final_message, inner_messages=inner_messages)
|
||||
yield response
|
||||
except Exception as e:
|
||||
error_message = f"Error generating response: {str(e)}"
|
||||
event_logger.error(f"API error: {error_message}", exc_info=True)
|
||||
error_response = TextMessage(source=self.name, content=error_message)
|
||||
yield Response(chat_message=error_response, inner_messages=inner_messages)
|
||||
|
||||
async def on_reset(self: "OpenAIAgent", cancellation_token: CancellationToken) -> None:
|
||||
self._last_response_id = None
|
||||
self._message_history = []
|
||||
|
||||
async def save_state(self: "OpenAIAgent") -> Mapping[str, Any]:
|
||||
state = OpenAIAgentState(
|
||||
response_id=self._last_response_id,
|
||||
history=self._message_history,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self: "OpenAIAgent", state: Mapping[str, Any]) -> None:
|
||||
agent_state = OpenAIAgentState.model_validate(state)
|
||||
self._last_response_id = agent_state.response_id
|
||||
self._message_history = agent_state.history
|
||||
|
||||
def _to_config(self: "OpenAIAgent") -> OpenAIAgentConfig:
|
||||
"""Convert the OpenAI agent to a declarative config.
|
||||
|
||||
Serializes built-in tools to their appropriate configuration formats for JSON serialization.
|
||||
|
||||
Returns:
|
||||
OpenAIAgentConfig: The configuration that can recreate this agent.
|
||||
"""
|
||||
return OpenAIAgentConfig(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
model=self._model,
|
||||
instructions=self._instructions,
|
||||
tools=list(self._tools),
|
||||
temperature=self._temperature,
|
||||
max_output_tokens=self._max_output_tokens,
|
||||
json_mode=self._json_mode,
|
||||
store=self._store,
|
||||
truncation=self._truncation,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls: Type["OpenAIAgent"], config: OpenAIAgentConfig) -> "OpenAIAgent":
|
||||
"""Create an OpenAI agent from a declarative config.
|
||||
|
||||
Handles built-in tools (from string or dict configurations).
|
||||
|
||||
Args:
|
||||
config: The configuration to load the agent from.
|
||||
|
||||
Returns:
|
||||
OpenAIAgent: The reconstructed agent.
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI()
|
||||
|
||||
return cls(
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
client=client,
|
||||
model=config.model,
|
||||
instructions=config.instructions,
|
||||
tools=config.tools, # type: ignore
|
||||
temperature=config.temperature,
|
||||
max_output_tokens=config.max_output_tokens,
|
||||
json_mode=config.json_mode,
|
||||
store=config.store,
|
||||
truncation=config.truncation,
|
||||
)
|
||||
|
||||
# Add public API wrappers for configuration and tools
|
||||
def to_config(self) -> OpenAIAgentConfig:
|
||||
"""Public wrapper for the private _to_config method."""
|
||||
return self._to_config()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: OpenAIAgentConfig) -> "OpenAIAgent":
|
||||
"""Public wrapper for the private _from_config classmethod."""
|
||||
return cls._from_config(config)
|
||||
|
||||
@property
|
||||
def tools(self) -> list[Any]:
|
||||
"""Public access to the agent's tools."""
|
||||
return self._tools
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
"""Public access to the agent's model."""
|
||||
return self._model
|
||||
@@ -0,0 +1,715 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import aiofiles
|
||||
from agentdhal_agentchat import EVENT_LOGGER_NAME
|
||||
from agentdhal_agentchat.agents import BaseChatAgent
|
||||
from agentdhal_agentchat.base import Response
|
||||
from agentdhal_agentchat.messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
)
|
||||
from agentdhal_core import CancellationToken, FunctionCall, Image
|
||||
from agentdhal_core.models import ChatCompletionClient, FunctionExecutionResult
|
||||
from agentdhal_core.tools import FunctionTool, Tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI, NotGiven
|
||||
from openai.pagination import AsyncCursorPage
|
||||
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
|
||||
from openai.types import FileObject
|
||||
from openai.types.beta import thread_update_params
|
||||
from openai.types.beta.assistant import Assistant
|
||||
from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam
|
||||
from openai.types.beta.assistant_tool_param import AssistantToolParam
|
||||
from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
|
||||
from openai.types.beta.file_search_tool_param import FileSearchToolParam
|
||||
from openai.types.beta.function_tool_param import FunctionToolParam
|
||||
from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter
|
||||
from openai.types.beta.threads import Message, MessageDeleted, Run
|
||||
from openai.types.beta.threads.image_url_content_block_param import ImageURLContentBlockParam
|
||||
from openai.types.beta.threads.image_url_param import ImageURLParam
|
||||
from openai.types.beta.threads.message_content_part_param import (
|
||||
MessageContentPartParam,
|
||||
)
|
||||
from openai.types.beta.threads.text_content_block_param import TextContentBlockParam
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
from openai.types.vector_store import VectorStore
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam":
|
||||
"""Convert an autogen Tool to an OpenAI Assistant function tool parameter."""
|
||||
|
||||
schema = tool.schema
|
||||
parameters: Dict[str, object] = {}
|
||||
if "parameters" in schema:
|
||||
parameters = {
|
||||
"type": schema["parameters"]["type"],
|
||||
"properties": schema["parameters"]["properties"],
|
||||
}
|
||||
if "required" in schema["parameters"]:
|
||||
parameters["required"] = schema["parameters"]["required"]
|
||||
|
||||
function_def = FunctionDefinition(
|
||||
name=schema["name"],
|
||||
description=schema.get("description", ""),
|
||||
parameters=parameters,
|
||||
)
|
||||
return FunctionToolParam(type="function", function=function_def)
|
||||
|
||||
|
||||
class OpenAIAssistantAgentState(BaseModel):
|
||||
type: str = Field(default="OpenAIAssistantAgentState")
|
||||
assistant_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None
|
||||
initial_message_ids: List[str] = Field(default_factory=list)
|
||||
vector_store_id: Optional[str] = None
|
||||
uploaded_file_ids: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OpenAIAssistantAgent(BaseChatAgent):
|
||||
"""An agent implementation that uses the Assistant API to generate responses.
|
||||
|
||||
Installation:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[openai]" # For OpenAI Assistant
|
||||
# pip install "agentdhal-ext[openai,azure]" # For Azure OpenAI Assistant
|
||||
|
||||
|
||||
This agent leverages the Assistant API to create AI assistants with capabilities like:
|
||||
|
||||
* Code interpretation and execution
|
||||
* File handling and search
|
||||
* Custom function calling
|
||||
* Multi-turn conversations
|
||||
|
||||
The agent maintains a thread of conversation and can use various tools including
|
||||
|
||||
* Code interpreter: For executing code and working with files
|
||||
* File search: For searching through uploaded documents
|
||||
* Custom functions: For extending capabilities with user-defined tools
|
||||
|
||||
Key Features:
|
||||
|
||||
* Supports multiple file formats including code, documents, images
|
||||
* Can handle up to 128 tools per assistant
|
||||
* Maintains conversation context in threads
|
||||
* Supports file uploads for code interpreter and search
|
||||
* Vector store integration for efficient file search
|
||||
* Automatic file parsing and embedding
|
||||
|
||||
You can use an existing thread or assistant by providing the `thread_id` or `assistant_id` parameters.
|
||||
|
||||
Examples:
|
||||
|
||||
Use the assistant to analyze data in a CSV file:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from agentdhal_core import CancellationToken
|
||||
import asyncio
|
||||
from agentdhal_extensions.agents.openai import OpenAIAssistantAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
|
||||
|
||||
async def example():
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
# Create an OpenAI client
|
||||
client = AsyncOpenAI(api_key="your-api-key", base_url="your-base-url")
|
||||
|
||||
# Create an assistant with code interpreter
|
||||
assistant = OpenAIAssistantAgent(
|
||||
name="PythonHelper",
|
||||
description="Helps with Python programming",
|
||||
client=client,
|
||||
model="gpt-4",
|
||||
instructions="You are a helpful Python programming assistant.",
|
||||
tools=["code_interpreter"],
|
||||
)
|
||||
|
||||
# Upload files for the assistant to use
|
||||
await assistant.on_upload_for_code_interpreter("data.csv", cancellation_token)
|
||||
|
||||
# Get response from the assistant
|
||||
response = await assistant.on_messages(
|
||||
[TextMessage(source="user", content="Analyze the data in data.csv")], cancellation_token
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
# Clean up resources
|
||||
await assistant.delete_uploaded_files(cancellation_token)
|
||||
await assistant.delete_assistant(cancellation_token)
|
||||
|
||||
|
||||
asyncio.run(example())
|
||||
|
||||
Use Azure OpenAI Assistant with AAD authentication:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from openai import AsyncAzureOpenAI
|
||||
import asyncio
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_extensions.agents.openai import OpenAIAssistantAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
|
||||
|
||||
async def example():
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
# Create an Azure OpenAI client
|
||||
token_provider = get_bearer_token_provider(DefaultAzureCredential())
|
||||
client = AsyncAzureOpenAI(
|
||||
azure_deployment="YOUR_AZURE_DEPLOYMENT",
|
||||
api_version="YOUR_API_VERSION",
|
||||
azure_endpoint="YOUR_AZURE_ENDPOINT",
|
||||
azure_ad_token_provider=token_provider,
|
||||
)
|
||||
|
||||
# Create an assistant with code interpreter
|
||||
assistant = OpenAIAssistantAgent(
|
||||
name="PythonHelper",
|
||||
description="Helps with Python programming",
|
||||
client=client,
|
||||
model="gpt-4o",
|
||||
instructions="You are a helpful Python programming assistant.",
|
||||
tools=["code_interpreter"],
|
||||
)
|
||||
|
||||
# Get response from the assistant
|
||||
response = await assistant.on_messages([TextMessage(source="user", content="Hello.")], cancellation_token)
|
||||
|
||||
print(response)
|
||||
|
||||
# Clean up resources
|
||||
await assistant.delete_assistant(cancellation_token)
|
||||
|
||||
|
||||
asyncio.run(example())
|
||||
|
||||
Args:
|
||||
name (str): Name of the assistant
|
||||
description (str): Description of the assistant's purpose
|
||||
client (AsyncOpenAI | AsyncAzureOpenAI): OpenAI client or Azure OpenAI client instance
|
||||
model (str): Model to use (e.g. "gpt-4")
|
||||
instructions (str): System instructions for the assistant
|
||||
tools (Optional[Iterable[Union[Literal["code_interpreter", "file_search"], Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]]]]): Tools the assistant can use
|
||||
assistant_id (Optional[str]): ID of existing assistant to use
|
||||
thread_id (Optional[str]): ID of existing thread to use
|
||||
metadata (Optional[Dict[str, str]]): Additional metadata for the assistant.
|
||||
response_format (Optional[AssistantResponseFormatOptionParam]): Response format settings
|
||||
temperature (Optional[float]): Temperature for response generation
|
||||
tool_resources (Optional[ToolResources]): Additional tool configuration
|
||||
top_p (Optional[float]): Top p sampling parameter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
client: AsyncOpenAI | AsyncAzureOpenAI,
|
||||
model: str,
|
||||
instructions: str,
|
||||
tools: Optional[
|
||||
Iterable[
|
||||
Union[
|
||||
Literal["code_interpreter", "file_search"],
|
||||
Tool | Callable[..., Any] | Callable[..., Awaitable[Any]],
|
||||
]
|
||||
]
|
||||
] = None,
|
||||
assistant_id: Optional[str] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional["AssistantResponseFormatOptionParam"] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_resources: Optional["ToolResources"] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> None:
|
||||
if isinstance(client, ChatCompletionClient):
|
||||
raise ValueError(
|
||||
"Incorrect client passed to OpenAIAssistantAgent. Please use an OpenAI AsyncClient instance instead of an AutoGen ChatCompletionClient instance."
|
||||
)
|
||||
|
||||
super().__init__(name, description)
|
||||
if tools is None:
|
||||
tools = []
|
||||
|
||||
# Store original tools and converted tools separately
|
||||
self._original_tools: List[Tool] = []
|
||||
converted_tools: List["AssistantToolParam"] = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, str):
|
||||
if tool == "code_interpreter":
|
||||
converted_tools.append(CodeInterpreterToolParam(type="code_interpreter"))
|
||||
elif tool == "file_search":
|
||||
converted_tools.append(FileSearchToolParam(type="file_search"))
|
||||
elif isinstance(tool, Tool):
|
||||
self._original_tools.append(tool)
|
||||
converted_tools.append(_convert_tool_to_function_param(tool))
|
||||
elif callable(tool):
|
||||
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
|
||||
description = tool.__doc__
|
||||
else:
|
||||
description = ""
|
||||
function_tool = FunctionTool(tool, description=description)
|
||||
self._original_tools.append(function_tool)
|
||||
converted_tools.append(_convert_tool_to_function_param(function_tool))
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||
|
||||
self._client = client
|
||||
self._assistant: Optional["Assistant"] = None
|
||||
self._thread: Optional["Thread"] = None
|
||||
self._init_thread_id = thread_id
|
||||
self._model = model
|
||||
self._instructions = instructions
|
||||
self._api_tools = converted_tools
|
||||
self._assistant_id = assistant_id
|
||||
self._metadata = metadata
|
||||
self._response_format = response_format
|
||||
self._temperature = temperature
|
||||
self._tool_resources = tool_resources
|
||||
self._top_p = top_p
|
||||
self._vector_store_id: Optional[str] = None
|
||||
self._uploaded_file_ids: List[str] = []
|
||||
|
||||
# Variables to track initial state
|
||||
self._initial_message_ids: Set[str] = set()
|
||||
self._initial_state_retrieved: bool = False
|
||||
|
||||
async def _ensure_initialized(self) -> None:
|
||||
"""Ensure assistant and thread are created."""
|
||||
if self._assistant is None:
|
||||
if self._assistant_id:
|
||||
self._assistant = await self._client.beta.assistants.retrieve(assistant_id=self._assistant_id) # type: ignore[reportDeprecated]
|
||||
else:
|
||||
self._assistant = await self._client.beta.assistants.create( # type: ignore[reportDeprecated]
|
||||
model=self._model,
|
||||
description=self.description,
|
||||
instructions=self._instructions,
|
||||
tools=self._api_tools,
|
||||
metadata=self._metadata,
|
||||
response_format=self._response_format if self._response_format else NOT_GIVEN, # type: ignore
|
||||
temperature=self._temperature,
|
||||
tool_resources=self._tool_resources if self._tool_resources else NOT_GIVEN, # type: ignore
|
||||
top_p=self._top_p,
|
||||
)
|
||||
|
||||
if self._thread is None:
|
||||
if self._init_thread_id:
|
||||
self._thread = await self._client.beta.threads.retrieve(thread_id=self._init_thread_id) # type: ignore[reportDeprecated]
|
||||
else:
|
||||
self._thread = await self._client.beta.threads.create() # type: ignore[reportDeprecated]
|
||||
|
||||
# Retrieve initial state only once
|
||||
if not self._initial_state_retrieved:
|
||||
await self._retrieve_initial_state()
|
||||
self._initial_state_retrieved = True
|
||||
|
||||
async def _retrieve_initial_state(self) -> None:
|
||||
"""Retrieve and store the initial state of messages and runs."""
|
||||
# Retrieve all initial message IDs
|
||||
initial_message_ids: Set[str] = set()
|
||||
after: str | NotGiven = NOT_GIVEN
|
||||
while True:
|
||||
msgs: AsyncCursorPage[Message] = await self._client.beta.threads.messages.list( # type: ignore[reportDeprecated]
|
||||
self._thread_id, after=after, order="asc", limit=100
|
||||
)
|
||||
for msg in msgs.data:
|
||||
initial_message_ids.add(msg.id)
|
||||
if not msgs.has_next_page():
|
||||
break
|
||||
after = msgs.data[-1].id
|
||||
self._initial_message_ids = initial_message_ids
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""The types of messages that the assistant agent produces."""
|
||||
return (TextMessage,)
|
||||
|
||||
@property
|
||||
def threads(self) -> AsyncThreads:
|
||||
return self._client.beta.threads
|
||||
|
||||
@property
|
||||
def runs(self) -> AsyncRuns:
|
||||
return self._client.beta.threads.runs
|
||||
|
||||
@property
|
||||
def messages(self) -> AsyncMessages:
|
||||
return self._client.beta.threads.messages
|
||||
|
||||
@property
|
||||
def _get_assistant_id(self) -> str:
|
||||
if self._assistant is None:
|
||||
raise ValueError("Assistant not initialized")
|
||||
return self._assistant.id
|
||||
|
||||
@property
|
||||
def _thread_id(self) -> str:
|
||||
if self._thread is None:
|
||||
raise ValueError("Thread not initialized")
|
||||
return self._thread.id
|
||||
|
||||
async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str:
|
||||
"""Execute a tool call and return the result."""
|
||||
if not self._original_tools:
|
||||
raise ValueError("No tools are available.")
|
||||
tool = next((t for t in self._original_tools if t.name == tool_call.name), None)
|
||||
if tool is None:
|
||||
raise ValueError(f"The tool '{tool_call.name}' is not available.")
|
||||
arguments = json.loads(tool_call.arguments)
|
||||
result = await tool.run_json(arguments, cancellation_token, call_id=tool_call.id)
|
||||
return tool.return_value_as_string(result)
|
||||
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
"""Handle incoming messages and return a response."""
|
||||
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
return message
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""Handle incoming messages and return a response."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
# Process all messages in sequence
|
||||
for message in messages:
|
||||
await self.handle_incoming_message(message, cancellation_token)
|
||||
|
||||
# Inner messages for tool calls
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
|
||||
# Create and start a run
|
||||
run: Run = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.runs.create( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
assistant_id=self._get_assistant_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for run completion by polling
|
||||
while True:
|
||||
run = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.runs.retrieve( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
run_id=run.id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if run.status == "failed":
|
||||
raise ValueError(f"Run failed: {run.last_error}")
|
||||
|
||||
# If the run requires action (function calls), execute tools and continue
|
||||
if run.status == "requires_action" and run.required_action is not None:
|
||||
tool_calls: List[FunctionCall] = []
|
||||
for required_tool_call in run.required_action.submit_tool_outputs.tool_calls:
|
||||
if required_tool_call.type == "function":
|
||||
tool_calls.append(
|
||||
FunctionCall(
|
||||
id=required_tool_call.id,
|
||||
name=required_tool_call.function.name,
|
||||
arguments=required_tool_call.function.arguments,
|
||||
)
|
||||
)
|
||||
|
||||
# Add tool call message to inner messages
|
||||
tool_call_msg = ToolCallRequestEvent(source=self.name, content=tool_calls)
|
||||
inner_messages.append(tool_call_msg)
|
||||
event_logger.debug(tool_call_msg)
|
||||
yield tool_call_msg
|
||||
|
||||
# Execute tool calls and get results
|
||||
tool_outputs: List[FunctionExecutionResult] = []
|
||||
for tool_call in tool_calls:
|
||||
try:
|
||||
result = await self._execute_tool_call(tool_call, cancellation_token)
|
||||
is_error = False
|
||||
except Exception as e:
|
||||
result = f"Error: {e}"
|
||||
is_error = True
|
||||
tool_outputs.append(
|
||||
FunctionExecutionResult(
|
||||
content=result, call_id=tool_call.id, is_error=is_error, name=tool_call.name
|
||||
)
|
||||
)
|
||||
|
||||
# Add tool result message to inner messages
|
||||
tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs)
|
||||
inner_messages.append(tool_result_msg)
|
||||
event_logger.debug(tool_result_msg)
|
||||
yield tool_result_msg
|
||||
|
||||
# Submit tool outputs back to the run
|
||||
run = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.runs.submit_tool_outputs( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
run_id=run.id,
|
||||
tool_outputs=[{"tool_call_id": t.call_id, "output": t.content} for t in tool_outputs],
|
||||
)
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if run.status == "completed":
|
||||
break
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Get messages after run completion
|
||||
assistant_messages: AsyncCursorPage[Message] = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.list(thread_id=self._thread_id, order="desc", limit=1) # type: ignore[reportDeprecated]
|
||||
)
|
||||
)
|
||||
|
||||
if not assistant_messages.data:
|
||||
raise ValueError("No messages received from assistant")
|
||||
|
||||
# Get the last message's content
|
||||
last_message = assistant_messages.data[0]
|
||||
if not last_message.content:
|
||||
raise ValueError(f"No content in the last message: {last_message}")
|
||||
|
||||
# Extract text content
|
||||
text_content = [content for content in last_message.content if content.type == "text"]
|
||||
if not text_content:
|
||||
raise ValueError(f"Expected text content in the last message: {last_message.content}")
|
||||
|
||||
# Return the assistant's response as a Response with inner messages
|
||||
chat_message = TextMessage(source=self.name, content=text_content[0].text.value)
|
||||
yield Response(chat_message=chat_message, inner_messages=inner_messages)
|
||||
|
||||
async def handle_incoming_message(self, message: BaseChatMessage, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle regular text messages by adding them to the thread."""
|
||||
content: str | List[MessageContentPartParam] | None = None
|
||||
llm_message = message.to_model_message()
|
||||
if isinstance(llm_message.content, str):
|
||||
content = llm_message.content
|
||||
else:
|
||||
content = []
|
||||
for c in llm_message.content:
|
||||
if isinstance(c, str):
|
||||
content.append(TextContentBlockParam(text=c, type="text"))
|
||||
elif isinstance(c, Image):
|
||||
content.append(ImageURLContentBlockParam(image_url=ImageURLParam(url=c.data_uri), type="image_url"))
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(c)} in {message}")
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.create( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
content=content,
|
||||
role="user",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle reset command by deleting new messages and runs since initialization."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
# Retrieve all message IDs in the thread
|
||||
new_message_ids: List[str] = []
|
||||
after: str | NotGiven = NOT_GIVEN
|
||||
while True:
|
||||
msgs: AsyncCursorPage[Message] = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.list(self._thread_id, after=after, order="asc", limit=100) # type: ignore[reportDeprecated]
|
||||
)
|
||||
)
|
||||
for msg in msgs.data:
|
||||
if msg.id not in self._initial_message_ids:
|
||||
new_message_ids.append(msg.id)
|
||||
if not msgs.has_next_page():
|
||||
break
|
||||
after = msgs.data[-1].id
|
||||
|
||||
# Delete new messages
|
||||
for msg_id in new_message_ids:
|
||||
status: MessageDeleted = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id) # type: ignore[reportDeprecated]
|
||||
)
|
||||
)
|
||||
assert status.deleted is True
|
||||
|
||||
async def _upload_files(self, file_paths: str | Iterable[str], cancellation_token: CancellationToken) -> List[str]:
|
||||
"""Upload files and return their IDs."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
if isinstance(file_paths, str):
|
||||
file_paths = [file_paths]
|
||||
|
||||
file_ids: List[str] = []
|
||||
for file_path in file_paths:
|
||||
async with aiofiles.open(file_path, mode="rb") as f:
|
||||
file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read()))
|
||||
file_name = os.path.basename(file_path)
|
||||
|
||||
file: FileObject = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.files.create(file=(file_name, file_content), purpose="assistants"))
|
||||
)
|
||||
file_ids.append(file.id)
|
||||
self._uploaded_file_ids.append(file.id)
|
||||
|
||||
return file_ids
|
||||
|
||||
async def on_upload_for_code_interpreter(
|
||||
self, file_paths: str | Iterable[str], cancellation_token: CancellationToken
|
||||
) -> None:
|
||||
"""Handle file uploads for the code interpreter."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
file_ids = await self._upload_files(file_paths, cancellation_token)
|
||||
|
||||
# Update thread with the new files
|
||||
thread = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id)) # type: ignore[reportDeprecated]
|
||||
)
|
||||
tool_resources: ToolResources = thread.tool_resources or ToolResources()
|
||||
code_interpreter: ToolResourcesCodeInterpreter = (
|
||||
tool_resources.code_interpreter or ToolResourcesCodeInterpreter()
|
||||
)
|
||||
existing_file_ids: List[str] = code_interpreter.file_ids or []
|
||||
existing_file_ids.extend(file_ids)
|
||||
tool_resources.code_interpreter = ToolResourcesCodeInterpreter(file_ids=existing_file_ids)
|
||||
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.update( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
tool_resources=cast(thread_update_params.ToolResources, tool_resources.model_dump()),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def on_upload_for_file_search(
|
||||
self, file_paths: str | Iterable[str], cancellation_token: CancellationToken
|
||||
) -> None:
|
||||
"""Handle file uploads for file search."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
# Check if file_search is enabled in tools
|
||||
if not any(tool.get("type") == "file_search" for tool in self._api_tools):
|
||||
raise ValueError(
|
||||
"File search is not enabled for this assistant. Add a file_search tool when creating the assistant."
|
||||
)
|
||||
|
||||
# Create vector store if not already created
|
||||
if self._vector_store_id is None:
|
||||
vector_store: VectorStore = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.vector_stores.create())
|
||||
)
|
||||
self._vector_store_id = vector_store.id
|
||||
|
||||
# Update assistant with vector store ID
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.assistants.update(
|
||||
assistant_id=self._get_assistant_id,
|
||||
tool_resources={"file_search": {"vector_store_ids": [self._vector_store_id]}},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
file_ids = await self._upload_files(file_paths, cancellation_token)
|
||||
|
||||
# Create file batch with the file IDs
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.vector_stores.file_batches.create_and_poll(
|
||||
vector_store_id=self._vector_store_id, file_ids=file_ids
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def delete_uploaded_files(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Delete all files that were uploaded by this agent instance."""
|
||||
await self._ensure_initialized()
|
||||
for file_id in self._uploaded_file_ids:
|
||||
try:
|
||||
await cancellation_token.link_future(asyncio.ensure_future(self._client.files.delete(file_id=file_id)))
|
||||
except Exception as e:
|
||||
event_logger.error(f"Failed to delete file {file_id}: {str(e)}")
|
||||
self._uploaded_file_ids = []
|
||||
|
||||
async def delete_assistant(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Delete the assistant if it was created by this instance."""
|
||||
await self._ensure_initialized()
|
||||
if self._assistant is not None and not self._assistant_id:
|
||||
try:
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.beta.assistants.delete(assistant_id=self._get_assistant_id)) # type: ignore[reportDeprecated]
|
||||
)
|
||||
self._assistant = None
|
||||
except Exception as e:
|
||||
event_logger.error(f"Failed to delete assistant: {str(e)}")
|
||||
|
||||
async def delete_vector_store(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Delete the vector store if it was created by this instance."""
|
||||
await self._ensure_initialized()
|
||||
if self._vector_store_id is not None:
|
||||
try:
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.vector_stores.delete(vector_store_id=self._vector_store_id))
|
||||
)
|
||||
self._vector_store_id = None
|
||||
except Exception as e:
|
||||
event_logger.error(f"Failed to delete vector store: {str(e)}")
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = OpenAIAssistantAgentState(
|
||||
assistant_id=self._assistant.id if self._assistant else self._assistant_id,
|
||||
thread_id=self._thread.id if self._thread else self._init_thread_id,
|
||||
initial_message_ids=list(self._initial_message_ids),
|
||||
vector_store_id=self._vector_store_id,
|
||||
uploaded_file_ids=self._uploaded_file_ids,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
agent_state = OpenAIAssistantAgentState.model_validate(state)
|
||||
self._assistant_id = agent_state.assistant_id
|
||||
self._init_thread_id = agent_state.thread_id
|
||||
self._initial_message_ids = set(agent_state.initial_message_ids)
|
||||
self._vector_store_id = agent_state.vector_store_id
|
||||
self._uploaded_file_ids = agent_state.uploaded_file_ids
|
||||
@@ -0,0 +1,3 @@
|
||||
from ._video_surfer import VideoSurfer
|
||||
|
||||
__all__ = ["VideoSurfer"]
|
||||
@@ -0,0 +1,172 @@
|
||||
from typing import Any, Awaitable, Callable, List, Optional
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_core.models import ChatCompletionClient
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .tools import (
|
||||
extract_audio,
|
||||
get_screenshot_at,
|
||||
get_video_length,
|
||||
save_screenshot,
|
||||
transcribe_audio_with_timestamps,
|
||||
transcribe_video_screenshot,
|
||||
)
|
||||
|
||||
|
||||
class VideoSurfer(AssistantAgent):
|
||||
"""
|
||||
VideoSurfer is a specialized agent designed to answer questions about a local video file.
|
||||
|
||||
Installation:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[video-surfer]"
|
||||
|
||||
This agent utilizes various tools to extract information from the video, such as its length, screenshots at specific timestamps, and audio transcriptions. It processes these elements to provide detailed answers to user queries.
|
||||
|
||||
Available tools:
|
||||
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.extract_audio`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.get_video_length`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.transcribe_audio_with_timestamps`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.get_screenshot_at`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.save_screenshot`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.transcribe_video_screenshot`
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
model_client (ChatCompletionClient): The model client used for generating responses.
|
||||
tools (List[BaseTool[BaseModel, BaseModel] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional):
|
||||
A list of tools or functions the agent can use. If not provided, defaults to all video tools from the action space.
|
||||
description (str, optional): A brief description of the agent. Defaults to "An agent that can answer questions about a local video.".
|
||||
system_message (str | None, optional): The system message guiding the agent's behavior. Defaults to a predefined message.
|
||||
|
||||
Example usage:
|
||||
|
||||
The following example demonstrates how to create an video surfing agent with
|
||||
a model client and generate a response to a simple query about a local video
|
||||
called video.mp4.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.agents.video_surfer import VideoSurfer
|
||||
|
||||
async def main() -> None:
|
||||
\"\"\"
|
||||
Main function to run the video agent.
|
||||
\"\"\"
|
||||
# Define an agent
|
||||
video_agent = VideoSurfer(
|
||||
name="VideoSurfer",
|
||||
model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-08-06")
|
||||
)
|
||||
|
||||
# Define termination condition
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
|
||||
# Define a team
|
||||
agent_team = RoundRobinGroupChat([video_agent], termination_condition=termination)
|
||||
|
||||
# Run the team and stream messages to the console
|
||||
stream = agent_team.run_stream(task="How does Adam define complex tasks in video.mp4? What concrete example of complex does his use? Can you save this example to disk as well?")
|
||||
await Console(stream)
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
The following example demonstrates how to create and use a VideoSurfer and UserProxyAgent with MagenticOneGroupChat.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.teams import MagenticOneGroupChat
|
||||
from agentdhal_agentchat.agents import UserProxyAgent
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.agents.video_surfer import VideoSurfer
|
||||
|
||||
async def main() -> None:
|
||||
\"\"\"
|
||||
Main function to run the video agent.
|
||||
\"\"\"
|
||||
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o-2024-08-06")
|
||||
|
||||
# Define an agent
|
||||
video_agent = VideoSurfer(
|
||||
name="VideoSurfer",
|
||||
model_client=model_client
|
||||
)
|
||||
|
||||
web_surfer_agent = UserProxyAgent(
|
||||
name="User"
|
||||
)
|
||||
|
||||
# Define a team
|
||||
agent_team = MagenticOneGroupChat([web_surfer_agent, video_agent], model_client=model_client,)
|
||||
|
||||
# Run the team and stream messages to the console
|
||||
stream = agent_team.run_stream(task="Find a latest video about magentic one on youtube and extract quotes from it that make sense.")
|
||||
await Console(stream)
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
DEFAULT_DESCRIPTION = "An agent that can answer questions about a local video."
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGE = """
|
||||
You are a helpful agent that is an expert at answering questions from a video.
|
||||
When asked to answer a question about a video, you should:
|
||||
1. Check if that video is available locally.
|
||||
2. Use the transcription to find which part of the video the question is referring to.
|
||||
3. Optionally use screenshots from those timestamps
|
||||
4. Provide a detailed answer to the question.
|
||||
Reply with TERMINATE when the task has been completed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
tools: List[BaseTool[BaseModel, BaseModel] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
|
||||
description: Optional[str] = None,
|
||||
system_message: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
model_client=model_client,
|
||||
tools=tools
|
||||
or [
|
||||
get_video_length,
|
||||
get_screenshot_at,
|
||||
save_screenshot,
|
||||
self.vs_transribe_video_screenshot,
|
||||
extract_audio,
|
||||
transcribe_audio_with_timestamps,
|
||||
],
|
||||
description=description or self.DEFAULT_DESCRIPTION,
|
||||
system_message=system_message or self.DEFAULT_SYSTEM_MESSAGE,
|
||||
)
|
||||
|
||||
async def vs_transribe_video_screenshot(self, video_path: str, timestamp: float) -> str:
|
||||
"""
|
||||
Transcribes the video screenshot at a specific timestamp.
|
||||
|
||||
Args:
|
||||
video_path (str): Path to the video file.
|
||||
timestamp (float): Timestamp to take the screenshot.
|
||||
|
||||
Returns:
|
||||
str: Transcription of the video screenshot.
|
||||
"""
|
||||
return await transcribe_video_screenshot(video_path, timestamp, self._model_client)
|
||||
156
agent_dhal/agentdhal_extensions/agents/video_surfer/tools.py
Normal file
156
agent_dhal/agentdhal_extensions/agents/video_surfer/tools.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import base64
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import cv2
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
import whisper
|
||||
from agentdhal_core import Image as AGImage
|
||||
from agentdhal_core.models import (
|
||||
ChatCompletionClient,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
|
||||
def extract_audio(video_path: str, audio_output_path: str) -> str:
|
||||
"""
|
||||
Extracts audio from a video file and saves it as an MP3 file.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:param audio_output_path: Path to save the extracted audio file.
|
||||
:return: Confirmation message with the path to the saved audio file.
|
||||
"""
|
||||
(ffmpeg.input(video_path).output(audio_output_path, format="mp3").run(quiet=True, overwrite_output=True)) # type: ignore
|
||||
return f"Audio extracted and saved to {audio_output_path}."
|
||||
|
||||
|
||||
def transcribe_audio_with_timestamps(audio_path: str) -> str:
|
||||
"""
|
||||
Transcribes the audio file with timestamps using the Whisper model.
|
||||
|
||||
:param audio_path: Path to the audio file.
|
||||
:return: Transcription with timestamps.
|
||||
"""
|
||||
model = whisper.load_model("base") # type: ignore
|
||||
result: Dict[str, Any] = model.transcribe(audio_path, task="transcribe", language="en", verbose=False) # type: ignore
|
||||
|
||||
segments: List[Dict[str, Any]] = result["segments"]
|
||||
transcription_with_timestamps = ""
|
||||
|
||||
for segment in segments:
|
||||
start: float = segment["start"]
|
||||
end: float = segment["end"]
|
||||
text: str = segment["text"]
|
||||
transcription_with_timestamps += f"[{start:.2f} - {end:.2f}] {text}\n"
|
||||
|
||||
return transcription_with_timestamps
|
||||
|
||||
|
||||
def get_video_length(video_path: str) -> str:
|
||||
"""
|
||||
Returns the length of the video in seconds.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:return: Duration of the video in seconds.
|
||||
"""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise IOError(f"Cannot open video file {video_path}")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
duration = frame_count / fps
|
||||
cap.release()
|
||||
|
||||
return f"The video is {duration:.2f} seconds long."
|
||||
|
||||
|
||||
def save_screenshot(video_path: str, timestamp: float, output_path: str) -> None:
|
||||
"""
|
||||
Captures a screenshot at the specified timestamp and saves it to the output path.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:param timestamp: Timestamp in seconds.
|
||||
:param output_path: Path to save the screenshot. The file format is determined by the extension in the path.
|
||||
"""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise IOError(f"Cannot open video file {video_path}")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_number = int(timestamp * fps)
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
cv2.imwrite(output_path, frame)
|
||||
else:
|
||||
raise IOError(f"Failed to capture frame at {timestamp:.2f}s")
|
||||
cap.release()
|
||||
|
||||
|
||||
async def transcribe_video_screenshot(video_path: str, timestamp: float, model_client: ChatCompletionClient) -> str:
|
||||
"""
|
||||
Transcribes the content of a video screenshot captured at the specified timestamp using OpenAI API.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:param timestamp: Timestamp in seconds.
|
||||
:param model_client: ChatCompletionClient instance.
|
||||
:return: Description of the screenshot content.
|
||||
"""
|
||||
screenshots = get_screenshot_at(video_path, [timestamp])
|
||||
if not screenshots:
|
||||
return "Failed to capture screenshot."
|
||||
|
||||
_, frame = screenshots[0]
|
||||
# Convert the frame to bytes and then to base64 encoding
|
||||
_, buffer = cv2.imencode(".jpg", frame)
|
||||
frame_bytes = buffer.tobytes()
|
||||
frame_base64 = base64.b64encode(frame_bytes).decode("utf-8")
|
||||
screenshot_uri = f"data:image/jpeg;base64,{frame_base64}"
|
||||
|
||||
messages = [
|
||||
UserMessage(
|
||||
content=[
|
||||
"Following is a screenshot from the video at {} seconds. Describe what you see here.",
|
||||
AGImage.from_uri(screenshot_uri),
|
||||
],
|
||||
source="tool",
|
||||
)
|
||||
]
|
||||
|
||||
result = await model_client.create(messages=messages)
|
||||
return str(result.content)
|
||||
|
||||
|
||||
def get_screenshot_at(video_path: str, timestamps: List[float]) -> List[Tuple[float, np.ndarray[Any, Any]]]:
|
||||
"""
|
||||
Captures screenshots at the specified timestamps and returns them as Python objects.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:param timestamps: List of timestamps in seconds.
|
||||
:return: List of tuples containing timestamp and the corresponding frame (image).
|
||||
Each frame is a NumPy array (height x width x channels).
|
||||
"""
|
||||
screenshots: List[Tuple[float, np.ndarray[Any, Any]]] = []
|
||||
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise IOError(f"Cannot open video file {video_path}")
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
duration = total_frames / fps
|
||||
|
||||
for timestamp in timestamps:
|
||||
if 0 <= timestamp <= duration:
|
||||
frame_number = int(timestamp * fps)
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
# Append the timestamp and frame to the list
|
||||
screenshots.append((timestamp, frame))
|
||||
else:
|
||||
raise IOError(f"Failed to capture frame at {timestamp:.2f}s")
|
||||
else:
|
||||
raise ValueError(f"Timestamp {timestamp:.2f}s is out of range [0s, {duration:.2f}s]")
|
||||
|
||||
cap.release()
|
||||
return screenshots
|
||||
@@ -0,0 +1,4 @@
|
||||
from ._multimodal_web_surfer import MultimodalWebSurfer
|
||||
from .playwright_controller import PlaywrightController
|
||||
|
||||
__all__ = ["MultimodalWebSurfer", "PlaywrightController"]
|
||||
11
agent_dhal/agentdhal_extensions/agents/web_surfer/_events.py
Normal file
11
agent_dhal/agentdhal_extensions/agents/web_surfer/_events.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSurferEvent:
|
||||
source: str
|
||||
message: str
|
||||
url: str
|
||||
action: str | None = None
|
||||
arguments: Dict[str, Any] | None = None
|
||||
@@ -0,0 +1,988 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
)
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import aiofiles
|
||||
import PIL.Image
|
||||
from agentdhal_agentchat.agents import BaseChatAgent
|
||||
from agentdhal_agentchat.base import Response
|
||||
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, MultiModalMessage, TextMessage
|
||||
from agentdhal_agentchat.utils import content_to_str, remove_images
|
||||
from agentdhal_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel, FunctionCall
|
||||
from agentdhal_core import Image as AGImage
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
LLMMessage,
|
||||
ModelFamily,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from PIL import Image
|
||||
from playwright.async_api import BrowserContext, Download, Page, Playwright, async_playwright
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._events import WebSurferEvent
|
||||
from ._prompts import (
|
||||
WEB_SURFER_QA_PROMPT,
|
||||
WEB_SURFER_QA_SYSTEM_MESSAGE,
|
||||
WEB_SURFER_TOOL_PROMPT_MM,
|
||||
WEB_SURFER_TOOL_PROMPT_TEXT,
|
||||
)
|
||||
from ._set_of_mark import add_set_of_mark
|
||||
from ._tool_definitions import (
|
||||
TOOL_CLICK,
|
||||
TOOL_HISTORY_BACK,
|
||||
TOOL_HOVER,
|
||||
TOOL_READ_PAGE_AND_ANSWER,
|
||||
TOOL_SCROLL_DOWN,
|
||||
TOOL_SCROLL_UP,
|
||||
TOOL_SLEEP,
|
||||
TOOL_SUMMARIZE_PAGE,
|
||||
TOOL_TYPE,
|
||||
TOOL_VISIT_URL,
|
||||
TOOL_WEB_SEARCH,
|
||||
)
|
||||
from ._types import InteractiveRegion, UserContent
|
||||
from .playwright_controller import PlaywrightController
|
||||
|
||||
DEFAULT_CONTEXT_SIZE = 128000
|
||||
|
||||
|
||||
class MultimodalWebSurferConfig(BaseModel):
|
||||
name: str
|
||||
model_client: ComponentModel
|
||||
downloads_folder: str | None = None
|
||||
description: str | None = None
|
||||
debug_dir: str | None = None
|
||||
headless: bool = True
|
||||
start_page: str | None = "https://www.bing.com/"
|
||||
animate_actions: bool = False
|
||||
to_save_screenshots: bool = False
|
||||
use_ocr: bool = False
|
||||
browser_channel: str | None = None
|
||||
browser_data_dir: str | None = None
|
||||
to_resize_viewport: bool = True
|
||||
|
||||
|
||||
class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
|
||||
"""
|
||||
MultimodalWebSurfer is a multimodal agent that acts as a web surfer that can search the web and visit web pages.
|
||||
|
||||
Installation:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[web-surfer]"
|
||||
|
||||
It launches a chromium browser and allows the playwright to interact with the web browser and can perform a variety of actions. The browser is launched on the first call to the agent and is reused for subsequent calls.
|
||||
|
||||
It must be used with a multimodal model client that supports function/tool calling, ideally GPT-4o currently.
|
||||
|
||||
|
||||
When :meth:`on_messages` or :meth:`on_messages_stream` is called, the following occurs:
|
||||
1) If this is the first call, the browser is initialized and the page is loaded. This is done in :meth:`_lazy_init`. The browser is only closed when :meth:`close` is called.
|
||||
2) The method :meth:`_generate_reply` is called, which then creates the final response as below.
|
||||
3) The agent takes a screenshot of the page, extracts the interactive elements, and prepares a set-of-mark screenshot with bounding boxes around the interactive elements.
|
||||
4) The agent makes a call to the :attr:`model_client` with the SOM screenshot, history of messages, and the list of available tools.
|
||||
- If the model returns a string, the agent returns the string as the final response.
|
||||
- If the model returns a list of tool calls, the agent executes the tool calls with :meth:`_execute_tool` using :attr:`_playwright_controller`.
|
||||
- The agent returns a final response which includes a screenshot of the page, page metadata, description of the action taken and the inner text of the webpage.
|
||||
5) If at any point the agent encounters an error, it returns the error message as the final response.
|
||||
|
||||
|
||||
.. note::
|
||||
Please note that using the MultimodalWebSurfer involves interacting with a digital world designed for humans, which carries inherent risks.
|
||||
Be aware that agents may occasionally attempt risky actions, such as recruiting humans for help or accepting cookie agreements without human involvement. Always ensure agents are monitored and operate within a controlled environment to prevent unintended consequences.
|
||||
Moreover, be cautious that MultimodalWebSurfer may be susceptible to prompt injection attacks from webpages.
|
||||
|
||||
.. note::
|
||||
|
||||
On Windows, the event loop policy must be set to `WindowsProactorEventLoopPolicy` to avoid issues with subprocesses.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
model_client (ChatCompletionClient): The model client used by the agent. Must be multimodal and support function calling.
|
||||
downloads_folder (str, optional): The folder where downloads are saved. Defaults to None, no downloads are saved.
|
||||
description (str, optional): The description of the agent. Defaults to MultimodalWebSurfer.DEFAULT_DESCRIPTION.
|
||||
debug_dir (str, optional): The directory where debug information is saved. Defaults to None.
|
||||
headless (bool, optional): Whether the browser should be headless. Defaults to True.
|
||||
start_page (str, optional): The start page for the browser. Defaults to MultimodalWebSurfer.DEFAULT_START_PAGE.
|
||||
animate_actions (bool, optional): Whether to animate actions. Defaults to False.
|
||||
to_save_screenshots (bool, optional): Whether to save screenshots. Defaults to False.
|
||||
use_ocr (bool, optional): Whether to use OCR. Defaults to False.
|
||||
browser_channel (str, optional): The browser channel. Defaults to None.
|
||||
browser_data_dir (str, optional): The browser data directory. Defaults to None.
|
||||
to_resize_viewport (bool, optional): Whether to resize the viewport. Defaults to True.
|
||||
playwright (Playwright, optional): The playwright instance. Defaults to None.
|
||||
context (BrowserContext, optional): The browser context. Defaults to None.
|
||||
|
||||
|
||||
|
||||
|
||||
Example usage:
|
||||
|
||||
The following example demonstrates how to create a web surfing agent with
|
||||
a model client and run it for multiple turns.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.agents.web_surfer import MultimodalWebSurfer
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Define an agent
|
||||
web_surfer_agent = MultimodalWebSurfer(
|
||||
name="MultimodalWebSurfer",
|
||||
model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-08-06"),
|
||||
)
|
||||
|
||||
# Define a team
|
||||
agent_team = RoundRobinGroupChat([web_surfer_agent], max_turns=3)
|
||||
|
||||
# Run the team and stream messages to the console
|
||||
stream = agent_team.run_stream(task="Navigate to the AutoGen readme on GitHub.")
|
||||
await Console(stream)
|
||||
# Close the browser controlled by the agent
|
||||
await web_surfer_agent.close()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_type = "agent"
|
||||
component_config_schema = MultimodalWebSurferConfig
|
||||
component_provider_override = "agentdhal_extensions.agents.web_surfer.MultimodalWebSurfer"
|
||||
|
||||
DEFAULT_DESCRIPTION = """
|
||||
A helpful assistant with access to a web browser.
|
||||
Ask them to perform web searches, open pages, and interact with content (e.g., clicking links, scrolling the viewport, filling in form fields, etc.).
|
||||
It can also summarize the entire page, or answer questions based on the content of the page.
|
||||
It can also be asked to sleep and wait for pages to load, in cases where the page seems not yet fully loaded.
|
||||
"""
|
||||
DEFAULT_START_PAGE = "https://www.bing.com/"
|
||||
|
||||
# Viewport dimensions
|
||||
VIEWPORT_HEIGHT = 900
|
||||
VIEWPORT_WIDTH = 1440
|
||||
|
||||
# Size of the image we send to the MLM
|
||||
# Current values represent a 0.85 scaling to fit within the GPT-4v short-edge constraints (768px)
|
||||
MLM_HEIGHT = 765
|
||||
MLM_WIDTH = 1224
|
||||
|
||||
SCREENSHOT_TOKENS = 1105
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_client: ChatCompletionClient,
|
||||
downloads_folder: str | None = None,
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
debug_dir: str | None = None,
|
||||
headless: bool = True,
|
||||
start_page: str | None = DEFAULT_START_PAGE,
|
||||
animate_actions: bool = False,
|
||||
to_save_screenshots: bool = False,
|
||||
use_ocr: bool = False,
|
||||
browser_channel: str | None = None,
|
||||
browser_data_dir: str | None = None,
|
||||
to_resize_viewport: bool = True,
|
||||
playwright: Playwright | None = None,
|
||||
context: BrowserContext | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the MultimodalWebSurfer.
|
||||
"""
|
||||
super().__init__(name, description)
|
||||
if debug_dir is None and to_save_screenshots:
|
||||
raise ValueError(
|
||||
"Cannot save screenshots without a debug directory. Set it using the 'debug_dir' parameter. The debug directory is created if it does not exist."
|
||||
)
|
||||
if model_client.model_info["function_calling"] is False:
|
||||
raise ValueError(
|
||||
"The model does not support function calling. MultimodalWebSurfer requires a model that supports function calling."
|
||||
)
|
||||
|
||||
self._model_client = model_client
|
||||
self.headless = headless
|
||||
self.browser_channel = browser_channel
|
||||
self.browser_data_dir = browser_data_dir
|
||||
self.start_page = start_page or self.DEFAULT_START_PAGE
|
||||
self.downloads_folder = downloads_folder
|
||||
self.debug_dir = debug_dir
|
||||
self.to_save_screenshots = to_save_screenshots
|
||||
self.use_ocr = use_ocr
|
||||
self.to_resize_viewport = to_resize_viewport
|
||||
self.animate_actions = animate_actions
|
||||
|
||||
# Call init to set these in case not set
|
||||
self._playwright: Playwright | None = playwright
|
||||
self._context: BrowserContext | None = context
|
||||
self._page: Page | None = None
|
||||
self._last_download: Download | None = None
|
||||
self._prior_metadata_hash: str | None = None
|
||||
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.name}.MultimodalWebSurfer")
|
||||
self._chat_history: List[LLMMessage] = []
|
||||
|
||||
# Define the download handler
|
||||
def _download_handler(download: Download) -> None:
|
||||
self._last_download = download
|
||||
|
||||
self._download_handler = _download_handler
|
||||
|
||||
# Define the Playwright controller that handles the browser interactions
|
||||
self._playwright_controller = PlaywrightController(
|
||||
animate_actions=self.animate_actions,
|
||||
downloads_folder=self.downloads_folder,
|
||||
viewport_width=self.VIEWPORT_WIDTH,
|
||||
viewport_height=self.VIEWPORT_HEIGHT,
|
||||
_download_handler=self._download_handler,
|
||||
to_resize_viewport=self.to_resize_viewport,
|
||||
)
|
||||
self.default_tools = [
|
||||
TOOL_VISIT_URL,
|
||||
TOOL_WEB_SEARCH,
|
||||
TOOL_HISTORY_BACK,
|
||||
TOOL_CLICK,
|
||||
TOOL_TYPE,
|
||||
TOOL_READ_PAGE_AND_ANSWER,
|
||||
TOOL_SUMMARIZE_PAGE,
|
||||
TOOL_SLEEP,
|
||||
TOOL_HOVER,
|
||||
]
|
||||
self.did_lazy_init = False # flag to check if we have initialized the browser
|
||||
|
||||
async def _lazy_init(
|
||||
self,
|
||||
) -> None:
|
||||
"""
|
||||
On the first call, we initialize the browser and the page.
|
||||
"""
|
||||
|
||||
# Check the current event loop policy if on windows.
|
||||
if sys.platform == "win32":
|
||||
current_policy = asyncio.get_event_loop_policy()
|
||||
if hasattr(asyncio, "WindowsProactorEventLoopPolicy") and not isinstance(
|
||||
current_policy, asyncio.WindowsProactorEventLoopPolicy
|
||||
):
|
||||
warnings.warn(
|
||||
"The current event loop policy is not WindowsProactorEventLoopPolicy. "
|
||||
"This may cause issues with subprocesses. "
|
||||
"Try setting the event loop policy to WindowsProactorEventLoopPolicy. "
|
||||
"For example: `asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())`. "
|
||||
"See https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.ProactorEventLoop.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self._last_download = None
|
||||
self._prior_metadata_hash = None
|
||||
|
||||
# Create the playwright self
|
||||
launch_args: Dict[str, Any] = {"headless": self.headless}
|
||||
if self.browser_channel is not None:
|
||||
launch_args["channel"] = self.browser_channel
|
||||
if self._playwright is None:
|
||||
self._playwright = await async_playwright().start()
|
||||
|
||||
# Create the context -- are we launching persistent?
|
||||
if self._context is None:
|
||||
if self.browser_data_dir is None:
|
||||
browser = await self._playwright.chromium.launch(**launch_args)
|
||||
self._context = await browser.new_context(
|
||||
user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36 Edg/122.0.0.0"
|
||||
)
|
||||
else:
|
||||
self._context = await self._playwright.chromium.launch_persistent_context(
|
||||
self.browser_data_dir, **launch_args
|
||||
)
|
||||
|
||||
# Create the page
|
||||
self._context.set_default_timeout(60000) # One minute
|
||||
self._page = await self._context.new_page()
|
||||
assert self._page is not None
|
||||
# self._page.route(lambda x: True, self._route_handler)
|
||||
self._page.on("download", self._download_handler)
|
||||
if self.to_resize_viewport:
|
||||
await self._page.set_viewport_size({"width": self.VIEWPORT_WIDTH, "height": self.VIEWPORT_HEIGHT})
|
||||
await self._page.add_init_script(
|
||||
path=os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js")
|
||||
)
|
||||
await self._page.goto(self.start_page)
|
||||
await self._page.wait_for_load_state()
|
||||
|
||||
# Prepare the debug directory -- which stores the screenshots generated throughout the process
|
||||
await self._set_debug_dir(self.debug_dir)
|
||||
self.did_lazy_init = True
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the browser and the page.
|
||||
Should be called when the agent is no longer needed.
|
||||
"""
|
||||
if self._page is not None:
|
||||
await self._page.close()
|
||||
self._page = None
|
||||
if self._context is not None:
|
||||
await self._context.close()
|
||||
self._context = None
|
||||
if self._playwright is not None:
|
||||
await self._playwright.stop()
|
||||
self._playwright = None
|
||||
|
||||
async def _set_debug_dir(self, debug_dir: str | None) -> None:
|
||||
assert self._page is not None
|
||||
if self.debug_dir is None:
|
||||
return
|
||||
|
||||
if not os.path.isdir(self.debug_dir):
|
||||
os.mkdir(self.debug_dir)
|
||||
|
||||
if self.to_save_screenshots:
|
||||
current_timestamp = "_" + int(time.time()).__str__()
|
||||
screenshot_png_name = "screenshot" + current_timestamp + ".png"
|
||||
|
||||
await self._page.screenshot(path=os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore
|
||||
self.logger.info(
|
||||
WebSurferEvent(
|
||||
source=self.name,
|
||||
url=self._page.url,
|
||||
message="Screenshot: " + screenshot_png_name,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (MultiModalMessage,)
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
if not self.did_lazy_init:
|
||||
return
|
||||
assert self._page is not None
|
||||
|
||||
self._chat_history.clear()
|
||||
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
|
||||
self._page, self.start_page
|
||||
)
|
||||
if reset_last_download and self._last_download is not None:
|
||||
self._last_download = None
|
||||
if reset_prior_metadata and self._prior_metadata_hash is not None:
|
||||
self._prior_metadata_hash = None
|
||||
if self.to_save_screenshots:
|
||||
current_timestamp = "_" + int(time.time()).__str__()
|
||||
screenshot_png_name = "screenshot" + current_timestamp + ".png"
|
||||
|
||||
await self._page.screenshot(path=os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore
|
||||
self.logger.info(
|
||||
WebSurferEvent(
|
||||
source=self.name,
|
||||
url=self._page.url,
|
||||
message="Screenshot: " + screenshot_png_name,
|
||||
)
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
WebSurferEvent(
|
||||
source=self.name,
|
||||
url=self._page.url,
|
||||
message="Resetting browser.",
|
||||
)
|
||||
)
|
||||
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
return message
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
for chat_message in messages:
|
||||
self._chat_history.append(chat_message.to_model_message())
|
||||
|
||||
self.inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
self.model_usage: List[RequestUsage] = []
|
||||
try:
|
||||
content = await self._generate_reply(cancellation_token=cancellation_token)
|
||||
self._chat_history.append(AssistantMessage(content=content_to_str(content), source=self.name))
|
||||
final_usage = RequestUsage(
|
||||
prompt_tokens=sum([u.prompt_tokens for u in self.model_usage]),
|
||||
completion_tokens=sum([u.completion_tokens for u in self.model_usage]),
|
||||
)
|
||||
if isinstance(content, str):
|
||||
yield Response(
|
||||
chat_message=TextMessage(content=content, source=self.name, models_usage=final_usage),
|
||||
inner_messages=self.inner_messages,
|
||||
)
|
||||
else:
|
||||
yield Response(
|
||||
chat_message=MultiModalMessage(content=content, source=self.name, models_usage=final_usage),
|
||||
inner_messages=self.inner_messages,
|
||||
)
|
||||
|
||||
except BaseException:
|
||||
content = f"Web surfing error:\n\n{traceback.format_exc()}"
|
||||
self._chat_history.append(AssistantMessage(content=content, source=self.name))
|
||||
yield Response(chat_message=TextMessage(content=content, source=self.name))
|
||||
|
||||
async def _generate_reply(self, cancellation_token: CancellationToken) -> UserContent:
|
||||
"""Generates the actual reply. First calls the LLM to figure out which tool to use, then executes the tool."""
|
||||
|
||||
# Lazy init, initialize the browser and the page on the first generate reply only
|
||||
if not self.did_lazy_init:
|
||||
await self._lazy_init()
|
||||
|
||||
assert self._page is not None
|
||||
|
||||
# Clone the messages, removing old screenshots
|
||||
history: List[LLMMessage] = remove_images(self._chat_history)
|
||||
|
||||
# Split the history, removing the last message
|
||||
if len(history):
|
||||
user_request = history.pop()
|
||||
else:
|
||||
user_request = UserMessage(content="Empty request.", source="user")
|
||||
|
||||
# Truncate the history for smaller models
|
||||
if self._model_client.model_info["family"] not in [
|
||||
ModelFamily.GPT_4O,
|
||||
ModelFamily.O1,
|
||||
ModelFamily.O3,
|
||||
ModelFamily.GPT_4,
|
||||
ModelFamily.GPT_35,
|
||||
]:
|
||||
history = []
|
||||
|
||||
# Ask the page for interactive elements, then prepare the state-of-mark screenshot
|
||||
rects = await self._playwright_controller.get_interactive_rects(self._page)
|
||||
viewport = await self._playwright_controller.get_visual_viewport(self._page)
|
||||
screenshot = await self._page.screenshot()
|
||||
som_screenshot, visible_rects, rects_above, rects_below = add_set_of_mark(screenshot, rects)
|
||||
|
||||
if self.to_save_screenshots:
|
||||
current_timestamp = "_" + int(time.time()).__str__()
|
||||
screenshot_png_name = "screenshot_som" + current_timestamp + ".png"
|
||||
som_screenshot.save(os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore
|
||||
self.logger.info(
|
||||
WebSurferEvent(
|
||||
source=self.name,
|
||||
url=self._page.url,
|
||||
message="Screenshot: " + screenshot_png_name,
|
||||
)
|
||||
)
|
||||
# What tools are available?
|
||||
tools = self.default_tools.copy()
|
||||
|
||||
# We can scroll up
|
||||
if viewport["pageTop"] > 5:
|
||||
tools.append(TOOL_SCROLL_UP)
|
||||
|
||||
# Can scroll down
|
||||
if (viewport["pageTop"] + viewport["height"] + 5) < viewport["scrollHeight"]:
|
||||
tools.append(TOOL_SCROLL_DOWN)
|
||||
|
||||
# Focus hint
|
||||
focused = await self._playwright_controller.get_focused_rect_id(self._page)
|
||||
focused_hint = ""
|
||||
if focused:
|
||||
name = self._target_name(focused, rects)
|
||||
if name:
|
||||
name = f"(and name '{name}') "
|
||||
else:
|
||||
name = ""
|
||||
|
||||
role = "control"
|
||||
try:
|
||||
role = rects[focused]["role"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
focused_hint = f"\nThe {role} with ID {focused} {name}currently has the input focus.\n\n"
|
||||
|
||||
# Everything visible
|
||||
visible_targets = "\n".join(self._format_target_list(visible_rects, rects)) + "\n\n"
|
||||
|
||||
# Everything else
|
||||
other_targets: List[str] = []
|
||||
other_targets.extend(self._format_target_list(rects_above, rects))
|
||||
other_targets.extend(self._format_target_list(rects_below, rects))
|
||||
|
||||
if len(other_targets) > 0:
|
||||
if len(other_targets) > 30:
|
||||
other_targets = other_targets[0:30]
|
||||
other_targets.append("...")
|
||||
other_targets_str = (
|
||||
"Additional valid interaction targets include (but are not limited to):\n"
|
||||
+ "\n".join(other_targets)
|
||||
+ "\n\n"
|
||||
)
|
||||
else:
|
||||
other_targets_str = ""
|
||||
|
||||
state_description = "Your " + await self._get_state_description()
|
||||
tool_names = "\n".join([t["name"] for t in tools])
|
||||
page_title = await self._page.title()
|
||||
|
||||
prompt_message = None
|
||||
if self._model_client.model_info["vision"]:
|
||||
text_prompt = WEB_SURFER_TOOL_PROMPT_MM.format(
|
||||
state_description=state_description,
|
||||
visible_targets=visible_targets,
|
||||
other_targets_str=other_targets_str,
|
||||
focused_hint=focused_hint,
|
||||
tool_names=tool_names,
|
||||
title=page_title,
|
||||
url=self._page.url,
|
||||
).strip()
|
||||
|
||||
# Scale the screenshot for the MLM, and close the original
|
||||
scaled_screenshot = som_screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
|
||||
som_screenshot.close()
|
||||
if self.to_save_screenshots:
|
||||
scaled_screenshot.save(os.path.join(self.debug_dir, "screenshot_scaled.png")) # type: ignore
|
||||
|
||||
# Create the message
|
||||
prompt_message = UserMessage(
|
||||
content=[re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), AGImage.from_pil(scaled_screenshot)],
|
||||
source=self.name,
|
||||
)
|
||||
else:
|
||||
text_prompt = WEB_SURFER_TOOL_PROMPT_TEXT.format(
|
||||
state_description=state_description,
|
||||
visible_targets=visible_targets,
|
||||
other_targets_str=other_targets_str,
|
||||
focused_hint=focused_hint,
|
||||
tool_names=tool_names,
|
||||
title=page_title,
|
||||
url=self._page.url,
|
||||
).strip()
|
||||
|
||||
# Create the message
|
||||
prompt_message = UserMessage(content=re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), source=self.name)
|
||||
|
||||
history.append(prompt_message)
|
||||
history.append(user_request)
|
||||
|
||||
# {history[-2].content if isinstance(history[-2].content, str) else history[-2].content[0]}
|
||||
# print(f"""
|
||||
# ================={len(history)}=================
|
||||
# {history[-2].content}
|
||||
# =====
|
||||
# {history[-1].content}
|
||||
# ===================================================
|
||||
# """)
|
||||
|
||||
# Make the request
|
||||
response = await self._model_client.create(
|
||||
history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token
|
||||
) # , "parallel_tool_calls": False})
|
||||
|
||||
self.model_usage.append(response.usage)
|
||||
message = response.content
|
||||
self._last_download = None
|
||||
if isinstance(message, str):
|
||||
# Answer directly
|
||||
self.inner_messages.append(TextMessage(content=message, source=self.name))
|
||||
return message
|
||||
elif isinstance(message, list):
|
||||
# Take an action
|
||||
return await self._execute_tool(message, rects, tool_names, cancellation_token=cancellation_token)
|
||||
else:
|
||||
# Not sure what happened here
|
||||
raise AssertionError(f"Unknown response format '{message}'")
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
message: List[FunctionCall],
|
||||
rects: Dict[str, InteractiveRegion],
|
||||
tool_names: str,
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> UserContent:
|
||||
# Execute the tool
|
||||
name = message[0].name
|
||||
args = json.loads(message[0].arguments)
|
||||
action_description = ""
|
||||
assert self._page is not None
|
||||
self.logger.info(
|
||||
WebSurferEvent(
|
||||
source=self.name,
|
||||
url=self._page.url,
|
||||
action=name,
|
||||
arguments=args,
|
||||
message=f"{name}( {json.dumps(args)} )",
|
||||
)
|
||||
)
|
||||
self.inner_messages.append(TextMessage(content=f"{name}( {json.dumps(args)} )", source=self.name))
|
||||
|
||||
if name == "visit_url":
|
||||
url = args.get("url")
|
||||
action_description = f"I typed '{url}' into the browser address bar."
|
||||
# Check if the argument starts with a known protocol
|
||||
if url.startswith(("https://", "http://", "file://", "about:")):
|
||||
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
|
||||
self._page, url
|
||||
)
|
||||
# If the argument contains a space, treat it as a search query
|
||||
elif " " in url:
|
||||
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
|
||||
self._page, f"https://www.bing.com/search?q={quote_plus(url)}&FORM=QBLH"
|
||||
)
|
||||
# Otherwise, prefix with https://
|
||||
else:
|
||||
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
|
||||
self._page, "https://" + url
|
||||
)
|
||||
if reset_last_download and self._last_download is not None:
|
||||
self._last_download = None
|
||||
if reset_prior_metadata and self._prior_metadata_hash is not None:
|
||||
self._prior_metadata_hash = None
|
||||
elif name == "history_back":
|
||||
action_description = "I clicked the browser back button."
|
||||
await self._playwright_controller.back(self._page)
|
||||
|
||||
elif name == "web_search":
|
||||
query = args.get("query")
|
||||
action_description = f"I typed '{query}' into the browser search bar."
|
||||
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
|
||||
self._page, f"https://www.bing.com/search?q={quote_plus(query)}&FORM=QBLH"
|
||||
)
|
||||
if reset_last_download and self._last_download is not None:
|
||||
self._last_download = None
|
||||
if reset_prior_metadata and self._prior_metadata_hash is not None:
|
||||
self._prior_metadata_hash = None
|
||||
elif name == "scroll_up":
|
||||
action_description = "I scrolled up one page in the browser."
|
||||
await self._playwright_controller.page_up(self._page)
|
||||
elif name == "scroll_down":
|
||||
action_description = "I scrolled down one page in the browser."
|
||||
await self._playwright_controller.page_down(self._page)
|
||||
|
||||
elif name == "click":
|
||||
target_id = str(args.get("target_id"))
|
||||
target_name = self._target_name(target_id, rects)
|
||||
if target_name:
|
||||
action_description = f"I clicked '{target_name}'."
|
||||
else:
|
||||
action_description = "I clicked the control."
|
||||
new_page_tentative = await self._playwright_controller.click_id(self._page, target_id)
|
||||
if new_page_tentative is not None:
|
||||
self._page = new_page_tentative
|
||||
self._prior_metadata_hash = None
|
||||
self.logger.info(
|
||||
WebSurferEvent(
|
||||
source=self.name,
|
||||
url=self._page.url,
|
||||
message="New tab or window.",
|
||||
)
|
||||
)
|
||||
elif name == "input_text":
|
||||
input_field_id = str(args.get("input_field_id"))
|
||||
text_value = str(args.get("text_value"))
|
||||
input_field_name = self._target_name(input_field_id, rects)
|
||||
if input_field_name:
|
||||
action_description = f"I typed '{text_value}' into '{input_field_name}'."
|
||||
else:
|
||||
action_description = f"I input '{text_value}'."
|
||||
await self._playwright_controller.fill_id(self._page, input_field_id, text_value)
|
||||
|
||||
elif name == "scroll_element_up":
|
||||
target_id = str(args.get("target_id"))
|
||||
target_name = self._target_name(target_id, rects)
|
||||
|
||||
if target_name:
|
||||
action_description = f"I scrolled '{target_name}' up."
|
||||
else:
|
||||
action_description = "I scrolled the control up."
|
||||
|
||||
await self._playwright_controller.scroll_id(self._page, target_id, "up")
|
||||
|
||||
elif name == "scroll_element_down":
|
||||
target_id = str(args.get("target_id"))
|
||||
target_name = self._target_name(target_id, rects)
|
||||
|
||||
if target_name:
|
||||
action_description = f"I scrolled '{target_name}' down."
|
||||
else:
|
||||
action_description = "I scrolled the control down."
|
||||
|
||||
await self._playwright_controller.scroll_id(self._page, target_id, "down")
|
||||
|
||||
elif name == "answer_question":
|
||||
question = str(args.get("question"))
|
||||
action_description = f"I answered the following question '{question}' based on the web page."
|
||||
# Do Q&A on the DOM. No need to take further action. Browser state does not change.
|
||||
return await self._summarize_page(question=question, cancellation_token=cancellation_token)
|
||||
elif name == "summarize_page":
|
||||
# Summarize the DOM. No need to take further action. Browser state does not change.
|
||||
action_description = "I summarized the current web page"
|
||||
return await self._summarize_page(cancellation_token=cancellation_token)
|
||||
|
||||
elif name == "hover":
|
||||
target_id = str(args.get("target_id"))
|
||||
target_name = self._target_name(target_id, rects)
|
||||
if target_name:
|
||||
action_description = f"I hovered over '{target_name}'."
|
||||
else:
|
||||
action_description = "I hovered over the control."
|
||||
await self._playwright_controller.hover_id(self._page, target_id)
|
||||
|
||||
elif name == "sleep":
|
||||
action_description = "I am waiting a short period of time before taking further action."
|
||||
await self._playwright_controller.sleep(self._page, 3)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown tool '{name}'. Please choose from:\n\n{tool_names}")
|
||||
|
||||
await self._page.wait_for_load_state()
|
||||
await self._playwright_controller.sleep(self._page, 3)
|
||||
|
||||
# Handle downloads
|
||||
if self._last_download is not None and self.downloads_folder is not None:
|
||||
fname = os.path.join(self.downloads_folder, self._last_download.suggested_filename)
|
||||
await self._last_download.save_as(fname) # type: ignore
|
||||
page_body = f"<html><head><title>Download Successful</title></head><body style=\"margin: 20px;\"><h1>Successfully downloaded '{self._last_download.suggested_filename}' to local path:<br><br>{fname}</h1></body></html>"
|
||||
await self._page.goto(
|
||||
"data:text/html;base64," + base64.b64encode(page_body.encode("utf-8")).decode("utf-8")
|
||||
)
|
||||
await self._page.wait_for_load_state()
|
||||
|
||||
# Handle metadata
|
||||
page_metadata = json.dumps(await self._playwright_controller.get_page_metadata(self._page), indent=4)
|
||||
metadata_hash = hashlib.md5(page_metadata.encode("utf-8")).hexdigest()
|
||||
if metadata_hash != self._prior_metadata_hash:
|
||||
page_metadata = (
|
||||
"\n\nThe following metadata was extracted from the webpage:\n\n" + page_metadata.strip() + "\n"
|
||||
)
|
||||
else:
|
||||
page_metadata = ""
|
||||
self._prior_metadata_hash = metadata_hash
|
||||
|
||||
new_screenshot = await self._page.screenshot()
|
||||
if self.to_save_screenshots:
|
||||
current_timestamp = "_" + int(time.time()).__str__()
|
||||
screenshot_png_name = "screenshot" + current_timestamp + ".png"
|
||||
|
||||
async with aiofiles.open(os.path.join(self.debug_dir, screenshot_png_name), "wb") as file: # type: ignore
|
||||
await file.write(new_screenshot) # type: ignore
|
||||
self.logger.info(
|
||||
WebSurferEvent(
|
||||
source=self.name,
|
||||
url=self._page.url,
|
||||
message="Screenshot: " + screenshot_png_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Return the complete observation
|
||||
state_description = "The " + await self._get_state_description()
|
||||
message_content = (
|
||||
f"{action_description}\n\n" + state_description + page_metadata + "\nHere is a screenshot of the page."
|
||||
)
|
||||
|
||||
return [
|
||||
re.sub(r"(\n\s*){3,}", "\n\n", message_content), # Removing blank lines
|
||||
AGImage.from_pil(PIL.Image.open(io.BytesIO(new_screenshot))),
|
||||
]
|
||||
|
||||
async def _get_state_description(self) -> str:
|
||||
assert self._playwright_controller is not None
|
||||
assert self._page is not None
|
||||
|
||||
# Describe the viewport of the new page in words
|
||||
viewport = await self._playwright_controller.get_visual_viewport(self._page)
|
||||
percent_visible = int(viewport["height"] * 100 / viewport["scrollHeight"])
|
||||
percent_scrolled = int(viewport["pageTop"] * 100 / viewport["scrollHeight"])
|
||||
if percent_scrolled < 1: # Allow some rounding error
|
||||
position_text = "at the top of the page"
|
||||
elif percent_scrolled + percent_visible >= 99: # Allow some rounding error
|
||||
position_text = "at the bottom of the page"
|
||||
else:
|
||||
position_text = str(percent_scrolled) + "% down from the top of the page"
|
||||
|
||||
visible_text = await self._playwright_controller.get_visible_text(self._page)
|
||||
|
||||
# Return the complete observation
|
||||
page_title = await self._page.title()
|
||||
message_content = f"web browser is open to the page [{page_title}]({self._page.url}).\nThe viewport shows {percent_visible}% of the webpage, and is positioned {position_text}\n"
|
||||
message_content += f"The following text is visible in the viewport:\n\n{visible_text}"
|
||||
return message_content
|
||||
|
||||
def _target_name(self, target: str, rects: Dict[str, InteractiveRegion]) -> str | None:
|
||||
try:
|
||||
return rects[target]["aria_name"].strip()
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def _format_target_list(self, ids: List[str], rects: Dict[str, InteractiveRegion]) -> List[str]:
|
||||
"""
|
||||
Format the list of targets in the webpage as a string to be used in the agent's prompt.
|
||||
"""
|
||||
targets: List[str] = []
|
||||
for r in list(set(ids)):
|
||||
if r in rects:
|
||||
# Get the role
|
||||
aria_role = rects[r].get("role", "").strip()
|
||||
if len(aria_role) == 0:
|
||||
aria_role = rects[r].get("tag_name", "").strip()
|
||||
|
||||
# Get the name
|
||||
aria_name = re.sub(r"[\n\r]+", " ", rects[r].get("aria_name", "")).strip()
|
||||
|
||||
# What are the actions?
|
||||
actions = ['"click", "hover"']
|
||||
if rects[r]["role"] in ["textbox", "searchbox", "search"]:
|
||||
actions = ['"input_text"']
|
||||
actions_str = "[" + ",".join(actions) + "]"
|
||||
|
||||
targets.append(f'{{"id": {r}, "name": "{aria_name}", "role": "{aria_role}", "tools": {actions_str} }}')
|
||||
|
||||
return targets
|
||||
|
||||
async def _summarize_page(
|
||||
self,
|
||||
question: str | None = None,
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> str:
|
||||
assert self._page is not None
|
||||
|
||||
page_markdown: str = await self._playwright_controller.get_page_markdown(self._page)
|
||||
|
||||
title: str = self._page.url
|
||||
try:
|
||||
title = await self._page.title()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Take a screenshot and scale it
|
||||
screenshot = Image.open(io.BytesIO(await self._page.screenshot()))
|
||||
scaled_screenshot = screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
|
||||
screenshot.close()
|
||||
ag_image = AGImage.from_pil(scaled_screenshot)
|
||||
|
||||
# Prepare the system prompt
|
||||
messages: List[LLMMessage] = []
|
||||
messages.append(SystemMessage(content=WEB_SURFER_QA_SYSTEM_MESSAGE))
|
||||
prompt = WEB_SURFER_QA_PROMPT(title, question)
|
||||
# Grow the buffer (which is added to the prompt) until we overflow the context window or run out of lines
|
||||
buffer = ""
|
||||
# for line in re.split(r"([\r\n]+)", page_markdown):
|
||||
for line in page_markdown.splitlines():
|
||||
trial_message = UserMessage(
|
||||
content=prompt + buffer + line,
|
||||
source=self.name,
|
||||
)
|
||||
|
||||
try:
|
||||
remaining = self._model_client.remaining_tokens(messages + [trial_message])
|
||||
except KeyError:
|
||||
# Use the default if the model isn't found
|
||||
remaining = DEFAULT_CONTEXT_SIZE - self._model_client.count_tokens(messages + [trial_message])
|
||||
|
||||
if self._model_client.model_info["vision"] and remaining <= 0:
|
||||
break
|
||||
|
||||
if self._model_client.model_info["vision"] and remaining <= self.SCREENSHOT_TOKENS:
|
||||
break
|
||||
|
||||
buffer += line
|
||||
|
||||
# Nothing to do
|
||||
buffer = buffer.strip()
|
||||
if len(buffer) == 0:
|
||||
return "Nothing to summarize."
|
||||
|
||||
# Append the message
|
||||
if self._model_client.model_info["vision"]:
|
||||
# Multimodal
|
||||
messages.append(
|
||||
UserMessage(
|
||||
content=[
|
||||
prompt + buffer,
|
||||
ag_image,
|
||||
],
|
||||
source=self.name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Text only
|
||||
messages.append(
|
||||
UserMessage(
|
||||
content=prompt + buffer,
|
||||
source=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
# Generate the response
|
||||
response = await self._model_client.create(messages, cancellation_token=cancellation_token)
|
||||
self.model_usage.append(response.usage)
|
||||
scaled_screenshot.close()
|
||||
assert isinstance(response.content, str)
|
||||
return response.content
|
||||
|
||||
def _to_config(self) -> MultimodalWebSurferConfig:
|
||||
return MultimodalWebSurferConfig(
|
||||
name=self.name,
|
||||
model_client=self._model_client.dump_component(),
|
||||
downloads_folder=self.downloads_folder,
|
||||
description=self.description,
|
||||
debug_dir=self.debug_dir,
|
||||
headless=self.headless,
|
||||
start_page=self.start_page,
|
||||
animate_actions=self.animate_actions,
|
||||
to_save_screenshots=self.to_save_screenshots,
|
||||
use_ocr=self.use_ocr,
|
||||
browser_channel=self.browser_channel,
|
||||
browser_data_dir=self.browser_data_dir,
|
||||
to_resize_viewport=self.to_resize_viewport,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: MultimodalWebSurferConfig) -> Self:
|
||||
return cls(
|
||||
name=config.name,
|
||||
model_client=ChatCompletionClient.load_component(config.model_client),
|
||||
downloads_folder=config.downloads_folder,
|
||||
description=config.description or cls.DEFAULT_DESCRIPTION,
|
||||
debug_dir=config.debug_dir,
|
||||
headless=config.headless,
|
||||
start_page=config.start_page or cls.DEFAULT_START_PAGE,
|
||||
animate_actions=config.animate_actions,
|
||||
to_save_screenshots=config.to_save_screenshots,
|
||||
use_ocr=config.use_ocr,
|
||||
browser_channel=config.browser_channel,
|
||||
browser_data_dir=config.browser_data_dir,
|
||||
to_resize_viewport=config.to_resize_viewport,
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
WEB_SURFER_TOOL_PROMPT_MM = """
|
||||
{state_description}
|
||||
|
||||
Consider the following screenshot of the page. In this screenshot, interactive elements are outlined in bounding boxes of different colors. Each bounding box has a numeric ID label in the same color. Additional information about each visible label is listed below:
|
||||
|
||||
{visible_targets}{other_targets_str}{focused_hint}
|
||||
|
||||
You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible:
|
||||
|
||||
{tool_names}
|
||||
|
||||
When deciding between tools, consider if the request can be best addressed by:
|
||||
- the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate)
|
||||
- contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate
|
||||
- on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option)
|
||||
|
||||
My request follows:
|
||||
"""
|
||||
|
||||
WEB_SURFER_TOOL_PROMPT_TEXT = """
|
||||
{state_description}
|
||||
|
||||
You have also identified the following interactive components:
|
||||
|
||||
{visible_targets}{other_targets_str}{focused_hint}
|
||||
|
||||
You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible:
|
||||
|
||||
{tool_names}
|
||||
|
||||
When deciding between tools, consider if the request can be best addressed by:
|
||||
- the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate)
|
||||
- contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate
|
||||
- on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option)
|
||||
|
||||
My request follows:
|
||||
"""
|
||||
|
||||
|
||||
WEB_SURFER_QA_SYSTEM_MESSAGE = """
|
||||
You are a helpful assistant that can summarize long documents to answer question.
|
||||
"""
|
||||
|
||||
|
||||
def WEB_SURFER_QA_PROMPT(title: str, question: str | None = None) -> str:
|
||||
base_prompt = f"We are visiting the webpage '{title}'. Its full-text content are pasted below, along with a screenshot of the page's current viewport."
|
||||
if question is not None:
|
||||
return (
|
||||
f"{base_prompt} Please summarize the webpage into one or two paragraphs with respect to '{question}':\n\n"
|
||||
)
|
||||
else:
|
||||
return f"{base_prompt} Please summarize the webpage into one or two paragraphs:\n\n"
|
||||
@@ -0,0 +1,96 @@
|
||||
import io
|
||||
import random
|
||||
from typing import BinaryIO, Dict, List, Tuple, cast
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from ._types import DOMRectangle, InteractiveRegion
|
||||
|
||||
TOP_NO_LABEL_ZONE = 20 # Don't print any labels close the top of the page
|
||||
|
||||
|
||||
def add_set_of_mark(
|
||||
screenshot: bytes | Image.Image | io.BufferedIOBase, ROIs: Dict[str, InteractiveRegion]
|
||||
) -> Tuple[Image.Image, List[str], List[str], List[str]]:
|
||||
if isinstance(screenshot, Image.Image):
|
||||
return _add_set_of_mark(screenshot, ROIs)
|
||||
|
||||
if isinstance(screenshot, bytes):
|
||||
screenshot = io.BytesIO(screenshot)
|
||||
|
||||
# TODO: Not sure why this cast was needed, but by this point screenshot is a binary file-like object
|
||||
image = Image.open(cast(BinaryIO, screenshot))
|
||||
comp, visible_rects, rects_above, rects_below = _add_set_of_mark(image, ROIs)
|
||||
image.close()
|
||||
return comp, visible_rects, rects_above, rects_below
|
||||
|
||||
|
||||
def _add_set_of_mark(
|
||||
screenshot: Image.Image, ROIs: Dict[str, InteractiveRegion]
|
||||
) -> Tuple[Image.Image, List[str], List[str], List[str]]:
|
||||
visible_rects: List[str] = list()
|
||||
rects_above: List[str] = list() # Scroll up to see
|
||||
rects_below: List[str] = list() # Scroll down to see
|
||||
|
||||
fnt = ImageFont.load_default(14)
|
||||
base = screenshot.convert("L").convert("RGBA")
|
||||
overlay = Image.new("RGBA", base.size)
|
||||
|
||||
draw = ImageDraw.Draw(overlay)
|
||||
for r in ROIs:
|
||||
for rect in ROIs[r]["rects"]:
|
||||
# Empty rectangles
|
||||
if not rect:
|
||||
continue
|
||||
if rect["width"] * rect["height"] == 0:
|
||||
continue
|
||||
|
||||
mid = ((rect["right"] + rect["left"]) / 2.0, (rect["top"] + rect["bottom"]) / 2.0)
|
||||
|
||||
if 0 <= mid[0] and mid[0] < base.size[0]:
|
||||
if mid[1] < 0:
|
||||
rects_above.append(r)
|
||||
elif mid[1] >= base.size[1]:
|
||||
rects_below.append(r)
|
||||
else:
|
||||
visible_rects.append(r)
|
||||
_draw_roi(draw, int(r), fnt, rect)
|
||||
|
||||
comp = Image.alpha_composite(base, overlay)
|
||||
overlay.close()
|
||||
return comp, visible_rects, rects_above, rects_below
|
||||
|
||||
|
||||
def _draw_roi(
|
||||
draw: ImageDraw.ImageDraw, idx: int, font: ImageFont.FreeTypeFont | ImageFont.ImageFont, rect: DOMRectangle
|
||||
) -> None:
|
||||
color = _color(idx)
|
||||
luminance = color[0] * 0.3 + color[1] * 0.59 + color[2] * 0.11
|
||||
text_color = (0, 0, 0, 255) if luminance > 90 else (255, 255, 255, 255)
|
||||
|
||||
roi = ((rect["left"], rect["top"]), (rect["right"], rect["bottom"]))
|
||||
|
||||
label_location = (rect["right"], rect["top"])
|
||||
label_anchor = "rb"
|
||||
|
||||
if label_location[1] <= TOP_NO_LABEL_ZONE:
|
||||
label_location = (rect["right"], rect["bottom"])
|
||||
label_anchor = "rt"
|
||||
|
||||
draw.rectangle(roi, outline=color, fill=(color[0], color[1], color[2], 48), width=2)
|
||||
|
||||
# TODO: Having trouble with these types being partially Unknown.
|
||||
bbox = draw.textbbox(label_location, str(idx), font=font, anchor=label_anchor, align="center") # type: ignore
|
||||
bbox = (bbox[0] - 3, bbox[1] - 3, bbox[2] + 3, bbox[3] + 3)
|
||||
draw.rectangle(bbox, fill=color)
|
||||
|
||||
# TODO: Having trouble with these types being partially Unknown.
|
||||
draw.text(label_location, str(idx), fill=text_color, font=font, anchor=label_anchor, align="center") # type: ignore
|
||||
|
||||
|
||||
def _color(identifier: int) -> Tuple[int, int, int, int]:
|
||||
rnd = random.Random(int(identifier))
|
||||
color = [rnd.randint(0, 255), rnd.randint(125, 255), rnd.randint(0, 50)]
|
||||
rnd.shuffle(color)
|
||||
color.append(255)
|
||||
return cast(Tuple[int, int, int, int], tuple(color))
|
||||
@@ -0,0 +1,317 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from agentdhal_core.tools._base import ParametersSchema, ToolSchema
|
||||
|
||||
|
||||
def _load_tool(tooldef: Dict[str, Any]) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name=tooldef["function"]["name"],
|
||||
description=tooldef["function"]["description"],
|
||||
parameters=ParametersSchema(
|
||||
type="object",
|
||||
properties=tooldef["function"]["parameters"]["properties"],
|
||||
required=tooldef["function"]["parameters"]["required"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
REASONING_TOOL_PROMPT = (
|
||||
"A short description of the action to be performed and reason for doing so, do not mention the user."
|
||||
)
|
||||
|
||||
TOOL_VISIT_URL: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "visit_url",
|
||||
"description": "Navigate directly to a provided URL using the browser's address bar. Prefer this tool over other navigation techniques in cases where the user provides a fully-qualified URL (e.g., choose it over clicking links, or inputing queries into search boxes).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to visit in the browser.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "url"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_WEB_SEARCH: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Performs a web search on Bing.com with the given query.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The web search query to use.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_HISTORY_BACK: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "history_back",
|
||||
"description": "Navigates back one page in the browser's history. This is equivalent to clicking the browser back button.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
},
|
||||
"required": ["reasoning"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_SCROLL_UP: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "scroll_up",
|
||||
"description": "Scrolls the entire browser viewport one page UP towards the beginning.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
},
|
||||
"required": ["reasoning"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_SCROLL_DOWN: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "scroll_down",
|
||||
"description": "Scrolls the entire browser viewport one page DOWN towards the end.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
},
|
||||
"required": ["reasoning"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_CLICK: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "click",
|
||||
"description": "Clicks the mouse on the target with the given id.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"target_id": {
|
||||
"type": "integer",
|
||||
"description": "The numeric id of the target to click.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "target_id"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_TYPE: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "input_text",
|
||||
"description": "Types the given text value into the specified field.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"input_field_id": {
|
||||
"type": "integer",
|
||||
"description": "The numeric id of the input field to receive the text.",
|
||||
},
|
||||
"text_value": {
|
||||
"type": "string",
|
||||
"description": "The text to type into the input field.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "input_field_id", "text_value"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_SCROLL_ELEMENT_DOWN: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "scroll_element_down",
|
||||
"description": "Scrolls a given html element (e.g., a div or a menu) DOWN.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"target_id": {
|
||||
"type": "integer",
|
||||
"description": "The numeric id of the target to scroll down.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "target_id"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_SCROLL_ELEMENT_UP: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "scroll_element_up",
|
||||
"description": "Scrolls a given html element (e.g., a div or a menu) UP.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"target_id": {
|
||||
"type": "integer",
|
||||
"description": "The numeric id of the target to scroll UP.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "target_id"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_HOVER: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "hover",
|
||||
"description": "Hovers the mouse over the target with the given id.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"target_id": {
|
||||
"type": "integer",
|
||||
"description": "The numeric id of the target to hover over.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "target_id"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
TOOL_READ_PAGE_AND_ANSWER: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "answer_question",
|
||||
"description": "Uses AI to answer a question about the current webpage's content.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question to answer.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "question"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_SUMMARIZE_PAGE: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "summarize_page",
|
||||
"description": "Uses AI to summarize the entire page.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
},
|
||||
"required": ["reasoning"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_SLEEP: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "sleep",
|
||||
"description": "Wait a short period of time. Call this function if the page has not yet fully loaded, or if it is determined that a small delay would increase the task's chances of success.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
},
|
||||
"required": ["reasoning"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
106
agent_dhal/agentdhal_extensions/agents/web_surfer/_types.py
Normal file
106
agent_dhal/agentdhal_extensions/agents/web_surfer/_types.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from typing import Any, Dict, List, TypedDict, Union
|
||||
|
||||
from agentdhal_core import FunctionCall, Image
|
||||
from agentdhal_core.models import FunctionExecutionResult
|
||||
|
||||
UserContent = Union[str, List[Union[str, Image]]]
|
||||
AssistantContent = Union[str, List[FunctionCall]]
|
||||
FunctionExecutionContent = List[FunctionExecutionResult]
|
||||
SystemContent = str
|
||||
|
||||
|
||||
class DOMRectangle(TypedDict):
|
||||
x: Union[int, float]
|
||||
y: Union[int, float]
|
||||
width: Union[int, float]
|
||||
height: Union[int, float]
|
||||
top: Union[int, float]
|
||||
right: Union[int, float]
|
||||
bottom: Union[int, float]
|
||||
left: Union[int, float]
|
||||
|
||||
|
||||
class VisualViewport(TypedDict):
|
||||
height: Union[int, float]
|
||||
width: Union[int, float]
|
||||
offsetLeft: Union[int, float]
|
||||
offsetTop: Union[int, float]
|
||||
pageLeft: Union[int, float]
|
||||
pageTop: Union[int, float]
|
||||
scale: Union[int, float]
|
||||
clientWidth: Union[int, float]
|
||||
clientHeight: Union[int, float]
|
||||
scrollWidth: Union[int, float]
|
||||
scrollHeight: Union[int, float]
|
||||
|
||||
|
||||
class InteractiveRegion(TypedDict):
|
||||
tag_name: str
|
||||
role: str
|
||||
aria_name: str
|
||||
v_scrollable: bool
|
||||
rects: List[DOMRectangle]
|
||||
|
||||
|
||||
# Helper functions for dealing with JSON. Not sure there's a better way?
|
||||
|
||||
|
||||
def _get_str(d: Any, k: str) -> str:
|
||||
val = d[k]
|
||||
assert isinstance(val, str)
|
||||
return val
|
||||
|
||||
|
||||
def _get_number(d: Any, k: str) -> Union[int, float]:
|
||||
val = d[k]
|
||||
assert isinstance(val, int) or isinstance(val, float)
|
||||
return val
|
||||
|
||||
|
||||
def _get_bool(d: Any, k: str) -> bool:
|
||||
val = d[k]
|
||||
assert isinstance(val, bool)
|
||||
return val
|
||||
|
||||
|
||||
def domrectangle_from_dict(rect: Dict[str, Any]) -> DOMRectangle:
|
||||
return DOMRectangle(
|
||||
x=_get_number(rect, "x"),
|
||||
y=_get_number(rect, "y"),
|
||||
width=_get_number(rect, "width"),
|
||||
height=_get_number(rect, "height"),
|
||||
top=_get_number(rect, "top"),
|
||||
right=_get_number(rect, "right"),
|
||||
bottom=_get_number(rect, "bottom"),
|
||||
left=_get_number(rect, "left"),
|
||||
)
|
||||
|
||||
|
||||
def interactiveregion_from_dict(region: Dict[str, Any]) -> InteractiveRegion:
|
||||
typed_rects: List[DOMRectangle] = []
|
||||
for rect in region["rects"]:
|
||||
typed_rects.append(domrectangle_from_dict(rect))
|
||||
|
||||
return InteractiveRegion(
|
||||
tag_name=_get_str(region, "tag_name"),
|
||||
role=_get_str(region, "role"),
|
||||
aria_name=_get_str(region, "aria-name"),
|
||||
v_scrollable=_get_bool(region, "v-scrollable"),
|
||||
rects=typed_rects,
|
||||
)
|
||||
|
||||
|
||||
def visualviewport_from_dict(viewport: Dict[str, Any]) -> VisualViewport:
|
||||
return VisualViewport(
|
||||
height=_get_number(viewport, "height"),
|
||||
width=_get_number(viewport, "width"),
|
||||
offsetLeft=_get_number(viewport, "offsetLeft"),
|
||||
offsetTop=_get_number(viewport, "offsetTop"),
|
||||
pageLeft=_get_number(viewport, "pageLeft"),
|
||||
pageTop=_get_number(viewport, "pageTop"),
|
||||
scale=_get_number(viewport, "scale"),
|
||||
clientWidth=_get_number(viewport, "clientWidth"),
|
||||
clientHeight=_get_number(viewport, "clientHeight"),
|
||||
scrollWidth=_get_number(viewport, "scrollWidth"),
|
||||
scrollHeight=_get_number(viewport, "scrollHeight"),
|
||||
)
|
||||
429
agent_dhal/agentdhal_extensions/agents/web_surfer/page_script.js
Normal file
429
agent_dhal/agentdhal_extensions/agents/web_surfer/page_script.js
Normal file
@@ -0,0 +1,429 @@
|
||||
var MultimodalWebSurfer = MultimodalWebSurfer || (function() {
|
||||
let nextLabel = 10;
|
||||
|
||||
let roleMapping = {
|
||||
"a": "link",
|
||||
"area": "link",
|
||||
"button": "button",
|
||||
"input, type=button": "button",
|
||||
"input, type=checkbox": "checkbox",
|
||||
"input, type=email": "textbox",
|
||||
"input, type=number": "spinbutton",
|
||||
"input, type=radio": "radio",
|
||||
"input, type=range": "slider",
|
||||
"input, type=reset": "button",
|
||||
"input, type=search": "searchbox",
|
||||
"input, type=submit": "button",
|
||||
"input, type=tel": "textbox",
|
||||
"input, type=text": "textbox",
|
||||
"input, type=url": "textbox",
|
||||
"search": "search",
|
||||
"select": "combobox",
|
||||
"option": "option",
|
||||
"textarea": "textbox"
|
||||
};
|
||||
|
||||
let getCursor = function(elm) {
|
||||
return window.getComputedStyle(elm)["cursor"];
|
||||
};
|
||||
|
||||
let getInteractiveElements = function() {
|
||||
|
||||
let results = []
|
||||
let roles = ["scrollbar", "searchbox", "slider", "spinbutton", "switch", "tab", "treeitem", "button", "checkbox", "gridcell", "link", "menuitem", "menuitemcheckbox", "menuitemradio", "option", "progressbar", "radio", "textbox", "combobox", "menu", "tree", "treegrid", "grid", "listbox", "radiogroup", "widget"];
|
||||
let inertCursors = ["auto", "default", "none", "text", "vertical-text", "not-allowed", "no-drop"];
|
||||
|
||||
// Get the main interactive elements
|
||||
let nodeList = document.querySelectorAll("input, select, textarea, button, [href], [onclick], [contenteditable], [tabindex]:not([tabindex='-1'])");
|
||||
for (let i=0; i<nodeList.length; i++) { // Copy to something mutable
|
||||
results.push(nodeList[i]);
|
||||
}
|
||||
|
||||
// Anything not already included that has a suitable role
|
||||
nodeList = document.querySelectorAll("[role]");
|
||||
for (let i=0; i<nodeList.length; i++) { // Copy to something mutable
|
||||
if (results.indexOf(nodeList[i]) == -1) {
|
||||
let role = nodeList[i].getAttribute("role");
|
||||
if (roles.indexOf(role) > -1) {
|
||||
results.push(nodeList[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Any element that changes the cursor to something implying interactivity
|
||||
nodeList = document.querySelectorAll("*");
|
||||
for (let i=0; i<nodeList.length; i++) {
|
||||
let node = nodeList[i];
|
||||
|
||||
// Cursor is default, or does not suggest interactivity
|
||||
let cursor = getCursor(node);
|
||||
if (inertCursors.indexOf(cursor) >= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Move up to the first instance of this cursor change
|
||||
parent = node.parentNode;
|
||||
while (parent && getCursor(parent) == cursor) {
|
||||
node = parent;
|
||||
parent = node.parentNode;
|
||||
}
|
||||
|
||||
// Add the node if it is new
|
||||
if (results.indexOf(node) == -1) {
|
||||
results.push(node);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
};
|
||||
|
||||
let labelElements = function(elements) {
|
||||
for (let i=0; i<elements.length; i++) {
|
||||
if (!elements[i].hasAttribute("__elementId")) {
|
||||
elements[i].setAttribute("__elementId", "" + (nextLabel++));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let isTopmost = function(element, x, y) {
|
||||
let hit = document.elementFromPoint(x, y);
|
||||
|
||||
// Hack to handle elements outside the viewport
|
||||
if (hit === null) {
|
||||
return true;
|
||||
}
|
||||
|
||||
while (hit) {
|
||||
if (hit == element) return true;
|
||||
hit = hit.parentNode;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
let getFocusedElementId = function() {
|
||||
let elm = document.activeElement;
|
||||
while (elm) {
|
||||
if (elm.hasAttribute && elm.hasAttribute("__elementId")) {
|
||||
return elm.getAttribute("__elementId");
|
||||
}
|
||||
elm = elm.parentNode;
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
let trimmedInnerText = function(element) {
|
||||
if (!element) {
|
||||
return "";
|
||||
}
|
||||
let text = element.innerText;
|
||||
if (!text) {
|
||||
return "";
|
||||
}
|
||||
return text.trim();
|
||||
};
|
||||
|
||||
let getApproximateAriaName = function(element) {
|
||||
// Check for aria labels
|
||||
if (element.hasAttribute("aria-labelledby")) {
|
||||
let buffer = "";
|
||||
let ids = element.getAttribute("aria-labelledby").split(" ");
|
||||
for (let i=0; i<ids.length; i++) {
|
||||
let label = document.getElementById(ids[i]);
|
||||
if (label) {
|
||||
buffer = buffer + " " + trimmedInnerText(label);
|
||||
}
|
||||
}
|
||||
return buffer.trim();
|
||||
}
|
||||
|
||||
if (element.hasAttribute("aria-label")) {
|
||||
return element.getAttribute("aria-label");
|
||||
}
|
||||
|
||||
// Check for labels
|
||||
if (element.hasAttribute("id")) {
|
||||
let label_id = element.getAttribute("id");
|
||||
let label = "";
|
||||
let labels = document.querySelectorAll("label[for='" + label_id + "']");
|
||||
for (let j=0; j<labels.length; j++) {
|
||||
label += labels[j].innerText + " ";
|
||||
}
|
||||
label = label.trim();
|
||||
if (label != "") {
|
||||
return label;
|
||||
}
|
||||
}
|
||||
|
||||
if (element.parentElement && element.parentElement.tagName == "LABEL") {
|
||||
return element.parentElement.innerText;
|
||||
}
|
||||
|
||||
// Check for alt text or titles
|
||||
if (element.hasAttribute("alt")) {
|
||||
return element.getAttribute("alt")
|
||||
}
|
||||
|
||||
if (element.hasAttribute("title")) {
|
||||
return element.getAttribute("title")
|
||||
}
|
||||
|
||||
return trimmedInnerText(element);
|
||||
};
|
||||
|
||||
let getApproximateAriaRole = function(element) {
|
||||
let tag = element.tagName.toLowerCase();
|
||||
if (tag == "input" && element.hasAttribute("type")) {
|
||||
tag = tag + ", type=" + element.getAttribute("type");
|
||||
}
|
||||
|
||||
if (element.hasAttribute("role")) {
|
||||
return [element.getAttribute("role"), tag];
|
||||
}
|
||||
else if (tag in roleMapping) {
|
||||
return [roleMapping[tag], tag];
|
||||
}
|
||||
else {
|
||||
return ["", tag];
|
||||
}
|
||||
};
|
||||
|
||||
let getInteractiveRects = function() {
|
||||
labelElements(getInteractiveElements());
|
||||
let elements = document.querySelectorAll("[__elementId]");
|
||||
let results = {};
|
||||
for (let i=0; i<elements.length; i++) {
|
||||
let key = elements[i].getAttribute("__elementId");
|
||||
let rects = elements[i].getClientRects();
|
||||
let ariaRole = getApproximateAriaRole(elements[i]);
|
||||
let ariaName = getApproximateAriaName(elements[i]);
|
||||
let vScrollable = elements[i].scrollHeight - elements[i].clientHeight >= 1;
|
||||
|
||||
let record = {
|
||||
"tag_name": ariaRole[1],
|
||||
"role": ariaRole[0],
|
||||
"aria-name": ariaName,
|
||||
"v-scrollable": vScrollable,
|
||||
"rects": []
|
||||
};
|
||||
|
||||
for (const rect of rects) {
|
||||
let x = rect.left + rect.width/2;
|
||||
let y = rect.top + rect.height/2;
|
||||
if (isTopmost(elements[i], x, y)) {
|
||||
record["rects"].push(JSON.parse(JSON.stringify(rect)));
|
||||
}
|
||||
}
|
||||
|
||||
if (record["rects"].length > 0) {
|
||||
results[key] = record;
|
||||
}
|
||||
}
|
||||
return results;
|
||||
};
|
||||
|
||||
let getVisualViewport = function() {
|
||||
let vv = window.visualViewport;
|
||||
let de = document.documentElement;
|
||||
return {
|
||||
"height": vv ? vv.height : 0,
|
||||
"width": vv ? vv.width : 0,
|
||||
"offsetLeft": vv ? vv.offsetLeft : 0,
|
||||
"offsetTop": vv ? vv.offsetTop : 0,
|
||||
"pageLeft": vv ? vv.pageLeft : 0,
|
||||
"pageTop": vv ? vv.pageTop : 0,
|
||||
"scale": vv ? vv.scale : 0,
|
||||
"clientWidth": de ? de.clientWidth : 0,
|
||||
"clientHeight": de ? de.clientHeight : 0,
|
||||
"scrollWidth": de ? de.scrollWidth : 0,
|
||||
"scrollHeight": de ? de.scrollHeight : 0
|
||||
};
|
||||
};
|
||||
|
||||
let _getMetaTags = function() {
|
||||
let meta = document.querySelectorAll("meta");
|
||||
let results = {};
|
||||
for (let i = 0; i<meta.length; i++) {
|
||||
let key = null;
|
||||
if (meta[i].hasAttribute("name")) {
|
||||
key = meta[i].getAttribute("name");
|
||||
}
|
||||
else if (meta[i].hasAttribute("property")) {
|
||||
key = meta[i].getAttribute("property");
|
||||
}
|
||||
else {
|
||||
continue;
|
||||
}
|
||||
if (meta[i].hasAttribute("content")) {
|
||||
results[key] = meta[i].getAttribute("content");
|
||||
}
|
||||
}
|
||||
return results;
|
||||
};
|
||||
|
||||
let _getJsonLd = function() {
|
||||
let jsonld = [];
|
||||
let scripts = document.querySelectorAll('script[type="application/ld+json"]');
|
||||
for (let i=0; i<scripts.length; i++) {
|
||||
jsonld.push(scripts[i].innerHTML.trim());
|
||||
}
|
||||
return jsonld;
|
||||
};
|
||||
|
||||
// From: https://www.stevefenton.co.uk/blog/2022/12/parse-microdata-with-javascript/
|
||||
let _getMicrodata = function() {
|
||||
function sanitize(input) {
|
||||
return input.replace(/\s/gi, ' ').trim();
|
||||
}
|
||||
|
||||
function addValue(information, name, value) {
|
||||
if (information[name]) {
|
||||
if (typeof information[name] === 'array') {
|
||||
information[name].push(value);
|
||||
} else {
|
||||
const arr = [];
|
||||
arr.push(information[name]);
|
||||
arr.push(value);
|
||||
information[name] = arr;
|
||||
}
|
||||
} else {
|
||||
information[name] = value;
|
||||
}
|
||||
}
|
||||
|
||||
function traverseItem(item, information) {
|
||||
const children = item.children;
|
||||
|
||||
for (let i = 0; i < children.length; i++) {
|
||||
const child = children[i];
|
||||
|
||||
if (child.hasAttribute('itemscope')) {
|
||||
if (child.hasAttribute('itemprop')) {
|
||||
const itemProp = child.getAttribute('itemprop');
|
||||
const itemType = child.getAttribute('itemtype');
|
||||
|
||||
const childInfo = {
|
||||
itemType: itemType
|
||||
};
|
||||
|
||||
traverseItem(child, childInfo);
|
||||
|
||||
itemProp.split(' ').forEach(propName => {
|
||||
addValue(information, propName, childInfo);
|
||||
});
|
||||
}
|
||||
|
||||
} else if (child.hasAttribute('itemprop')) {
|
||||
const itemProp = child.getAttribute('itemprop');
|
||||
itemProp.split(' ').forEach(propName => {
|
||||
if (propName === 'url') {
|
||||
addValue(information, propName, child.href);
|
||||
} else {
|
||||
addValue(information, propName, sanitize(child.getAttribute("content") || child.content || child.textContent || child.src || ""));
|
||||
}
|
||||
});
|
||||
traverseItem(child, information);
|
||||
} else {
|
||||
traverseItem(child, information);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const microdata = [];
|
||||
|
||||
document.querySelectorAll("[itemscope]").forEach(function(elem, i) {
|
||||
const itemType = elem.getAttribute('itemtype');
|
||||
const information = {
|
||||
itemType: itemType
|
||||
};
|
||||
traverseItem(elem, information);
|
||||
microdata.push(information);
|
||||
});
|
||||
|
||||
return microdata;
|
||||
};
|
||||
|
||||
let getPageMetadata = function() {
|
||||
let jsonld = _getJsonLd();
|
||||
let metaTags = _getMetaTags();
|
||||
let microdata = _getMicrodata();
|
||||
let results = {}
|
||||
if (jsonld.length > 0) {
|
||||
try {
|
||||
results["jsonld"] = JSON.parse(jsonld);
|
||||
}
|
||||
catch (e) {
|
||||
results["jsonld"] = jsonld;
|
||||
}
|
||||
}
|
||||
if (microdata.length > 0) {
|
||||
results["microdata"] = microdata;
|
||||
}
|
||||
for (let key in metaTags) {
|
||||
if (metaTags.hasOwnProperty(key)) {
|
||||
results["meta_tags"] = metaTags;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return results;
|
||||
};
|
||||
|
||||
|
||||
let getVisibleText = function() {
|
||||
// Get the window’s current viewport boundaries
|
||||
const viewportHeight = window.innerHeight || document.documentElement.clientHeight;
|
||||
const viewportWidth = window.innerWidth || document.documentElement.clientWidth;
|
||||
|
||||
let textInView = "";
|
||||
const walker = document.createTreeWalker(
|
||||
document.body,
|
||||
NodeFilter.SHOW_TEXT,
|
||||
null,
|
||||
false
|
||||
);
|
||||
|
||||
while (walker.nextNode()) {
|
||||
const textNode = walker.currentNode;
|
||||
// Create a range to retrieve bounding rectangles of the current text node
|
||||
const range = document.createRange();
|
||||
range.selectNodeContents(textNode);
|
||||
|
||||
const rects = range.getClientRects();
|
||||
|
||||
// Check if any rect is inside (or partially inside) the viewport
|
||||
for (const rect of rects) {
|
||||
const isVisible =
|
||||
rect.width > 0 &&
|
||||
rect.height > 0 &&
|
||||
rect.bottom >= 0 &&
|
||||
rect.right >= 0 &&
|
||||
rect.top <= viewportHeight &&
|
||||
rect.left <= viewportWidth;
|
||||
|
||||
if (isVisible) {
|
||||
textInView += textNode.nodeValue.replace(/\s+/g, " ");
|
||||
// Is the parent a block element?
|
||||
if (textNode.parentNode) {
|
||||
const parent = textNode.parentNode;
|
||||
const style = window.getComputedStyle(parent);
|
||||
if (["inline", "hidden", "none"].indexOf(style.display) === -1) {
|
||||
textInView += "\n";
|
||||
}
|
||||
}
|
||||
break; // No need to check other rects once found visible
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove blank lines from textInView
|
||||
textInView = textInView.replace(/^\s*\n/gm, "").trim().replace(/\n+/g, "\n");
|
||||
return textInView;
|
||||
};
|
||||
|
||||
return {
|
||||
getInteractiveRects: getInteractiveRects,
|
||||
getVisualViewport: getVisualViewport,
|
||||
getFocusedElementId: getFocusedElementId,
|
||||
getPageMetadata: getPageMetadata,
|
||||
getVisibleText: getVisibleText,
|
||||
};
|
||||
})();
|
||||
@@ -0,0 +1,578 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union, cast
|
||||
|
||||
from playwright._impl._errors import Error as PlaywrightError
|
||||
from playwright._impl._errors import TimeoutError
|
||||
from playwright.async_api import Download, Page
|
||||
|
||||
from ._types import (
|
||||
InteractiveRegion,
|
||||
VisualViewport,
|
||||
interactiveregion_from_dict,
|
||||
visualviewport_from_dict,
|
||||
)
|
||||
|
||||
markitdown: ModuleType | None = None
|
||||
try:
|
||||
# Suppress warnings from markitdown -- which is pretty chatty
|
||||
warnings.filterwarnings(action="ignore", module="markitdown")
|
||||
import markitdown
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class PlaywrightController:
|
||||
"""
|
||||
A helper class to allow Playwright to interact with web pages to perform actions such as clicking, filling, and scrolling.
|
||||
|
||||
Args:
|
||||
downloads_folder (str | None): The folder to save downloads to. If None, downloads are not saved.
|
||||
animate_actions (bool): Whether to animate the actions (create fake cursor to click).
|
||||
viewport_width (int): The width of the viewport.
|
||||
viewport_height (int): The height of the viewport.
|
||||
_download_handler (Optional[Callable[[Download], None]]): A function to handle downloads.
|
||||
to_resize_viewport (bool): Whether to resize the viewport
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
downloads_folder: str | None = None,
|
||||
animate_actions: bool = False,
|
||||
viewport_width: int = 1440,
|
||||
viewport_height: int = 900,
|
||||
_download_handler: Optional[Callable[[Download], None]] = None,
|
||||
to_resize_viewport: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the PlaywrightController.
|
||||
"""
|
||||
assert isinstance(animate_actions, bool)
|
||||
assert isinstance(viewport_width, int)
|
||||
assert isinstance(viewport_height, int)
|
||||
assert viewport_height > 0
|
||||
assert viewport_width > 0
|
||||
|
||||
self.animate_actions = animate_actions
|
||||
self.downloads_folder = downloads_folder
|
||||
self.viewport_width = viewport_width
|
||||
self.viewport_height = viewport_height
|
||||
self._download_handler = _download_handler
|
||||
self.to_resize_viewport = to_resize_viewport
|
||||
self._page_script: str = ""
|
||||
self.last_cursor_position: Tuple[float, float] = (0.0, 0.0)
|
||||
self._markdown_converter: Optional[Any] | None = None
|
||||
|
||||
# Read page_script
|
||||
with open(
|
||||
os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js"), "rt", encoding="utf-8"
|
||||
) as fh:
|
||||
self._page_script = fh.read()
|
||||
|
||||
async def sleep(self, page: Page, duration: Union[int, float]) -> None:
|
||||
"""
|
||||
Pause the execution for a specified duration.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
duration (Union[int, float]): The duration to sleep in milliseconds.
|
||||
"""
|
||||
assert page is not None
|
||||
await page.wait_for_timeout(duration * 1000)
|
||||
|
||||
async def get_interactive_rects(self, page: Page) -> Dict[str, InteractiveRegion]:
|
||||
"""
|
||||
Retrieve interactive regions from the web page.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
|
||||
Returns:
|
||||
Dict[str, InteractiveRegion]: A dictionary of interactive regions.
|
||||
"""
|
||||
assert page is not None
|
||||
# Read the regions from the DOM
|
||||
try:
|
||||
await page.evaluate(self._page_script)
|
||||
except Exception:
|
||||
pass
|
||||
result = cast(Dict[str, Dict[str, Any]], await page.evaluate("MultimodalWebSurfer.getInteractiveRects();"))
|
||||
|
||||
# Convert the results into appropriate types
|
||||
assert isinstance(result, dict)
|
||||
typed_results: Dict[str, InteractiveRegion] = {}
|
||||
for k in result:
|
||||
assert isinstance(k, str)
|
||||
typed_results[k] = interactiveregion_from_dict(result[k])
|
||||
|
||||
return typed_results
|
||||
|
||||
async def get_visual_viewport(self, page: Page) -> VisualViewport:
|
||||
"""
|
||||
Retrieve the visual viewport of the web page.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
|
||||
Returns:
|
||||
VisualViewport: The visual viewport of the page.
|
||||
"""
|
||||
assert page is not None
|
||||
try:
|
||||
await page.evaluate(self._page_script)
|
||||
except Exception:
|
||||
pass
|
||||
return visualviewport_from_dict(await page.evaluate("MultimodalWebSurfer.getVisualViewport();"))
|
||||
|
||||
async def get_focused_rect_id(self, page: Page) -> str | None:
|
||||
"""
|
||||
Retrieve the ID of the currently focused element.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
|
||||
Returns:
|
||||
str: The ID of the focused element or None if no control has focus.
|
||||
"""
|
||||
assert page is not None
|
||||
try:
|
||||
await page.evaluate(self._page_script)
|
||||
except Exception:
|
||||
pass
|
||||
result = await page.evaluate("MultimodalWebSurfer.getFocusedElementId();")
|
||||
return None if result is None else str(result)
|
||||
|
||||
async def get_page_metadata(self, page: Page) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve metadata from the web page.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary of page metadata.
|
||||
"""
|
||||
assert page is not None
|
||||
try:
|
||||
await page.evaluate(self._page_script)
|
||||
except Exception:
|
||||
pass
|
||||
result = await page.evaluate("MultimodalWebSurfer.getPageMetadata();")
|
||||
assert isinstance(result, dict)
|
||||
return cast(Dict[str, Any], result)
|
||||
|
||||
async def on_new_page(self, page: Page) -> None:
|
||||
"""
|
||||
Handle actions to perform on a new page.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
"""
|
||||
assert page is not None
|
||||
page.on("download", self._download_handler) # type: ignore
|
||||
if self.to_resize_viewport and self.viewport_width and self.viewport_height:
|
||||
await page.set_viewport_size({"width": self.viewport_width, "height": self.viewport_height})
|
||||
await self.sleep(page, 0.2)
|
||||
await page.add_init_script(path=os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js"))
|
||||
await page.wait_for_load_state()
|
||||
|
||||
async def back(self, page: Page) -> None:
|
||||
"""
|
||||
Navigate back to the previous page.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
"""
|
||||
assert page is not None
|
||||
await page.go_back()
|
||||
|
||||
async def visit_page(self, page: Page, url: str) -> Tuple[bool, bool]:
|
||||
"""
|
||||
Visit a specified URL.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
url (str): The URL to visit.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, bool]: A tuple indicating whether to reset prior metadata hash and last download.
|
||||
"""
|
||||
assert page is not None
|
||||
reset_prior_metadata_hash = False
|
||||
reset_last_download = False
|
||||
try:
|
||||
# Regular webpage
|
||||
await page.goto(url)
|
||||
await page.wait_for_load_state()
|
||||
reset_prior_metadata_hash = True
|
||||
except Exception as e_outer:
|
||||
# Downloaded file
|
||||
if self.downloads_folder and "net::ERR_ABORTED" in str(e_outer):
|
||||
async with page.expect_download() as download_info:
|
||||
try:
|
||||
await page.goto(url)
|
||||
except Exception as e_inner:
|
||||
if "net::ERR_ABORTED" in str(e_inner):
|
||||
pass
|
||||
else:
|
||||
raise e_inner
|
||||
download = await download_info.value
|
||||
fname = os.path.join(self.downloads_folder, download.suggested_filename)
|
||||
await download.save_as(fname)
|
||||
message = f"<body style=\"margin: 20px;\"><h1>Successfully downloaded '{download.suggested_filename}' to local path:<br><br>{fname}</h1></body>"
|
||||
await page.goto(
|
||||
"data:text/html;base64," + base64.b64encode(message.encode("utf-8")).decode("utf-8")
|
||||
)
|
||||
reset_last_download = True
|
||||
else:
|
||||
raise e_outer
|
||||
return reset_prior_metadata_hash, reset_last_download
|
||||
|
||||
async def page_down(self, page: Page) -> None:
|
||||
"""
|
||||
Scroll the page down by one viewport height minus 50 pixels.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
"""
|
||||
assert page is not None
|
||||
await page.evaluate(f"window.scrollBy(0, {self.viewport_height-50});")
|
||||
|
||||
async def page_up(self, page: Page) -> None:
|
||||
"""
|
||||
Scroll the page up by one viewport height minus 50 pixels.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
"""
|
||||
assert page is not None
|
||||
await page.evaluate(f"window.scrollBy(0, -{self.viewport_height-50});")
|
||||
|
||||
async def gradual_cursor_animation(
|
||||
self, page: Page, start_x: float, start_y: float, end_x: float, end_y: float
|
||||
) -> None:
|
||||
"""
|
||||
Animate the cursor movement gradually from start to end coordinates.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
start_x (float): The starting x-coordinate.
|
||||
start_y (float): The starting y-coordinate.
|
||||
end_x (float): The ending x-coordinate.
|
||||
end_y (float): The ending y-coordinate.
|
||||
"""
|
||||
# animation helper
|
||||
steps = 20
|
||||
for step in range(steps):
|
||||
x = start_x + (end_x - start_x) * (step / steps)
|
||||
y = start_y + (end_y - start_y) * (step / steps)
|
||||
# await page.mouse.move(x, y, steps=1)
|
||||
await page.evaluate(f"""
|
||||
(function() {{
|
||||
let cursor = document.getElementById('red-cursor');
|
||||
cursor.style.left = '{x}px';
|
||||
cursor.style.top = '{y}px';
|
||||
}})();
|
||||
""")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
self.last_cursor_position = (end_x, end_y)
|
||||
|
||||
async def add_cursor_box(self, page: Page, identifier: str) -> None:
|
||||
"""
|
||||
Add a red cursor box around the element with the given identifier.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
identifier (str): The element identifier.
|
||||
"""
|
||||
# animation helper
|
||||
await page.evaluate(f"""
|
||||
(function() {{
|
||||
let elm = document.querySelector("[__elementId='{identifier}']");
|
||||
if (elm) {{
|
||||
elm.style.transition = 'border 0.3s ease-in-out';
|
||||
elm.style.border = '2px solid red';
|
||||
}}
|
||||
}})();
|
||||
""")
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Create a red cursor
|
||||
await page.evaluate("""
|
||||
(function() {
|
||||
let cursor = document.createElement('div');
|
||||
cursor.id = 'red-cursor';
|
||||
cursor.style.width = '10px';
|
||||
cursor.style.height = '10px';
|
||||
cursor.style.backgroundColor = 'red';
|
||||
cursor.style.position = 'absolute';
|
||||
cursor.style.borderRadius = '50%';
|
||||
cursor.style.zIndex = '10000';
|
||||
document.body.appendChild(cursor);
|
||||
})();
|
||||
""")
|
||||
|
||||
async def remove_cursor_box(self, page: Page, identifier: str) -> None:
|
||||
"""
|
||||
Remove the red cursor box around the element with the given identifier.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
identifier (str): The element identifier.
|
||||
"""
|
||||
# Remove the highlight and cursor
|
||||
await page.evaluate(f"""
|
||||
(function() {{
|
||||
let elm = document.querySelector("[__elementId='{identifier}']");
|
||||
if (elm) {{
|
||||
elm.style.border = '';
|
||||
}}
|
||||
let cursor = document.getElementById('red-cursor');
|
||||
if (cursor) {{
|
||||
cursor.remove();
|
||||
}}
|
||||
}})();
|
||||
""")
|
||||
|
||||
async def click_id(self, page: Page, identifier: str) -> Page | None:
|
||||
"""
|
||||
Click the element with the given identifier.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
identifier (str): The element identifier.
|
||||
|
||||
Returns:
|
||||
Page | None: The new page if a new page is opened, otherwise None.
|
||||
"""
|
||||
new_page: Page | None = None
|
||||
assert page is not None
|
||||
target = page.locator(f"[__elementId='{identifier}']")
|
||||
|
||||
# See if it exists
|
||||
try:
|
||||
await target.wait_for(timeout=5000)
|
||||
except TimeoutError:
|
||||
raise ValueError("No such element.") from None
|
||||
|
||||
# Click it
|
||||
await target.scroll_into_view_if_needed()
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
box = cast(Dict[str, Union[int, float]], await target.bounding_box())
|
||||
|
||||
if self.animate_actions:
|
||||
await self.add_cursor_box(page, identifier)
|
||||
# Move cursor to the box slowly
|
||||
start_x, start_y = self.last_cursor_position
|
||||
end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2
|
||||
await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
try:
|
||||
# Give it a chance to open a new page
|
||||
async with page.expect_event("popup", timeout=1000) as page_info: # type: ignore
|
||||
await page.mouse.click(end_x, end_y, delay=10)
|
||||
new_page = await page_info.value # type: ignore
|
||||
assert isinstance(new_page, Page)
|
||||
await self.on_new_page(new_page)
|
||||
except TimeoutError:
|
||||
pass
|
||||
await self.remove_cursor_box(page, identifier)
|
||||
|
||||
else:
|
||||
try:
|
||||
# Give it a chance to open a new page
|
||||
async with page.expect_event("popup", timeout=1000) as page_info: # type: ignore
|
||||
await page.mouse.click(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2, delay=10)
|
||||
new_page = await page_info.value # type: ignore
|
||||
assert isinstance(new_page, Page)
|
||||
await self.on_new_page(new_page)
|
||||
except TimeoutError:
|
||||
pass
|
||||
return new_page # type: ignore
|
||||
|
||||
async def hover_id(self, page: Page, identifier: str) -> None:
|
||||
"""
|
||||
Hover the mouse over the element with the given identifier.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
identifier (str): The element identifier.
|
||||
"""
|
||||
assert page is not None
|
||||
target = page.locator(f"[__elementId='{identifier}']")
|
||||
|
||||
# See if it exists
|
||||
try:
|
||||
await target.wait_for(timeout=5000)
|
||||
except TimeoutError:
|
||||
raise ValueError("No such element.") from None
|
||||
|
||||
# Hover over it
|
||||
await target.scroll_into_view_if_needed()
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
box = cast(Dict[str, Union[int, float]], await target.bounding_box())
|
||||
|
||||
if self.animate_actions:
|
||||
await self.add_cursor_box(page, identifier)
|
||||
# Move cursor to the box slowly
|
||||
start_x, start_y = self.last_cursor_position
|
||||
end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2
|
||||
await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y)
|
||||
await asyncio.sleep(0.1)
|
||||
await page.mouse.move(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2)
|
||||
|
||||
await self.remove_cursor_box(page, identifier)
|
||||
else:
|
||||
await page.mouse.move(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2)
|
||||
|
||||
async def fill_id(self, page: Page, identifier: str, value: str, press_enter: bool = True) -> None:
|
||||
"""
|
||||
Fill the element with the given identifier with the specified value.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
identifier (str): The element identifier.
|
||||
value (str): The value to fill.
|
||||
"""
|
||||
assert page is not None
|
||||
target = page.locator(f"[__elementId='{identifier}']")
|
||||
|
||||
# See if it exists
|
||||
try:
|
||||
await target.wait_for(timeout=5000)
|
||||
except TimeoutError:
|
||||
raise ValueError("No such element.") from None
|
||||
|
||||
# Fill it
|
||||
await target.scroll_into_view_if_needed()
|
||||
box = cast(Dict[str, Union[int, float]], await target.bounding_box())
|
||||
|
||||
if self.animate_actions:
|
||||
await self.add_cursor_box(page, identifier)
|
||||
# Move cursor to the box slowly
|
||||
start_x, start_y = self.last_cursor_position
|
||||
end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2
|
||||
await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Focus on the element
|
||||
await target.focus()
|
||||
if self.animate_actions:
|
||||
# fill char by char to mimic human speed for short text and type fast for long text
|
||||
if len(value) < 100:
|
||||
delay_typing_speed = 50 + 100 * random.random()
|
||||
else:
|
||||
delay_typing_speed = 10
|
||||
await target.press_sequentially(value, delay=delay_typing_speed)
|
||||
else:
|
||||
try:
|
||||
await target.fill(value)
|
||||
except PlaywrightError:
|
||||
await target.press_sequentially(value)
|
||||
if press_enter:
|
||||
await target.press("Enter")
|
||||
|
||||
if self.animate_actions:
|
||||
await self.remove_cursor_box(page, identifier)
|
||||
|
||||
async def scroll_id(self, page: Page, identifier: str, direction: str) -> None:
|
||||
"""
|
||||
Scroll the element with the given identifier in the specified direction.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
identifier (str): The element identifier.
|
||||
direction (str): The direction to scroll ("up" or "down").
|
||||
"""
|
||||
assert page is not None
|
||||
await page.evaluate(
|
||||
f"""
|
||||
(function() {{
|
||||
let elm = document.querySelector("[__elementId='{identifier}']");
|
||||
if (elm) {{
|
||||
if ("{direction}" == "up") {{
|
||||
elm.scrollTop = Math.max(0, elm.scrollTop - elm.clientHeight);
|
||||
}}
|
||||
else {{
|
||||
elm.scrollTop = Math.min(elm.scrollHeight - elm.clientHeight, elm.scrollTop + elm.clientHeight);
|
||||
}}
|
||||
}}
|
||||
}})();
|
||||
"""
|
||||
)
|
||||
|
||||
async def get_webpage_text(self, page: Page, n_lines: int = 50) -> str:
|
||||
"""
|
||||
Retrieve the text content of the web page.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
n_lines (int): The number of lines to return from the page inner text.
|
||||
|
||||
Returns:
|
||||
str: The text content of the page.
|
||||
"""
|
||||
assert page is not None
|
||||
try:
|
||||
text_in_viewport = await page.evaluate("""() => {
|
||||
return document.body.innerText;
|
||||
}""")
|
||||
text_in_viewport = "\n".join(text_in_viewport.split("\n")[:n_lines])
|
||||
# remove empty lines
|
||||
text_in_viewport = "\n".join([line for line in text_in_viewport.split("\n") if line.strip()])
|
||||
assert isinstance(text_in_viewport, str)
|
||||
return text_in_viewport
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
async def get_visible_text(self, page: Page) -> str:
|
||||
"""
|
||||
Retrieve the text content of the browser viewport (approximately).
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
|
||||
Returns:
|
||||
str: The text content of the page.
|
||||
"""
|
||||
assert page is not None
|
||||
try:
|
||||
await page.evaluate(self._page_script)
|
||||
except Exception:
|
||||
pass
|
||||
result = await page.evaluate("MultimodalWebSurfer.getVisibleText();")
|
||||
assert isinstance(result, str)
|
||||
return result
|
||||
|
||||
async def get_page_markdown(self, page: Page) -> str:
|
||||
"""
|
||||
Retrieve the markdown content of the web page.
|
||||
Currently not implemented.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
|
||||
Returns:
|
||||
str: The markdown content of the page.
|
||||
"""
|
||||
assert page is not None
|
||||
if self._markdown_converter is None and markitdown is not None:
|
||||
self._markdown_converter = markitdown.MarkItDown()
|
||||
assert self._markdown_converter is not None
|
||||
html = await page.evaluate("document.documentElement.outerHTML;")
|
||||
res = self._markdown_converter.convert_stream(
|
||||
io.BytesIO(html.encode("utf-8")), file_extension=".html", url=page.url
|
||||
)
|
||||
assert hasattr(res, "text_content") and isinstance(res.text_content, str)
|
||||
return res.text_content
|
||||
else:
|
||||
return await self.get_webpage_text(page, n_lines=200)
|
||||
56
agent_dhal/agentdhal_extensions/auth/azure/__init__.py
Normal file
56
agent_dhal/agentdhal_extensions/auth/azure/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import List
|
||||
|
||||
from agentdhal_core import Component, ComponentBase
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from azure.core.credentials import TokenProvider
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
|
||||
|
||||
class TokenProviderConfig(BaseModel):
|
||||
provider_kind: str
|
||||
scopes: List[str]
|
||||
|
||||
|
||||
class AzureTokenProvider(ComponentBase[TokenProviderConfig], Component[TokenProviderConfig]):
|
||||
component_type = "token_provider"
|
||||
component_config_schema = TokenProviderConfig
|
||||
component_provider_override = "agentdhal_extensions.auth.azure.AzureTokenProvider"
|
||||
|
||||
def __init__(self, credential: TokenProvider, *scopes: str):
|
||||
self.credential = credential
|
||||
self.scopes = list(scopes)
|
||||
self.provider = get_bearer_token_provider(self.credential, *self.scopes)
|
||||
|
||||
def __call__(self) -> str:
|
||||
return self.provider()
|
||||
|
||||
def _to_config(self) -> TokenProviderConfig:
|
||||
"""Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance.
|
||||
|
||||
Returns:
|
||||
T: The configuration of the component.
|
||||
"""
|
||||
|
||||
if isinstance(self.credential, DefaultAzureCredential):
|
||||
# NOTE: we are not currently inspecting the chained credentials, so this could result in a loss of information
|
||||
return TokenProviderConfig(provider_kind="DefaultAzureCredential", scopes=self.scopes)
|
||||
else:
|
||||
raise ValueError("Only DefaultAzureCredential is supported")
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: TokenProviderConfig) -> Self:
|
||||
"""Create a new instance of the component from a configuration object.
|
||||
|
||||
Args:
|
||||
config (T): The configuration object.
|
||||
|
||||
Returns:
|
||||
Self: The new instance of the component.
|
||||
"""
|
||||
|
||||
if config.provider_kind == "DefaultAzureCredential":
|
||||
return cls(DefaultAzureCredential(), *config.scopes)
|
||||
else:
|
||||
raise ValueError("Only DefaultAzureCredential is supported")
|
||||
46
agent_dhal/agentdhal_extensions/cache_store/diskcache.py
Normal file
46
agent_dhal/agentdhal_extensions/cache_store/diskcache.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import Any, Optional, TypeVar, cast
|
||||
|
||||
import diskcache
|
||||
from agentdhal_core import CacheStore, Component
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class DiskCacheStoreConfig(BaseModel):
|
||||
"""Configuration for DiskCacheStore"""
|
||||
|
||||
directory: str # Path where cache is stored
|
||||
# Could add other diskcache.Cache parameters like size_limit, etc.
|
||||
|
||||
|
||||
class DiskCacheStore(CacheStore[T], Component[DiskCacheStoreConfig]):
|
||||
"""
|
||||
A typed CacheStore implementation that uses diskcache as the underlying storage.
|
||||
See :class:`~agentdhal_extensions.models.cache.ChatCompletionCache` for an example of usage.
|
||||
|
||||
Args:
|
||||
cache_instance: An instance of diskcache.Cache.
|
||||
The user is responsible for managing the DiskCache instance's lifetime.
|
||||
"""
|
||||
|
||||
component_config_schema = DiskCacheStoreConfig
|
||||
component_provider_override = "agentdhal_extensions.cache_store.diskcache.DiskCacheStore"
|
||||
|
||||
def __init__(self, cache_instance: diskcache.Cache): # type: ignore[no-any-unimported]
|
||||
self.cache = cache_instance
|
||||
|
||||
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
|
||||
return cast(Optional[T], self.cache.get(key, default)) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
def set(self, key: str, value: T) -> None:
|
||||
self.cache.set(key, cast(Any, value)) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
def _to_config(self) -> DiskCacheStoreConfig:
|
||||
# Get directory from cache instance
|
||||
return DiskCacheStoreConfig(directory=self.cache.directory)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: DiskCacheStoreConfig) -> Self:
|
||||
return cls(cache_instance=diskcache.Cache(config.directory)) # type: ignore[no-any-return]
|
||||
142
agent_dhal/agentdhal_extensions/cache_store/redis.py
Normal file
142
agent_dhal/agentdhal_extensions/cache_store/redis.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional, TypeVar, cast
|
||||
|
||||
import redis
|
||||
from agentdhal_core import CacheStore, Component
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RedisStoreConfig(BaseModel):
|
||||
"""Configuration for RedisStore"""
|
||||
|
||||
host: str = "localhost"
|
||||
port: int = 6379
|
||||
db: int = 0
|
||||
# Add other relevant redis connection parameters
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
ssl: bool = False
|
||||
socket_timeout: Optional[float] = None
|
||||
|
||||
|
||||
class RedisStore(CacheStore[T], Component[RedisStoreConfig]):
|
||||
"""
|
||||
A typed CacheStore implementation that uses redis as the underlying storage.
|
||||
See :class:`~agentdhal_extensions.models.cache.ChatCompletionCache` for an example of usage.
|
||||
|
||||
This implementation provides automatic serialization and deserialization for:
|
||||
- Pydantic models (uses model_dump_json/model_validate_json)
|
||||
- Primitive types (strings, numbers, etc.)
|
||||
|
||||
|
||||
Args:
|
||||
cache_instance: An instance of `redis.Redis`.
|
||||
The user is responsible for managing the Redis instance's lifetime.
|
||||
"""
|
||||
|
||||
component_config_schema = RedisStoreConfig
|
||||
component_provider_override = "agentdhal_extensions.cache_store.redis.RedisStore"
|
||||
|
||||
def __init__(self, redis_instance: redis.Redis):
|
||||
self.cache = redis_instance
|
||||
|
||||
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
|
||||
"""
|
||||
Retrieve a value from the Redis cache.
|
||||
|
||||
This method handles both primitive values and complex objects:
|
||||
- Pydantic models are automatically deserialized from JSON
|
||||
- Primitive values (strings, numbers, etc.) are returned as-is
|
||||
- If deserialization fails, returns the raw value or default
|
||||
|
||||
Args:
|
||||
key: The key to retrieve
|
||||
default: Value to return if key doesn't exist
|
||||
|
||||
Returns:
|
||||
The value if found and properly deserialized, otherwise the default
|
||||
"""
|
||||
try:
|
||||
raw_value = self.cache.get(key)
|
||||
if raw_value is None:
|
||||
return default
|
||||
|
||||
if isinstance(raw_value, bytes):
|
||||
try:
|
||||
# First try to decode as UTF-8 string
|
||||
decoded_str = raw_value.decode("utf-8")
|
||||
try:
|
||||
# Try to parse as JSON and return the parsed object
|
||||
parsed_json = json.loads(decoded_str)
|
||||
return cast(Optional[T], parsed_json)
|
||||
except json.JSONDecodeError:
|
||||
# If not valid JSON, return the decoded string.
|
||||
return cast(Optional[T], decoded_str)
|
||||
except UnicodeDecodeError:
|
||||
return default
|
||||
else:
|
||||
# Backward compatibility for primitives
|
||||
return cast(Optional[T], raw_value)
|
||||
except (redis.RedisError, ConnectionError):
|
||||
# Log Redis-specific errors but return default gracefully
|
||||
return default
|
||||
|
||||
def set(self, key: str, value: T) -> None:
|
||||
"""
|
||||
Store a value in the Redis cache.
|
||||
|
||||
This method handles both primitive values and complex objects:
|
||||
- Pydantic models are automatically serialized to JSON
|
||||
- Primitive values (strings, numbers, etc.) are stored as-is
|
||||
|
||||
Args:
|
||||
key: The key to store the value under
|
||||
value: The value to store
|
||||
"""
|
||||
try:
|
||||
if isinstance(value, BaseModel):
|
||||
# Serialize Pydantic models to JSON
|
||||
serialized_value = value.model_dump_json().encode("utf-8")
|
||||
self.cache.set(key, serialized_value)
|
||||
else:
|
||||
# Backward compatibility for primitives
|
||||
self.cache.set(key, cast(Any, value))
|
||||
except (redis.RedisError, ConnectionError, UnicodeEncodeError):
|
||||
# Log the error but don't re-raise to maintain robustness
|
||||
pass
|
||||
|
||||
def _to_config(self) -> RedisStoreConfig:
|
||||
# Extract connection info from redis instance
|
||||
connection_pool = self.cache.connection_pool
|
||||
connection_kwargs: Dict[str, Any] = connection_pool.connection_kwargs # type: ignore[reportUnknownMemberType]
|
||||
|
||||
username = connection_kwargs.get("username")
|
||||
password = connection_kwargs.get("password")
|
||||
socket_timeout = connection_kwargs.get("socket_timeout")
|
||||
|
||||
return RedisStoreConfig(
|
||||
host=str(connection_kwargs.get("host", "localhost")),
|
||||
port=int(connection_kwargs.get("port", 6379)),
|
||||
db=int(connection_kwargs.get("db", 0)),
|
||||
username=str(username) if username is not None else None,
|
||||
password=str(password) if password is not None else None,
|
||||
ssl=bool(connection_kwargs.get("ssl", False)),
|
||||
socket_timeout=float(socket_timeout) if socket_timeout is not None else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: RedisStoreConfig) -> Self:
|
||||
# Create new redis instance from config
|
||||
redis_instance = redis.Redis(
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
db=config.db,
|
||||
username=config.username,
|
||||
password=config.password,
|
||||
ssl=config.ssl,
|
||||
socket_timeout=config.socket_timeout,
|
||||
)
|
||||
return cls(redis_instance=redis_instance)
|
||||
199
agent_dhal/agentdhal_extensions/code_executors/_common.py
Normal file
199
agent_dhal/agentdhal_extensions/code_executors/_common.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import inspect
|
||||
import re
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from textwrap import dedent, indent
|
||||
from typing import Any, Callable, Optional, Sequence, Set, TypeVar, Union
|
||||
|
||||
from agentdhal_core.code_executor import Alias, CodeResult, FunctionWithRequirements, FunctionWithRequirementsStr, Import
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandLineCodeResult(CodeResult):
|
||||
"""A code result class for command line code executor."""
|
||||
|
||||
code_file: Optional[str]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str:
|
||||
if isinstance(func, FunctionWithRequirementsStr):
|
||||
return func.func
|
||||
|
||||
code = inspect.getsource(func)
|
||||
# Strip the decorator
|
||||
if code.startswith("@"):
|
||||
code = code[code.index("\n") + 1 :]
|
||||
return code
|
||||
|
||||
|
||||
def _import_to_str(im: Import) -> str:
|
||||
if isinstance(im, str):
|
||||
return f"import {im}"
|
||||
elif isinstance(im, Alias):
|
||||
return f"import {im.name} as {im.alias}"
|
||||
else:
|
||||
|
||||
def to_str(i: Union[str, Alias]) -> str:
|
||||
if isinstance(i, str):
|
||||
return i
|
||||
else:
|
||||
return f"{i.name} as {i.alias}"
|
||||
|
||||
imports = ", ".join(map(to_str, im.imports))
|
||||
return f"from {im.module} import {imports}"
|
||||
|
||||
|
||||
def build_python_functions_file(
|
||||
funcs: Sequence[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]],
|
||||
) -> str:
|
||||
""":meta private:"""
|
||||
# First collect all global imports
|
||||
global_imports: Set[Import] = set()
|
||||
for func in funcs:
|
||||
if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)):
|
||||
global_imports.update(func.global_imports)
|
||||
|
||||
content = "\n".join(map(_import_to_str, global_imports)) + "\n\n"
|
||||
|
||||
for func in funcs:
|
||||
content += _to_code(func) + "\n\n"
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str:
|
||||
"""Generate a stub for a function as a string
|
||||
|
||||
Args:
|
||||
func (Callable[..., Any]): The function to generate a stub for
|
||||
|
||||
Returns:
|
||||
str: The stub for the function
|
||||
"""
|
||||
if isinstance(func, FunctionWithRequirementsStr):
|
||||
return to_stub(func.compiled_func)
|
||||
|
||||
content = f"def {func.__name__}{inspect.signature(func)}:\n"
|
||||
docstring = func.__doc__
|
||||
|
||||
if docstring:
|
||||
docstring = dedent(docstring)
|
||||
docstring = '"""' + docstring + '"""'
|
||||
docstring = indent(docstring, " ")
|
||||
content += docstring + "\n"
|
||||
|
||||
content += " ..."
|
||||
return content
|
||||
|
||||
|
||||
# Raises ValueError if the file is not in the workspace
|
||||
def get_file_name_from_content(code: str, workspace_path: Path) -> Optional[str]:
|
||||
first_line = code.split("\n")[0]
|
||||
# TODO - support other languages
|
||||
if first_line.startswith("# filename:"):
|
||||
filename = first_line.split(":")[1].strip()
|
||||
|
||||
# Handle relative paths in the filename
|
||||
path = Path(filename)
|
||||
if not path.is_absolute():
|
||||
path = workspace_path / path
|
||||
path = path.resolve()
|
||||
# Throws an error if the file is not in the workspace
|
||||
relative = path.relative_to(workspace_path.resolve())
|
||||
return str(relative)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def silence_pip(code: str, lang: str) -> str:
|
||||
"""Apply -qqq flag to pip install commands."""
|
||||
if lang == "python":
|
||||
regex = r"^! ?pip install"
|
||||
elif lang in ["bash", "shell", "sh", "pwsh", "powershell", "ps1"]:
|
||||
regex = r"^pip install"
|
||||
else:
|
||||
return code
|
||||
|
||||
# Find lines that start with pip install and make sure "-qqq" flag is added.
|
||||
lines = code.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
# use regex to find lines that start with pip install.
|
||||
match = re.search(regex, line)
|
||||
if match is not None:
|
||||
if "-qqq" not in line:
|
||||
lines[i] = line.replace(match.group(0), match.group(0) + " -qqq")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def get_required_packages(code: str, lang: str) -> set[str]:
|
||||
ret: set[str] = set()
|
||||
if lang == "python":
|
||||
regex = r"^! ?pip install(.*)$"
|
||||
else:
|
||||
return ret
|
||||
|
||||
# Find lines that start with pip install and make sure "-qqq" flag is added.
|
||||
lines = code.split("\n")
|
||||
for _, line in enumerate(lines):
|
||||
# use regex to find lines that start with pip install.
|
||||
match = re.search(regex, line)
|
||||
if match is not None:
|
||||
reqs = match.group(1).split(",")
|
||||
ret = {req.strip(" ") for req in reqs}
|
||||
return ret
|
||||
|
||||
|
||||
PYTHON_VARIANTS = ["python", "Python", "py"]
|
||||
|
||||
|
||||
def lang_to_cmd(lang: str) -> str:
|
||||
if lang in PYTHON_VARIANTS:
|
||||
return "python"
|
||||
if lang.startswith("python") or lang in ["bash", "sh"]:
|
||||
return lang
|
||||
if lang in ["shell"]:
|
||||
return "sh"
|
||||
if lang in ["pwsh", "powershell", "ps1"]:
|
||||
# Check if pwsh is available, otherwise fall back to powershell
|
||||
if shutil.which("pwsh") is not None:
|
||||
return "pwsh"
|
||||
elif shutil.which("powershell") is not None:
|
||||
return "powershell"
|
||||
else:
|
||||
raise ValueError("Powershell or pwsh is not installed. Please install one of them.")
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {lang}")
|
||||
|
||||
|
||||
# Regular expression for finding a code block
|
||||
# ```[ \t]*(\w+)?[ \t]*\r?\n(.*?)[ \t]*\r?\n``` Matches multi-line code blocks.
|
||||
# The [ \t]* matches the potential spaces before language name.
|
||||
# The (\w+)? matches the language, where the ? indicates it is optional.
|
||||
# The [ \t]* matches the potential spaces (not newlines) after language name.
|
||||
# The \r?\n makes sure there is a linebreak after ```.
|
||||
# The (.*?) matches the code itself (non-greedy).
|
||||
# The \r?\n makes sure there is a linebreak before ```.
|
||||
# The [ \t]* matches the potential spaces before closing ``` (the spec allows indentation).
|
||||
CODE_BLOCK_PATTERN = r"```[ \t]*(\w+)?[ \t]*\r?\n(.*?)\r?\n[ \t]*```"
|
||||
|
||||
|
||||
def infer_lang(code: str) -> str:
|
||||
"""infer the language for the code.
|
||||
TODO: make it robust.
|
||||
"""
|
||||
if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "):
|
||||
return "sh"
|
||||
|
||||
# check if code is a valid python code
|
||||
try:
|
||||
compile(code, "test", "exec")
|
||||
return "python"
|
||||
except SyntaxError:
|
||||
# not a valid python code
|
||||
return "unknown"
|
||||
@@ -0,0 +1,3 @@
|
||||
from ._azure_container_code_executor import ACADynamicSessionsCodeExecutor, TokenProvider
|
||||
|
||||
__all__ = ["TokenProvider", "ACADynamicSessionsCodeExecutor"]
|
||||
@@ -0,0 +1,522 @@
|
||||
# Credit to original authors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Protocol, Sequence, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
|
||||
# async functions shouldn't use open()
|
||||
from anyio import open_file
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.code_executor import (
|
||||
CodeBlock,
|
||||
CodeExecutor,
|
||||
CodeResult,
|
||||
FunctionWithRequirements,
|
||||
FunctionWithRequirementsStr,
|
||||
)
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from .._common import build_python_functions_file, get_required_packages, to_stub
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from azure.core.credentials import AccessToken
|
||||
|
||||
PYTHON_VARIANTS = ["python", "Python", "py"]
|
||||
|
||||
__all__ = ("ACADynamicSessionsCodeExecutor", "TokenProvider")
|
||||
|
||||
A = ParamSpec("A")
|
||||
|
||||
|
||||
class TokenProvider(Protocol):
|
||||
def get_token(
|
||||
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
|
||||
) -> AccessToken: ...
|
||||
|
||||
|
||||
class ACADynamicSessionsCodeExecutor(CodeExecutor):
|
||||
"""(Experimental) A code executor class that executes code through a an Azure
|
||||
Container Apps Dynamic Sessions instance.
|
||||
|
||||
.. note::
|
||||
|
||||
This class requires the :code:`azure` extra for the :code:`autogen-ext` package:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[azure]"
|
||||
|
||||
.. caution::
|
||||
|
||||
**This will execute LLM generated code on an Azure dynamic code container.**
|
||||
|
||||
The execution environment is similar to that of a jupyter notebook which allows for incremental code execution. The parameter functions are executed in order once at the beginning of each session. Each code block is then executed serially and in the order they are received. Each environment has a statically defined set of available packages which cannot be changed.
|
||||
Currently, attempting to use packages beyond what is available on the environment will result in an error. To get the list of supported packages, call the `get_available_packages` function.
|
||||
Currently the only supported language is Python.
|
||||
For Python code, use the language "python" for the code block.
|
||||
|
||||
Args:
|
||||
pool_management_endpoint (str): The azure container apps dynamic sessions endpoint.
|
||||
credential (TokenProvider): An object that implements the get_token function.
|
||||
timeout (int): The timeout for the execution of any single code block. Default is 60.
|
||||
work_dir (str): The working directory for the code execution. If None,
|
||||
a default working directory will be used. The default working
|
||||
directory is a temporal directory.
|
||||
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
|
||||
suppress_result_output bool: By default the executor will attach any result info in the execution response to the result outpu. Set this to True to prevent this.
|
||||
session_id (str): The session id for the code execution (passed to Dynamic Sessions). If None, a new session id will be generated. Default is None. Note this value will be reset when calling `restart`
|
||||
|
||||
.. note::
|
||||
Using the current directory (".") as working directory is deprecated. Using it will raise a deprecation warning.
|
||||
"""
|
||||
|
||||
SUPPORTED_LANGUAGES: ClassVar[List[str]] = [
|
||||
"python",
|
||||
]
|
||||
FUNCTION_PROMPT_TEMPLATE: ClassVar[str] = """You have access to the following user defined functions.
|
||||
|
||||
$functions"""
|
||||
|
||||
_AZURE_API_VER = "2024-02-02-preview"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool_management_endpoint: str,
|
||||
credential: TokenProvider,
|
||||
timeout: int = 60,
|
||||
work_dir: Union[Path, str, None] = None,
|
||||
functions: Sequence[
|
||||
Union[
|
||||
FunctionWithRequirements[Any, A],
|
||||
Callable[..., Any],
|
||||
FunctionWithRequirementsStr,
|
||||
]
|
||||
] = [],
|
||||
functions_module: str = "functions",
|
||||
suppress_result_output: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
):
|
||||
if timeout < 1:
|
||||
raise ValueError("Timeout must be greater than or equal to 1.")
|
||||
|
||||
self._work_dir: Optional[Path] = None
|
||||
self._temp_dir: Optional[tempfile.TemporaryDirectory[str]] = None
|
||||
|
||||
# If a user specifies a working directory, use that
|
||||
if work_dir is not None:
|
||||
if isinstance(work_dir, str):
|
||||
self._work_dir = Path(work_dir)
|
||||
else:
|
||||
self._work_dir = work_dir
|
||||
# Create the directory if it doesn't exist
|
||||
self._work_dir.mkdir(exist_ok=True, parents=True)
|
||||
# If a user does not specify a working directory, use the default directory (tempfile.TemporaryDirectory)
|
||||
else:
|
||||
self._temp_dir = tempfile.TemporaryDirectory()
|
||||
temp_dir_path = Path(self._temp_dir.name)
|
||||
temp_dir_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
self._started = False
|
||||
|
||||
# Rest of initialization remains the same
|
||||
self._functions_module = functions_module
|
||||
self._timeout = timeout
|
||||
self._functions = functions
|
||||
self._func_code: Optional[str] = None
|
||||
|
||||
# Setup could take some time so we intentionally wait for the first code block to do it.
|
||||
if len(functions) > 0:
|
||||
self._setup_functions_complete = False
|
||||
else:
|
||||
self._setup_functions_complete = True
|
||||
|
||||
self._suppress_result_output = suppress_result_output
|
||||
|
||||
self._pool_management_endpoint = pool_management_endpoint
|
||||
self._access_token: str | None = None
|
||||
self._session_id: str = session_id or str(uuid4())
|
||||
self._available_packages: set[str] | None = None
|
||||
self._credential: TokenProvider = credential
|
||||
# cwd needs to be set to /mnt/data to properly read uploaded files and download written files
|
||||
self._setup_cwd_complete = False
|
||||
|
||||
# TODO: expiration?
|
||||
def _ensure_access_token(self) -> None:
|
||||
if not self._access_token:
|
||||
scope = "https://dynamicsessions.io/.default"
|
||||
self._access_token = self._credential.get_token(scope).token
|
||||
|
||||
def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str:
|
||||
"""(Experimental) Format the functions for a prompt.
|
||||
|
||||
The template includes one variable:
|
||||
- `$functions`: The functions formatted as stubs with two newlines between each function.
|
||||
|
||||
Args:
|
||||
prompt_template (str): The prompt template. Default is the class default.
|
||||
|
||||
Returns:
|
||||
str: The formatted prompt.
|
||||
"""
|
||||
|
||||
template = Template(prompt_template)
|
||||
return template.substitute(
|
||||
functions="\n\n".join([to_stub(func) for func in self._functions]),
|
||||
)
|
||||
|
||||
@property
|
||||
def functions_module(self) -> str:
|
||||
"""(Experimental) The module name for the functions."""
|
||||
return self._functions_module
|
||||
|
||||
@property
|
||||
def functions(self) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
"""(Experimental) The timeout for code execution."""
|
||||
return self._timeout
|
||||
|
||||
@property
|
||||
def work_dir(self) -> Path:
|
||||
# If a user specifies a working directory, use that
|
||||
if self._work_dir is not None:
|
||||
# If a user specifies the current directory, warn them that this is deprecated
|
||||
if self._work_dir == Path("."):
|
||||
warnings.warn(
|
||||
"Using the current directory as work_dir is deprecated",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self._work_dir
|
||||
# If a user does not specify a working directory, use the default directory (tempfile.TemporaryDirectory)
|
||||
elif self._temp_dir is not None:
|
||||
return Path(self._temp_dir.name)
|
||||
else:
|
||||
raise RuntimeError("Working directory not properly initialized")
|
||||
|
||||
def _construct_url(self, path: str) -> str:
|
||||
endpoint = self._pool_management_endpoint
|
||||
if not endpoint.endswith("/"):
|
||||
endpoint += "/"
|
||||
url = endpoint + f"{path}?api-version={self._AZURE_API_VER}&identifier={self._session_id}"
|
||||
return url
|
||||
|
||||
async def get_available_packages(self, cancellation_token: CancellationToken) -> set[str]:
|
||||
if self._available_packages is not None:
|
||||
return self._available_packages
|
||||
avail_pkgs = """
|
||||
import pkg_resources\n[d.project_name for d in pkg_resources.working_set]
|
||||
"""
|
||||
ret = await self._execute_code_dont_check_setup(
|
||||
[CodeBlock(code=avail_pkgs, language="python")], cancellation_token
|
||||
)
|
||||
if ret.exit_code != 0:
|
||||
raise ValueError(f"Failed to get list of available packages: {ret.output.strip()}")
|
||||
pkgs = ret.output.strip("[]")
|
||||
pkglist = pkgs.split(",\n")
|
||||
return {pkg.strip(" '") for pkg in pkglist}
|
||||
|
||||
async def _populate_available_packages(self, cancellation_token: CancellationToken) -> None:
|
||||
self._available_packages = await self.get_available_packages(cancellation_token)
|
||||
|
||||
async def _setup_functions(self, cancellation_token: CancellationToken) -> None:
|
||||
if not self._func_code:
|
||||
self._func_code = build_python_functions_file(self._functions)
|
||||
|
||||
# Check required function imports and packages
|
||||
lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)]
|
||||
# Should we also be checking the imports?
|
||||
|
||||
flattened_packages = [item for sublist in lists_of_packages for item in sublist]
|
||||
required_packages = set(flattened_packages)
|
||||
|
||||
if self._available_packages is None:
|
||||
await self._populate_available_packages(cancellation_token)
|
||||
|
||||
if self._available_packages is not None:
|
||||
missing_pkgs = set(required_packages - self._available_packages)
|
||||
if len(missing_pkgs) > 0:
|
||||
raise ValueError(f"Packages unavailable in environment: {missing_pkgs}")
|
||||
|
||||
func_file = self.work_dir / f"{self._functions_module}.py"
|
||||
func_file.write_text(self._func_code)
|
||||
|
||||
# Attempt to load the function file to check for syntax errors, imports etc.
|
||||
exec_result = await self._execute_code_dont_check_setup(
|
||||
[CodeBlock(code=self._func_code, language="python")], cancellation_token
|
||||
)
|
||||
|
||||
if exec_result.exit_code != 0:
|
||||
raise ValueError(f"Functions failed to load: {exec_result.output.strip()}")
|
||||
|
||||
self._setup_functions_complete = True
|
||||
|
||||
async def _setup_cwd(self, cancellation_token: CancellationToken) -> None:
|
||||
# Change the cwd to /mnt/data to properly have access to uploaded files
|
||||
exec_result = await self._execute_code_dont_check_setup(
|
||||
[CodeBlock(code="import os; os.chdir('/mnt/data')", language="python")], cancellation_token
|
||||
)
|
||||
|
||||
if exec_result.exit_code != 0:
|
||||
raise ValueError("Failed to set up Azure container working directory")
|
||||
self._setup_cwd_complete = True
|
||||
|
||||
async def get_file_list(self, cancellation_token: CancellationToken) -> List[str]:
|
||||
self._ensure_access_token()
|
||||
timeout = aiohttp.ClientTimeout(total=float(self._timeout))
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._access_token}",
|
||||
}
|
||||
url = self._construct_url("files")
|
||||
async with aiohttp.ClientSession(timeout=timeout) as client:
|
||||
task = asyncio.create_task(
|
||||
client.get(
|
||||
url,
|
||||
headers=headers,
|
||||
)
|
||||
)
|
||||
cancellation_token.link_future(task)
|
||||
try:
|
||||
resp = await task
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
except asyncio.TimeoutError as e:
|
||||
# e.add_note is only in py 3.11+
|
||||
raise asyncio.TimeoutError("Timeout getting file list") from e
|
||||
except asyncio.CancelledError as e:
|
||||
# e.add_note is only in py 3.11+
|
||||
raise asyncio.CancelledError("File list retrieval cancelled") from e
|
||||
except aiohttp.ClientResponseError as e:
|
||||
raise ConnectionError("Error while getting file list") from e
|
||||
|
||||
values = data["value"]
|
||||
file_info_list: List[str] = []
|
||||
for value in values:
|
||||
file = value["properties"]
|
||||
file_info_list.append(file["filename"])
|
||||
return file_info_list
|
||||
|
||||
async def upload_files(self, files: List[Union[Path, str]], cancellation_token: CancellationToken) -> None:
|
||||
self._ensure_access_token()
|
||||
# TODO: Better to use the client auth system rather than headers
|
||||
headers = {"Authorization": f"Bearer {self._access_token}"}
|
||||
url = self._construct_url("files/upload")
|
||||
timeout = aiohttp.ClientTimeout(total=float(self._timeout))
|
||||
async with aiohttp.ClientSession(timeout=timeout) as client:
|
||||
for file in files:
|
||||
file_path = self.work_dir / file
|
||||
if not file_path.is_file():
|
||||
# TODO: what to do here?
|
||||
raise FileNotFoundError(f"{file} does not exist")
|
||||
|
||||
data = aiohttp.FormData()
|
||||
async with await open_file(file_path, "rb") as f:
|
||||
data.add_field(
|
||||
"file",
|
||||
f,
|
||||
filename=os.path.basename(file_path),
|
||||
content_type="application/octet-stream",
|
||||
)
|
||||
|
||||
task = asyncio.create_task(
|
||||
client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
data=data,
|
||||
)
|
||||
)
|
||||
|
||||
cancellation_token.link_future(task)
|
||||
try:
|
||||
resp = await task
|
||||
resp.raise_for_status()
|
||||
|
||||
except asyncio.TimeoutError as e:
|
||||
# e.add_note is only in py 3.11+
|
||||
raise asyncio.TimeoutError("Timeout uploading files") from e
|
||||
except asyncio.CancelledError as e:
|
||||
# e.add_note is only in py 3.11+
|
||||
raise asyncio.CancelledError("Uploading files cancelled") from e
|
||||
except aiohttp.ClientResponseError as e:
|
||||
raise ConnectionError("Error while uploading files") from e
|
||||
|
||||
async def download_files(self, files: List[Union[Path, str]], cancellation_token: CancellationToken) -> List[str]:
|
||||
self._ensure_access_token()
|
||||
available_files = await self.get_file_list(cancellation_token)
|
||||
# TODO: Better to use the client auth system rather than headers
|
||||
headers = {"Authorization": f"Bearer {self._access_token}"}
|
||||
timeout = aiohttp.ClientTimeout(total=float(self._timeout))
|
||||
local_paths: List[str] = []
|
||||
async with aiohttp.ClientSession(timeout=timeout) as client:
|
||||
for file in files:
|
||||
if file not in available_files:
|
||||
# TODO: what's the right thing to do here?
|
||||
raise FileNotFoundError(f"{file} does not exist")
|
||||
|
||||
url = self._construct_url(f"files/content/{file}")
|
||||
|
||||
task = asyncio.create_task(
|
||||
client.get(
|
||||
url,
|
||||
headers=headers,
|
||||
)
|
||||
)
|
||||
cancellation_token.link_future(task)
|
||||
try:
|
||||
resp = await task
|
||||
resp.raise_for_status()
|
||||
local_path = self.work_dir / file
|
||||
local_paths.append(str(local_path))
|
||||
async with await open_file(local_path, "wb") as f:
|
||||
await f.write(await resp.read())
|
||||
except asyncio.TimeoutError as e:
|
||||
# e.add_note is only in py 3.11+
|
||||
raise asyncio.TimeoutError("Timeout downloading files") from e
|
||||
except asyncio.CancelledError as e:
|
||||
# e.add_note is only in py 3.11+
|
||||
raise asyncio.CancelledError("Downloading files cancelled") from e
|
||||
except aiohttp.ClientResponseError as e:
|
||||
raise ConnectionError("Error while downloading files") from e
|
||||
return local_paths
|
||||
|
||||
async def execute_code_blocks(
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> CodeResult:
|
||||
"""(Experimental) Execute the code blocks and return the result.
|
||||
|
||||
Args:
|
||||
code_blocks (List[CodeBlock]): The code blocks to execute.
|
||||
cancellation_token (CancellationToken): a token to cancel the operation
|
||||
input_files (Optional[Union[Path, str]]): Any files the code blocks will need to access
|
||||
|
||||
Returns:
|
||||
CodeResult: The result of the code execution."""
|
||||
|
||||
self._ensure_access_token()
|
||||
if self._available_packages is None:
|
||||
await self._populate_available_packages(cancellation_token)
|
||||
if not self._setup_functions_complete:
|
||||
await self._setup_functions(cancellation_token)
|
||||
if not self._setup_cwd_complete:
|
||||
await self._setup_cwd(cancellation_token)
|
||||
|
||||
return await self._execute_code_dont_check_setup(code_blocks, cancellation_token)
|
||||
|
||||
# The http call here should be replaced by an actual Azure client call once its available
|
||||
async def _execute_code_dont_check_setup(
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> CodeResult:
|
||||
logs_all = ""
|
||||
exitcode = 0
|
||||
|
||||
# TODO: Better to use the client auth system rather than headers
|
||||
assert self._access_token is not None
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
properties = {
|
||||
"codeInputType": "inline",
|
||||
"executionType": "synchronous",
|
||||
"code": "", # Filled in later
|
||||
}
|
||||
url = self._construct_url("code/execute")
|
||||
timeout = aiohttp.ClientTimeout(total=float(self._timeout))
|
||||
async with aiohttp.ClientSession(timeout=timeout) as client:
|
||||
for code_block in code_blocks:
|
||||
lang, code = code_block.language, code_block.code
|
||||
lang = lang.lower()
|
||||
|
||||
if lang in PYTHON_VARIANTS:
|
||||
lang = "python"
|
||||
|
||||
if lang not in self.SUPPORTED_LANGUAGES:
|
||||
# In case the language is not supported, we return an error message.
|
||||
exitcode = 1
|
||||
logs_all += "\n" + f"unknown language {lang}"
|
||||
break
|
||||
|
||||
if self._available_packages is not None:
|
||||
req_pkgs = get_required_packages(code, lang)
|
||||
missing_pkgs = set(req_pkgs - self._available_packages)
|
||||
if len(missing_pkgs) > 0:
|
||||
# In case the code requires packages that are not available in the environment
|
||||
exitcode = 1
|
||||
logs_all += "\n" + f"Python packages unavailable in environment: {missing_pkgs}"
|
||||
break
|
||||
|
||||
properties["code"] = code_block.code
|
||||
|
||||
task = asyncio.create_task(
|
||||
client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json={"properties": properties},
|
||||
)
|
||||
)
|
||||
|
||||
cancellation_token.link_future(task)
|
||||
try:
|
||||
response = await task
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
data = data["properties"]
|
||||
logs_all += data.get("stderr", "") + data.get("stdout", "")
|
||||
if "Success" in data["status"]:
|
||||
if not self._suppress_result_output:
|
||||
logs_all += str(data["result"])
|
||||
elif "Failure" in data["status"]:
|
||||
exitcode = 1
|
||||
|
||||
except asyncio.TimeoutError as e:
|
||||
logs_all += "\n Timeout"
|
||||
# e.add_note is only in py 3.11+
|
||||
raise asyncio.TimeoutError(logs_all) from e
|
||||
except asyncio.CancelledError as e:
|
||||
logs_all += "\n Cancelled"
|
||||
# e.add_note is only in py 3.11+
|
||||
raise asyncio.CancelledError(logs_all) from e
|
||||
except aiohttp.ClientResponseError as e:
|
||||
logs_all += "\nError while sending code block to endpoint"
|
||||
raise ConnectionError(logs_all) from e
|
||||
|
||||
return CodeResult(exit_code=exitcode, output=logs_all)
|
||||
|
||||
async def restart(self) -> None:
|
||||
"""(Experimental) Restart the code executor.
|
||||
|
||||
Resets the internal state of the executor by generating a new session ID and resetting the setup variables.
|
||||
This causes the next code execution to reinitialize the environment and re-run any setup code.
|
||||
"""
|
||||
self._session_id = str(uuid4())
|
||||
self._setup_functions_complete = False
|
||||
self._access_token = None
|
||||
self._available_packages = None
|
||||
self._setup_cwd_complete = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""(Experimental) Start the code executor.
|
||||
|
||||
Marks the code executor as started."""
|
||||
# No setup needed for this executor
|
||||
self._started = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""(Experimental) Stop the code executor.
|
||||
|
||||
Stops the code executor after cleaning up the temporary working directory (if it was created)."""
|
||||
if self._temp_dir is not None:
|
||||
self._temp_dir.cleanup()
|
||||
self._temp_dir = None
|
||||
self._started = False
|
||||
@@ -0,0 +1,3 @@
|
||||
from ._docker_code_executor import DockerCommandLineCodeExecutor
|
||||
|
||||
__all__ = ["DockerCommandLineCodeExecutor"]
|
||||
@@ -0,0 +1,613 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/docker_commandline_code_executor.py
|
||||
# Credit to original authors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import shlex
|
||||
import sys
|
||||
import tempfile
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import Future as ConcurrentFuture
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, ClassVar, Dict, List, Optional, ParamSpec, Tuple, Union
|
||||
|
||||
from agentdhal_core import CancellationToken, Component
|
||||
from agentdhal_core.code_executor import (
|
||||
CodeBlock,
|
||||
CodeExecutor,
|
||||
FunctionWithRequirements,
|
||||
FunctionWithRequirementsStr,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from docker.types import DeviceRequest
|
||||
|
||||
from .._common import (
|
||||
CommandLineCodeResult,
|
||||
build_python_functions_file,
|
||||
get_file_name_from_content,
|
||||
lang_to_cmd,
|
||||
silence_pip,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self
|
||||
else:
|
||||
from typing_extensions import Self
|
||||
|
||||
try:
|
||||
import asyncio_atexit
|
||||
|
||||
import docker
|
||||
from docker.errors import DockerException, ImageNotFound, NotFound
|
||||
from docker.models.containers import Container
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Missing dependecies for DockerCommandLineCodeExecutor. Please ensure the autogen-ext package was installed with the 'docker' extra."
|
||||
) from e
|
||||
|
||||
|
||||
async def _wait_for_ready(container: Any, timeout: int = 60, stop_time: float = 0.1) -> None:
|
||||
elapsed_time = 0.0
|
||||
while container.status != "running" and elapsed_time < timeout:
|
||||
await asyncio.sleep(stop_time)
|
||||
elapsed_time += stop_time
|
||||
await asyncio.to_thread(container.reload)
|
||||
continue
|
||||
if container.status != "running":
|
||||
raise ValueError("Container failed to start")
|
||||
|
||||
|
||||
A = ParamSpec("A")
|
||||
|
||||
|
||||
class DockerCommandLineCodeExecutorConfig(BaseModel):
|
||||
"""Configuration for DockerCommandLineCodeExecutor"""
|
||||
|
||||
image: str = "python:3-slim"
|
||||
container_name: Optional[str] = None
|
||||
timeout: int = 60
|
||||
work_dir: Optional[str] = None
|
||||
bind_dir: Optional[str] = None
|
||||
auto_remove: bool = True
|
||||
stop_container: bool = True
|
||||
functions_module: str = "functions"
|
||||
extra_volumes: Dict[str, Dict[str, str]] = {}
|
||||
extra_hosts: Dict[str, str] = {}
|
||||
init_command: Optional[str] = None
|
||||
delete_tmp_files: bool = False
|
||||
|
||||
|
||||
class DockerCommandLineCodeExecutor(CodeExecutor, Component[DockerCommandLineCodeExecutorConfig]):
|
||||
"""Executes code through a command line environment in a Docker container.
|
||||
|
||||
.. note::
|
||||
|
||||
This class requires the :code:`docker` extra for the :code:`autogen-ext` package:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[docker]"
|
||||
|
||||
|
||||
The executor first saves each code block in a file in the working
|
||||
directory, and then executes the code file in the container.
|
||||
The executor executes the code blocks in the order they are received.
|
||||
Currently, the executor only supports Python and shell scripts.
|
||||
For Python code, use the language "python" for the code block.
|
||||
For shell scripts, use the language "bash", "shell", "sh", "pwsh", "powershell", or "ps1" for the code block.
|
||||
|
||||
Args:
|
||||
image (_type_, optional): Docker image to use for code execution.
|
||||
Defaults to "python:3-slim".
|
||||
container_name (Optional[str], optional): Name of the Docker container
|
||||
which is created. If None, will autogenerate a name. Defaults to None.
|
||||
timeout (int, optional): The timeout for code execution. Defaults to 60.
|
||||
work_dir (Union[Path, str], optional): The working directory for the code
|
||||
execution. Defaults to temporary directory.
|
||||
bind_dir (Union[Path, str], optional): The directory that will be bound
|
||||
to the code executor container. Useful for cases where you want to spawn
|
||||
the container from within a container. Defaults to work_dir.
|
||||
auto_remove (bool, optional): If true, will automatically remove the Docker
|
||||
container when it is stopped. Defaults to True.
|
||||
stop_container (bool, optional): If true, will automatically stop the
|
||||
container when stop is called, when the context manager exits or when
|
||||
the Python process exits with atext. Defaults to True.
|
||||
device_requests (Optional[List[DeviceRequest]], optional): A list of device request instances to add to the container for exposing GPUs (e.g., [docker.types.DeviceRequest(count=-1, capabilities=[['gpu']])]). Defaults to None.
|
||||
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
|
||||
functions_module (str, optional): The name of the module that will be created to store the functions. Defaults to "functions".
|
||||
extra_volumes (Optional[Dict[str, Dict[str, str]]], optional): A dictionary of extra volumes (beyond the work_dir) to mount to the container;
|
||||
key is host source path and value 'bind' is the container path. See Defaults to None.
|
||||
Example: extra_volumes = {'/home/user1/': {'bind': '/mnt/vol2', 'mode': 'rw'}, '/var/www': {'bind': '/mnt/vol1', 'mode': 'ro'}}
|
||||
extra_hosts (Optional[Dict[str, str]], optional): A dictionary of host mappings to add to the container. (See Docker docs on extra_hosts) Defaults to None.
|
||||
Example: extra_hosts = {"kubernetes.docker.internal": "host-gateway"}
|
||||
init_command (Optional[str], optional): A shell command to run before each shell operation execution. Defaults to None.
|
||||
Example: init_command="kubectl config use-context docker-hub"
|
||||
delete_tmp_files (bool, optional): If true, will delete temporary files after execution. Defaults to False.
|
||||
|
||||
.. note::
|
||||
Using the current directory (".") as working directory is deprecated. Using it will raise a deprecation warning.
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = DockerCommandLineCodeExecutorConfig
|
||||
component_provider_override = "agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor"
|
||||
|
||||
SUPPORTED_LANGUAGES: ClassVar[List[str]] = [
|
||||
"bash",
|
||||
"shell",
|
||||
"sh",
|
||||
"pwsh",
|
||||
"powershell",
|
||||
"ps1",
|
||||
"python",
|
||||
]
|
||||
|
||||
FUNCTION_PROMPT_TEMPLATE: ClassVar[
|
||||
str
|
||||
] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names.
|
||||
|
||||
For example, if there was a function called `foo` you could import it by writing `from $module_name import foo`
|
||||
|
||||
$functions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str = "python:3-slim",
|
||||
container_name: Optional[str] = None,
|
||||
*,
|
||||
timeout: int = 60,
|
||||
work_dir: Union[Path, str, None] = None,
|
||||
bind_dir: Optional[Union[Path, str]] = None,
|
||||
auto_remove: bool = True,
|
||||
stop_container: bool = True,
|
||||
device_requests: Optional[List[DeviceRequest]] = None,
|
||||
functions: Sequence[
|
||||
Union[
|
||||
FunctionWithRequirements[Any, A],
|
||||
Callable[..., Any],
|
||||
FunctionWithRequirementsStr,
|
||||
]
|
||||
] = [],
|
||||
functions_module: str = "functions",
|
||||
extra_volumes: Optional[Dict[str, Dict[str, str]]] = None,
|
||||
extra_hosts: Optional[Dict[str, str]] = None,
|
||||
init_command: Optional[str] = None,
|
||||
delete_tmp_files: bool = False,
|
||||
):
|
||||
if timeout < 1:
|
||||
raise ValueError("Timeout must be greater than or equal to 1.")
|
||||
|
||||
# Handle working directory logic
|
||||
if work_dir is None:
|
||||
self._work_dir = None
|
||||
else:
|
||||
if isinstance(work_dir, str):
|
||||
work_dir = Path(work_dir)
|
||||
# Emit a deprecation warning if the user is using the current directory as working directory
|
||||
if work_dir.resolve() == Path.cwd().resolve():
|
||||
warnings.warn(
|
||||
"Using the current directory as work_dir is deprecated.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._work_dir = work_dir
|
||||
# Create the working directory if it doesn't exist
|
||||
self._work_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
if container_name is None:
|
||||
self.container_name = f"agentdhal-code-exec-{uuid.uuid4()}"
|
||||
else:
|
||||
self.container_name = container_name
|
||||
|
||||
self._timeout = timeout
|
||||
|
||||
# Handle bind_dir
|
||||
self._bind_dir: Optional[Path] = None
|
||||
if bind_dir is not None:
|
||||
self._bind_dir = Path(bind_dir) if isinstance(bind_dir, str) else bind_dir
|
||||
else:
|
||||
self._bind_dir = self._work_dir # Default to work_dir if not provided
|
||||
|
||||
# Track temporary directory
|
||||
self._temp_dir: Optional[tempfile.TemporaryDirectory[str]] = None
|
||||
self._temp_dir_path: Optional[Path] = None
|
||||
|
||||
self._started = False
|
||||
|
||||
self._auto_remove = auto_remove
|
||||
self._stop_container = stop_container
|
||||
self._image = image
|
||||
|
||||
if not functions_module.isidentifier():
|
||||
raise ValueError("Module name must be a valid Python identifier")
|
||||
|
||||
self._functions_module = functions_module
|
||||
self._functions = functions
|
||||
self._extra_volumes = extra_volumes if extra_volumes is not None else {}
|
||||
self._extra_hosts = extra_hosts if extra_hosts is not None else {}
|
||||
self._init_command = init_command
|
||||
self._delete_tmp_files = delete_tmp_files
|
||||
self._device_requests = device_requests
|
||||
|
||||
# Setup could take some time so we intentionally wait for the first code block to do it.
|
||||
if len(functions) > 0:
|
||||
self._setup_functions_complete = False
|
||||
else:
|
||||
self._setup_functions_complete = True
|
||||
|
||||
self._container: Container | None = None
|
||||
self._running = False
|
||||
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._cancellation_futures: List[ConcurrentFuture[None]] = []
|
||||
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
"""(Experimental) The timeout for code execution."""
|
||||
return self._timeout
|
||||
|
||||
async def _setup_functions(self, cancellation_token: CancellationToken) -> None:
|
||||
func_file_content = build_python_functions_file(self._functions)
|
||||
func_file = self.work_dir / f"{self._functions_module}.py"
|
||||
func_file.write_text(func_file_content)
|
||||
|
||||
# Collect requirements
|
||||
lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)]
|
||||
flattened_packages = [item for sublist in lists_of_packages for item in sublist]
|
||||
required_packages = list(set(flattened_packages))
|
||||
if len(required_packages) > 0:
|
||||
logging.info("Ensuring packages are installed in executor.")
|
||||
|
||||
packages = shlex.join(required_packages)
|
||||
|
||||
result = await self._execute_code_dont_check_setup(
|
||||
[CodeBlock(code=f"python -m pip install {packages}", language="sh")], cancellation_token
|
||||
)
|
||||
|
||||
if result.exit_code != 0:
|
||||
stdout = result.output
|
||||
stderr = result.output
|
||||
raise ValueError(f"Pip install failed. {stdout}, {stderr}")
|
||||
|
||||
# Attempt to load the function file to check for syntax errors, imports etc.
|
||||
exec_result = await self._execute_code_dont_check_setup(
|
||||
[CodeBlock(code=func_file_content, language="python")], cancellation_token
|
||||
)
|
||||
|
||||
if exec_result.exit_code != 0:
|
||||
raise ValueError(f"Functions failed to load: {exec_result.output}")
|
||||
|
||||
self._setup_functions_complete = True
|
||||
|
||||
async def _kill_running_command(self, command: List[str]) -> None:
|
||||
if self._container is None or not self._running:
|
||||
return
|
||||
await asyncio.to_thread(self._container.exec_run, ["pkill", "-f", " ".join(command)])
|
||||
|
||||
async def _execute_command(self, command: List[str], cancellation_token: CancellationToken) -> Tuple[str, int]:
|
||||
if self._container is None or not self._running:
|
||||
raise ValueError("Container is not running. Must first be started with either start or a context manager.")
|
||||
|
||||
exec_task = asyncio.create_task(asyncio.to_thread(self._container.exec_run, command))
|
||||
cancellation_token.link_future(exec_task)
|
||||
|
||||
# Wait for the exec task to finish.
|
||||
try:
|
||||
result = await exec_task
|
||||
exit_code = result.exit_code
|
||||
output = result.output.decode("utf-8")
|
||||
if exit_code == 124:
|
||||
output += "\n Timeout"
|
||||
return output, exit_code
|
||||
except asyncio.CancelledError:
|
||||
# Schedule a task to kill the running command in the background.
|
||||
if self._loop and not self._loop.is_closed():
|
||||
try:
|
||||
logging.debug(f"Scheduling kill command via run_coroutine_threadsafe on loop {self._loop!r}")
|
||||
future: ConcurrentFuture[None] = asyncio.run_coroutine_threadsafe(
|
||||
self._kill_running_command(command), self._loop
|
||||
)
|
||||
self._cancellation_futures.append(future)
|
||||
logging.debug(f"Kill command scheduled, future: {future!r}")
|
||||
except RuntimeError as e:
|
||||
logging.error(f"Failed to schedule kill command on loop {self._loop!r}: {e}")
|
||||
except Exception as e:
|
||||
logging.exception(f"Unexpected error scheduling kill command: {e}")
|
||||
else:
|
||||
logging.warning(
|
||||
f"Cannot schedule kill command: Executor loop is not available or closed (loop: {self._loop!r})."
|
||||
)
|
||||
return "Code execution was cancelled.", 1
|
||||
|
||||
async def _execute_code_dont_check_setup(
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> CommandLineCodeResult:
|
||||
if self._container is None or not self._running:
|
||||
raise ValueError("Container is not running. Must first be started with either start or a context manager.")
|
||||
|
||||
if len(code_blocks) == 0:
|
||||
raise ValueError("No code blocks to execute.")
|
||||
|
||||
outputs: List[str] = []
|
||||
files: List[Path] = []
|
||||
last_exit_code = 0
|
||||
try:
|
||||
for code_block in code_blocks:
|
||||
lang = code_block.language.lower()
|
||||
code = silence_pip(code_block.code, lang)
|
||||
|
||||
# Check if there is a filename comment
|
||||
try:
|
||||
filename = get_file_name_from_content(code, self.work_dir)
|
||||
except ValueError:
|
||||
outputs.append("Filename is not in the workspace")
|
||||
last_exit_code = 1
|
||||
break
|
||||
|
||||
if not filename:
|
||||
filename = f"tmp_code_{sha256(code.encode()).hexdigest()}.{lang}"
|
||||
|
||||
code_path = self.work_dir / filename
|
||||
with code_path.open("w", encoding="utf-8") as fout:
|
||||
fout.write(code)
|
||||
files.append(code_path)
|
||||
|
||||
command = ["timeout", str(self._timeout), lang_to_cmd(lang), filename]
|
||||
|
||||
output, exit_code = await self._execute_command(command, cancellation_token)
|
||||
outputs.append(output)
|
||||
last_exit_code = exit_code
|
||||
if exit_code != 0:
|
||||
break
|
||||
finally:
|
||||
if self._delete_tmp_files:
|
||||
for file in files:
|
||||
try:
|
||||
file.unlink()
|
||||
except (OSError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
code_file = str(files[0]) if files else None
|
||||
return CommandLineCodeResult(exit_code=last_exit_code, output="".join(outputs), code_file=code_file)
|
||||
|
||||
@property
|
||||
def work_dir(self) -> Path:
|
||||
# If a user specifies a working directory, use that
|
||||
if self._work_dir is not None:
|
||||
# If a user specifies the current directory, warn them that this is deprecated
|
||||
if self._work_dir == Path("."):
|
||||
warnings.warn(
|
||||
"Using the current directory as work_dir is deprecated.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self._work_dir
|
||||
# If a user does not specify a working directory, use the default directory (tempfile.TemporaryDirectory)
|
||||
elif self._temp_dir is not None:
|
||||
return Path(self._temp_dir.name)
|
||||
else:
|
||||
raise RuntimeError("Working directory not properly initialized")
|
||||
|
||||
@property
|
||||
def bind_dir(self) -> Path:
|
||||
# If the user specified a bind directory, return it
|
||||
if self._bind_dir is not None:
|
||||
return self._bind_dir
|
||||
# Otherwise bind_dir is set to the current work_dir as default
|
||||
else:
|
||||
return self.work_dir
|
||||
|
||||
async def execute_code_blocks(
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> CommandLineCodeResult:
|
||||
"""(Experimental) Execute the code blocks and return the result.
|
||||
|
||||
Args:
|
||||
code_blocks (List[CodeBlock]): The code blocks to execute.
|
||||
|
||||
Returns:
|
||||
CommandlineCodeResult: The result of the code execution."""
|
||||
|
||||
if not self._setup_functions_complete:
|
||||
await self._setup_functions(cancellation_token)
|
||||
|
||||
return await self._execute_code_dont_check_setup(code_blocks, cancellation_token)
|
||||
|
||||
async def restart(self) -> None:
|
||||
"""(Experimental) Restart the Docker container code executor."""
|
||||
if self._container is None or not self._running:
|
||||
raise ValueError("Container is not running. Must first be started with either start or a context manager.")
|
||||
|
||||
await asyncio.to_thread(self._container.restart) # type: ignore
|
||||
if self._container.status != "running":
|
||||
self._running = False
|
||||
logs_str = self._container.logs().decode("utf-8")
|
||||
raise ValueError(f"Failed to restart container. Logs: {logs_str}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""(Experimental) Stop the code executor.
|
||||
|
||||
Stops the Docker container and cleans up any temporary files (if they were created), along with the temporary directory.
|
||||
The method first waits for all cancellation tasks to finish before stopping the container. Finally it marks the executor as not running.
|
||||
If the container is not running, the method does nothing.
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
if self._temp_dir is not None:
|
||||
self._temp_dir.cleanup()
|
||||
self._temp_dir = None
|
||||
|
||||
client = docker.from_env()
|
||||
try:
|
||||
try:
|
||||
container = await asyncio.to_thread(client.containers.get, self.container_name)
|
||||
except NotFound:
|
||||
logging.debug(f"Container {self.container_name} not found during stop...")
|
||||
self._running = False
|
||||
self._cancellation_futures.clear()
|
||||
return
|
||||
|
||||
if self._cancellation_futures:
|
||||
if not self._loop or self._loop.is_closed():
|
||||
logging.warning(
|
||||
f"Executor loop ({self._loop!r}) is closed or unavailable. Cannot reliably wait for "
|
||||
f"{len(self._cancellation_futures)} cancellation futures."
|
||||
)
|
||||
self._cancellation_futures.clear()
|
||||
else:
|
||||
# concurrent.futures.Future -> asyncio.Future
|
||||
asyncio_futures = [asyncio.wrap_future(f, loop=self._loop) for f in self._cancellation_futures]
|
||||
|
||||
if asyncio_futures:
|
||||
logging.debug(
|
||||
f"Waiting for {len(asyncio_futures)} cancellation futures to complete on loop {self._loop!r}..."
|
||||
)
|
||||
results = await asyncio.gather(*asyncio_futures, return_exceptions=True)
|
||||
for i, result in enumerate(results):
|
||||
original_future = self._cancellation_futures[i]
|
||||
if isinstance(result, Exception):
|
||||
logging.warning(f"Cancellation future {original_future!r} failed: {result}")
|
||||
else:
|
||||
logging.debug(f"Cancellation future {original_future!r} completed successfully.")
|
||||
else:
|
||||
logging.debug("No valid cancellation futures to await.")
|
||||
|
||||
self._cancellation_futures.clear()
|
||||
|
||||
logging.debug(f"Stopping container {self.container_name}...")
|
||||
await asyncio.to_thread(container.stop)
|
||||
logging.debug(f"Container {self.container_name} stopped.")
|
||||
|
||||
except DockerException as e:
|
||||
logging.error(f"Docker error while stopping container {self.container_name}: {e}")
|
||||
except Exception as e:
|
||||
logging.exception(f"Unexpected error during stop operation for container {self.container_name}: {e}")
|
||||
finally:
|
||||
self._running = False
|
||||
self._cancellation_futures.clear()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""(Experimental) Start the code executor.
|
||||
|
||||
This method sets the working environment variables, connects to Docker and starts the code executor.
|
||||
If no working directory was provided to the code executor, it creates a temporary directory and sets it as the code executor working directory.
|
||||
"""
|
||||
|
||||
if self._work_dir is None and self._temp_dir is None:
|
||||
self._temp_dir = tempfile.TemporaryDirectory()
|
||||
self._temp_dir_path = Path(self._temp_dir.name)
|
||||
self._temp_dir_path.mkdir(exist_ok=True)
|
||||
|
||||
# Start a container from the image, read to exec commands later
|
||||
try:
|
||||
client = docker.from_env()
|
||||
except DockerException as e:
|
||||
if "FileNotFoundError" in str(e):
|
||||
raise RuntimeError("Failed to connect to Docker. Please ensure Docker is installed and running.") from e
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unexpected error while connecting to Docker: {str(e)}") from e
|
||||
|
||||
# Check if the image exists
|
||||
try:
|
||||
await asyncio.to_thread(client.images.get, self._image)
|
||||
except ImageNotFound:
|
||||
# TODO logger
|
||||
logging.info(f"Pulling image {self._image}...")
|
||||
# Let the docker exception escape if this fails.
|
||||
await asyncio.to_thread(client.images.pull, self._image)
|
||||
|
||||
# Prepare the command (if needed)
|
||||
shell_command = "/bin/sh"
|
||||
command = ["-c", f"{(self._init_command)};exec {shell_command}"] if self._init_command else None
|
||||
|
||||
# Check if a container with the same name already exists and remove it
|
||||
try:
|
||||
existing_container = await asyncio.to_thread(client.containers.get, self.container_name)
|
||||
await asyncio.to_thread(existing_container.remove, force=True)
|
||||
except NotFound:
|
||||
pass
|
||||
|
||||
self._container = await asyncio.to_thread(
|
||||
client.containers.create,
|
||||
self._image,
|
||||
name=self.container_name,
|
||||
entrypoint=shell_command,
|
||||
command=command,
|
||||
tty=True,
|
||||
detach=True,
|
||||
auto_remove=self._auto_remove,
|
||||
volumes={str(self.bind_dir.resolve()): {"bind": "/workspace", "mode": "rw"}, **self._extra_volumes},
|
||||
working_dir="/workspace",
|
||||
extra_hosts=self._extra_hosts,
|
||||
device_requests=self._device_requests,
|
||||
)
|
||||
await asyncio.to_thread(self._container.start)
|
||||
|
||||
await _wait_for_ready(self._container)
|
||||
|
||||
async def cleanup() -> None:
|
||||
await self.stop()
|
||||
asyncio_atexit.unregister(cleanup) # type: ignore
|
||||
|
||||
if self._stop_container:
|
||||
asyncio_atexit.register(cleanup) # type: ignore
|
||||
|
||||
# Check if the container is running
|
||||
if self._container.status != "running":
|
||||
logs_str = self._container.logs().decode("utf-8")
|
||||
raise ValueError(f"Failed to start container from image {self._image}. Logs: {logs_str}")
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._cancellation_futures = []
|
||||
logging.debug(f"Executor started, associated with event loop: {self._loop!r}")
|
||||
|
||||
self._running = True
|
||||
|
||||
def _to_config(self) -> DockerCommandLineCodeExecutorConfig:
|
||||
"""(Experimental) Convert the component to a config object."""
|
||||
if self._functions:
|
||||
logging.info("Functions will not be included in serialized configuration")
|
||||
|
||||
return DockerCommandLineCodeExecutorConfig(
|
||||
image=self._image,
|
||||
container_name=self.container_name,
|
||||
timeout=self._timeout,
|
||||
work_dir=str(self._work_dir) if self._work_dir else None,
|
||||
bind_dir=str(self._bind_dir) if self._bind_dir else None,
|
||||
auto_remove=self._auto_remove,
|
||||
stop_container=self._stop_container,
|
||||
functions_module=self._functions_module,
|
||||
extra_volumes=self._extra_volumes,
|
||||
extra_hosts=self._extra_hosts,
|
||||
init_command=self._init_command,
|
||||
delete_tmp_files=self._delete_tmp_files,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: DockerCommandLineCodeExecutorConfig) -> Self:
|
||||
"""(Experimental) Create a component from a config object."""
|
||||
|
||||
return cls(
|
||||
image=config.image,
|
||||
container_name=config.container_name,
|
||||
timeout=config.timeout,
|
||||
work_dir=Path(config.work_dir) if config.work_dir else None,
|
||||
bind_dir=Path(config.bind_dir) if config.bind_dir else None,
|
||||
auto_remove=config.auto_remove,
|
||||
stop_container=config.stop_container,
|
||||
functions=[], # Functions not restored from config
|
||||
functions_module=config.functions_module,
|
||||
extra_volumes=config.extra_volumes,
|
||||
extra_hosts=config.extra_hosts,
|
||||
init_command=config.init_command,
|
||||
delete_tmp_files=config.delete_tmp_files,
|
||||
)
|
||||
@@ -0,0 +1,10 @@
|
||||
from ._docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterCodeResult
|
||||
from ._jupyter_server import DockerJupyterServer, JupyterClient, JupyterKernelClient
|
||||
|
||||
__all__ = [
|
||||
"DockerJupyterCodeExecutor",
|
||||
"DockerJupyterServer",
|
||||
"JupyterClient",
|
||||
"JupyterKernelClient",
|
||||
"DockerJupyterCodeResult",
|
||||
]
|
||||
@@ -0,0 +1,300 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from agentdhal_core import CancellationToken, Component
|
||||
from agentdhal_core.code_executor import CodeBlock, CodeExecutor, CodeResult
|
||||
from agentdhal_extensions.code_executors._common import silence_pip
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._jupyter_server import JupyterClient, JupyterConnectable, JupyterConnectionInfo, JupyterKernelClient
|
||||
|
||||
|
||||
@dataclass
|
||||
class DockerJupyterCodeResult(CodeResult):
|
||||
"""(Experimental) A code result class for IPython code executor."""
|
||||
|
||||
output_files: list[Path]
|
||||
|
||||
|
||||
class DockerJupyterCodeExecutorConfig(BaseModel):
|
||||
"""Configuration for JupyterCodeExecutor"""
|
||||
|
||||
jupyter_server: Union[JupyterConnectable, JupyterConnectionInfo]
|
||||
kernel_name: str = "python3"
|
||||
timeout: int = 60
|
||||
output_dir: Optional[Union[Path, str]] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class DockerJupyterCodeExecutor(CodeExecutor, Component[DockerJupyterCodeExecutorConfig]):
|
||||
"""(Experimental) A code executor class that executes code statefully using
|
||||
a Jupyter server supplied to this class.
|
||||
|
||||
Each execution is stateful and can access variables created from previous
|
||||
executions in the same session.
|
||||
|
||||
To use this, you need to install the following dependencies:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install "agentdhal-ext[docker-jupyter-executor]"
|
||||
|
||||
Args:
|
||||
jupyter_server (Union[JupyterConnectable, JupyterConnectionInfo]): The Jupyter server to use.
|
||||
kernel_name (str): The kernel name to use. Make sure it is installed.
|
||||
By default, it is "python3".
|
||||
timeout (int): The timeout for code execution, by default 60.
|
||||
output_dir (str): The directory to save output files, by default None.
|
||||
|
||||
Example of using it directly:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.code_executor import CodeBlock
|
||||
from agentdhal_extensions.code_executors.docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterServer
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
async with DockerJupyterServer() as jupyter_server:
|
||||
async with DockerJupyterCodeExecutor(jupyter_server=jupyter_server) as executor:
|
||||
code_blocks = [CodeBlock(code="print('hello world!')", language="python")]
|
||||
code_result = await executor.execute_code_blocks(code_blocks, cancellation_token=CancellationToken())
|
||||
print(code_result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Example of using it with your own jupyter image:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.code_executor import CodeBlock
|
||||
from agentdhal_extensions.code_executors.docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterServer
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
async with DockerJupyterServer(custom_image_name="your_custom_images_name", expose_port=8888) as jupyter_server:
|
||||
async with DockerJupyterCodeExecutor(jupyter_server=jupyter_server) as executor:
|
||||
code_blocks = [CodeBlock(code="print('hello world!')", language="python")]
|
||||
code_result = await executor.execute_code_blocks(code_blocks, cancellation_token=CancellationToken())
|
||||
print(code_result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Example of using it with :class:`~agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_extensions.code_executors.docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterServer
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.tools.code_execution import PythonCodeExecutionTool
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
async with DockerJupyterServer() as jupyter_server:
|
||||
async with DockerJupyterCodeExecutor(jupyter_server=jupyter_server) as executor:
|
||||
tool = PythonCodeExecutionTool(executor)
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
agent = AssistantAgent("assistant", model_client=model_client, tools=[tool])
|
||||
result = await agent.run(task="What is the 10th Fibonacci number? Use Python to calculate it.")
|
||||
print(result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Example of using it inside a :class:`~agentdhal_agentchat.agents._code_executor_agent.CodeExecutorAgent`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import CodeExecutorAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
from agentdhal_extensions.code_executors.docker_jupyter import DockerJupyterCodeExecutor, DockerJupyterServer
|
||||
from agentdhal_core import CancellationToken
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
async with DockerJupyterServer() as jupyter_server:
|
||||
async with DockerJupyterCodeExecutor(jupyter_server=jupyter_server) as executor:
|
||||
code_executor_agent = CodeExecutorAgent("code_executor", code_executor=executor)
|
||||
task = TextMessage(
|
||||
content='''Here is some code
|
||||
```python
|
||||
print('Hello world')
|
||||
```
|
||||
''',
|
||||
source="user",
|
||||
)
|
||||
response = await code_executor_agent.on_messages([task], CancellationToken())
|
||||
print(response.chat_message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = DockerJupyterCodeExecutorConfig
|
||||
component_provider_override = "agentdhal_extensions.code_executors.docker_jupyter.DockerJupyterCodeExecutor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
jupyter_server: Union[JupyterConnectable, JupyterConnectionInfo],
|
||||
kernel_name: str = "python3",
|
||||
timeout: int = 60,
|
||||
output_dir: Path | None = None,
|
||||
):
|
||||
if timeout < 1:
|
||||
raise ValueError("Timeout must be greater than or equal to 1.")
|
||||
|
||||
if isinstance(jupyter_server, JupyterConnectable):
|
||||
self._connection_info = jupyter_server.connection_info
|
||||
elif isinstance(jupyter_server, JupyterConnectionInfo):
|
||||
self._connection_info = jupyter_server
|
||||
else:
|
||||
raise ValueError("jupyter_server must be a JupyterConnectable or JupyterConnectionInfo.")
|
||||
|
||||
self._output_dir = output_dir or getattr(jupyter_server, "_bind_dir", None)
|
||||
if not self._output_dir:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
self._output_dir = Path(temp_dir)
|
||||
self._output_dir.mkdir(exist_ok=True)
|
||||
|
||||
self._jupyter_client = JupyterClient(self._connection_info)
|
||||
|
||||
self._kernel_name = kernel_name
|
||||
self._timeout = timeout
|
||||
self._async_jupyter_kernel_client: Optional[JupyterKernelClient] = None
|
||||
self._kernel_id: Optional[str] = None
|
||||
|
||||
async def _ensure_async_kernel_client(self) -> JupyterKernelClient:
|
||||
"""Ensure that an async kernel client exists and return it."""
|
||||
if self._kernel_id is None:
|
||||
await self.start()
|
||||
assert self._kernel_id is not None
|
||||
if self._async_jupyter_kernel_client is None:
|
||||
self._async_jupyter_kernel_client = await self._jupyter_client.get_kernel_client(self._kernel_id)
|
||||
return self._async_jupyter_kernel_client
|
||||
|
||||
async def execute_code_blocks(
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> DockerJupyterCodeResult:
|
||||
"""(Experimental) Execute a list of code blocks and return the result.
|
||||
|
||||
This method executes a list of code blocks as cells in the Jupyter kernel.
|
||||
See: https://jupyter-client.readthedocs.io/en/stable/messaging.html
|
||||
for the message protocol.
|
||||
|
||||
Args:
|
||||
code_blocks (List[CodeBlock]): A list of code blocks to execute.
|
||||
|
||||
Returns:
|
||||
DockerJupyterCodeResult: The result of the code execution.
|
||||
"""
|
||||
kernel_client = await self._ensure_async_kernel_client()
|
||||
# Wait for kernel to be ready using async client
|
||||
is_ready = await kernel_client.wait_for_ready(timeout_seconds=self._timeout)
|
||||
if not is_ready:
|
||||
return DockerJupyterCodeResult(exit_code=1, output="ERROR: Kernel not ready", output_files=[])
|
||||
|
||||
outputs: List[str] = []
|
||||
output_files: List[Path] = []
|
||||
for code_block in code_blocks:
|
||||
code = silence_pip(code_block.code, code_block.language)
|
||||
# Execute code using async client
|
||||
exec_task = asyncio.create_task(kernel_client.execute(code, timeout_seconds=self._timeout))
|
||||
cancellation_token.link_future(exec_task)
|
||||
result = await exec_task
|
||||
if result.is_ok:
|
||||
outputs.append(result.output)
|
||||
for data in result.data_items:
|
||||
if data.mime_type == "image/png":
|
||||
path = self._save_image(data.data)
|
||||
outputs.append(path)
|
||||
output_files.append(Path(path))
|
||||
elif data.mime_type == "text/html":
|
||||
path = self._save_html(data.data)
|
||||
outputs.append(path)
|
||||
output_files.append(Path(path))
|
||||
else:
|
||||
outputs.append(json.dumps(data.data))
|
||||
else:
|
||||
existing_output = "\n".join([str(output) for output in outputs])
|
||||
return DockerJupyterCodeResult(
|
||||
exit_code=1, output=existing_output + "\nERROR: " + result.output, output_files=output_files
|
||||
)
|
||||
return DockerJupyterCodeResult(
|
||||
exit_code=0, output="\n".join([str(output) for output in outputs]), output_files=output_files
|
||||
)
|
||||
|
||||
async def restart(self) -> None:
|
||||
"""(Experimental) Restart a new session."""
|
||||
# Use async client to restart kernel
|
||||
if self._kernel_id is not None:
|
||||
await self._jupyter_client.restart_kernel(self._kernel_id)
|
||||
# Reset the clients to force recreation
|
||||
if self._async_jupyter_kernel_client is not None:
|
||||
await self._async_jupyter_kernel_client.stop()
|
||||
self._async_jupyter_kernel_client = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""(Experimental) Start a new session."""
|
||||
available_kernels = await self._jupyter_client.list_kernel_specs()
|
||||
if self._kernel_name not in available_kernels["kernelspecs"]:
|
||||
raise ValueError(f"Kernel {self._kernel_name} is not installed.")
|
||||
self._kernel_id = await self._jupyter_client.start_kernel(self._kernel_name)
|
||||
|
||||
def _save_image(self, image_data_base64: str) -> str:
|
||||
"""Save image data to a file."""
|
||||
image_data = base64.b64decode(image_data_base64)
|
||||
filename = f"{uuid.uuid4().hex}.png"
|
||||
path = os.path.join(str(self._output_dir), filename)
|
||||
with open(path, "wb") as f:
|
||||
f.write(image_data)
|
||||
return os.path.abspath(path)
|
||||
|
||||
def _save_html(self, html_data: str) -> str:
|
||||
"""Save html data to a file."""
|
||||
filename = f"{uuid.uuid4().hex}.html"
|
||||
path = os.path.join(str(self._output_dir), filename)
|
||||
with open(path, "w") as f:
|
||||
f.write(html_data)
|
||||
return os.path.abspath(path)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the kernel."""
|
||||
if self._kernel_id is not None:
|
||||
await self._jupyter_client.delete_kernel(self._kernel_id)
|
||||
if self._async_jupyter_kernel_client is not None:
|
||||
await self._async_jupyter_kernel_client.stop()
|
||||
self._async_jupyter_kernel_client = None
|
||||
await self._jupyter_client.close()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
await self.stop()
|
||||
@@ -0,0 +1,430 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, List, Optional, Protocol, Type, Union, cast, runtime_checkable
|
||||
|
||||
import aiohttp
|
||||
import docker
|
||||
import docker.errors
|
||||
import requests
|
||||
import websockets
|
||||
from requests.adapters import HTTPAdapter, Retry
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
@dataclass
|
||||
class JupyterConnectionInfo:
|
||||
"""(Experimental)"""
|
||||
|
||||
host: str
|
||||
"""`str` - Host of the Jupyter gateway server"""
|
||||
use_https: bool
|
||||
"""`bool` - Whether to use HTTPS"""
|
||||
port: Optional[int] = None
|
||||
"""`Optional[int]` - Port of the Jupyter gateway server. If None, the default port is used"""
|
||||
token: Optional[str] = None
|
||||
"""`Optional[str]` - Token for authentication. If None, no token is used"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class JupyterConnectable(Protocol):
|
||||
"""(Experimental)"""
|
||||
|
||||
@property
|
||||
def connection_info(self) -> JupyterConnectionInfo:
|
||||
"""Return the connection information for this connectable."""
|
||||
...
|
||||
|
||||
|
||||
class JupyterClient:
|
||||
def __init__(self, connection_info: JupyterConnectionInfo):
|
||||
"""(Experimental) A client for communicating with a Jupyter gateway server.
|
||||
|
||||
Args:
|
||||
connection_info (JupyterConnectionInfo): Connection information
|
||||
"""
|
||||
self._connection_info = connection_info
|
||||
self._session = requests.Session()
|
||||
retries = Retry(total=5, backoff_factor=0.1)
|
||||
self._session.mount("http://", HTTPAdapter(max_retries=retries))
|
||||
# Create aiohttp session for async requests
|
||||
self._async_session: aiohttp.ClientSession | None = None
|
||||
|
||||
async def _ensure_async_session(self) -> aiohttp.ClientSession:
|
||||
if self._async_session is None:
|
||||
self._async_session = aiohttp.ClientSession()
|
||||
return self._async_session
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
if self._connection_info.token is None:
|
||||
return {}
|
||||
return {"Authorization": f"token {self._connection_info.token}"}
|
||||
|
||||
def _get_api_base_url(self) -> str:
|
||||
protocol = "https" if self._connection_info.use_https else "http"
|
||||
port = f":{self._connection_info.port}" if self._connection_info.port else ""
|
||||
return f"{protocol}://{self._connection_info.host}{port}"
|
||||
|
||||
def _get_ws_base_url(self) -> str:
|
||||
port = f":{self._connection_info.port}" if self._connection_info.port else ""
|
||||
return f"ws://{self._connection_info.host}{port}"
|
||||
|
||||
async def list_kernel_specs(self) -> Dict[str, Dict[str, str]]:
|
||||
response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers())
|
||||
return cast(Dict[str, Dict[str, str]], response.json())
|
||||
|
||||
async def list_kernels(self) -> List[Dict[str, str]]:
|
||||
response = self._session.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers())
|
||||
return cast(List[Dict[str, str]], response.json())
|
||||
|
||||
async def start_kernel(self, kernel_spec_name: str) -> str:
|
||||
"""Start a new kernel asynchronously.
|
||||
|
||||
Args:
|
||||
kernel_spec_name (str): Name of the kernel spec to start
|
||||
|
||||
Returns:
|
||||
str: ID of the started kernel
|
||||
"""
|
||||
session = await self._ensure_async_session()
|
||||
async with session.post(
|
||||
f"{self._get_api_base_url()}/api/kernels",
|
||||
headers=self._get_headers(),
|
||||
json={"name": kernel_spec_name},
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return cast(str, data["id"])
|
||||
|
||||
async def delete_kernel(self, kernel_id: str) -> None:
|
||||
session = await self._ensure_async_session()
|
||||
async with session.delete(
|
||||
f"{self._get_api_base_url()}/api/kernels/{kernel_id}", headers=self._get_headers()
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async def restart_kernel(self, kernel_id: str) -> None:
|
||||
session = await self._ensure_async_session()
|
||||
async with session.post(
|
||||
f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", headers=self._get_headers()
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_kernel_client(self, kernel_id: str) -> "JupyterKernelClient":
|
||||
ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels"
|
||||
# Using websockets library for async websocket connections
|
||||
ws = await websockets.connect(ws_url, additional_headers=self._get_headers())
|
||||
return JupyterKernelClient(ws)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the async session"""
|
||||
if self._async_session is not None:
|
||||
await self._async_session.close()
|
||||
self._async_session = None
|
||||
self._session.close()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataItem:
|
||||
mime_type: str
|
||||
data: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
is_ok: bool
|
||||
output: str
|
||||
data_items: List[DataItem]
|
||||
|
||||
|
||||
class JupyterKernelClient:
|
||||
"""An asynchronous client for communicating with a Jupyter kernel."""
|
||||
|
||||
def __init__(self, websocket: websockets.ClientConnection) -> None:
|
||||
self._session_id = uuid.uuid4().hex
|
||||
self._websocket = websocket
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
|
||||
) -> None:
|
||||
await self.stop()
|
||||
|
||||
async def stop(self) -> None:
|
||||
await self._websocket.close()
|
||||
|
||||
async def _send_message(self, *, content: Dict[str, Any], channel: str, message_type: str) -> str:
|
||||
timestamp = datetime.datetime.now().isoformat()
|
||||
message_id = uuid.uuid4().hex
|
||||
message = {
|
||||
"header": {
|
||||
"username": "agentdhal",
|
||||
"version": "5.0",
|
||||
"session": self._session_id,
|
||||
"msg_id": message_id,
|
||||
"msg_type": message_type,
|
||||
"date": timestamp,
|
||||
},
|
||||
"parent_header": {},
|
||||
"channel": channel,
|
||||
"content": content,
|
||||
"metadata": {},
|
||||
"buffers": {},
|
||||
}
|
||||
await self._websocket.send(json.dumps(message))
|
||||
return message_id
|
||||
|
||||
async def _receive_message(self, timeout_seconds: Optional[float]) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
if timeout_seconds is not None:
|
||||
data = await asyncio.wait_for(self._websocket.recv(), timeout=timeout_seconds)
|
||||
else:
|
||||
data = await self._websocket.recv()
|
||||
if isinstance(data, bytes):
|
||||
return cast(Dict[str, Any], json.loads(data.decode("utf-8")))
|
||||
return cast(Dict[str, Any], json.loads(data))
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
|
||||
async def wait_for_ready(self, timeout_seconds: Optional[float] = None) -> bool:
|
||||
message_id = await self._send_message(content={}, channel="shell", message_type="kernel_info_request")
|
||||
while True:
|
||||
message = await self._receive_message(timeout_seconds)
|
||||
# This means we timed out with no new messages.
|
||||
if message is None:
|
||||
return False
|
||||
if (
|
||||
message.get("parent_header", {}).get("msg_id") == message_id
|
||||
and message["msg_type"] == "kernel_info_reply"
|
||||
):
|
||||
return True
|
||||
|
||||
async def execute(self, code: str, timeout_seconds: Optional[float] = None) -> ExecutionResult:
|
||||
message_id = await self._send_message(
|
||||
content={
|
||||
"code": code,
|
||||
"silent": False,
|
||||
"store_history": True,
|
||||
"user_expressions": {},
|
||||
"allow_stdin": False,
|
||||
"stop_on_error": True,
|
||||
},
|
||||
channel="shell",
|
||||
message_type="execute_request",
|
||||
)
|
||||
|
||||
text_output: List[str] = []
|
||||
data_output: List[DataItem] = []
|
||||
while True:
|
||||
message = await self._receive_message(timeout_seconds)
|
||||
if message is None:
|
||||
return ExecutionResult(
|
||||
is_ok=False, output="ERROR: Timeout waiting for output from code block.", data_items=[]
|
||||
)
|
||||
|
||||
# Ignore messages that are not for this execution.
|
||||
if message.get("parent_header", {}).get("msg_id") != message_id:
|
||||
continue
|
||||
|
||||
msg_type = message["msg_type"]
|
||||
content = message["content"]
|
||||
if msg_type in ["execute_result", "display_data"]:
|
||||
for data_type, data in content["data"].items():
|
||||
if data_type == "text/plain":
|
||||
text_output.append(data)
|
||||
elif data_type.startswith("image/") or data_type == "text/html":
|
||||
data_output.append(DataItem(mime_type=data_type, data=data))
|
||||
else:
|
||||
text_output.append(json.dumps(data))
|
||||
elif msg_type == "stream":
|
||||
text_output.append(content["text"])
|
||||
elif msg_type == "error":
|
||||
# Output is an error.
|
||||
return ExecutionResult(
|
||||
is_ok=False,
|
||||
output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}",
|
||||
data_items=[],
|
||||
)
|
||||
if msg_type == "status" and content["execution_state"] == "idle":
|
||||
break
|
||||
return ExecutionResult(
|
||||
is_ok=True, output="\n".join([str(output) for output in text_output]), data_items=data_output
|
||||
)
|
||||
|
||||
|
||||
class DockerJupyterServer(JupyterConnectable):
|
||||
DEFAULT_DOCKERFILE = """FROM quay.io/jupyter/docker-stacks-foundation
|
||||
|
||||
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
|
||||
|
||||
USER ${NB_UID}
|
||||
RUN mamba install --yes jupyter_kernel_gateway ipykernel && \
|
||||
mamba clean --all -f -y && \
|
||||
fix-permissions "${CONDA_DIR}" && \
|
||||
fix-permissions "/home/${NB_USER}"
|
||||
|
||||
ENV TOKEN="UNSET"
|
||||
CMD python -m jupyter kernelgateway --KernelGatewayApp.ip=0.0.0.0 \
|
||||
--KernelGatewayApp.port=8888 \
|
||||
--KernelGatewayApp.auth_token="${TOKEN}" \
|
||||
--JupyterApp.answer_yes=true \
|
||||
--JupyterWebsocketPersonality.list_kernels=true
|
||||
|
||||
EXPOSE 8888
|
||||
|
||||
WORKDIR "${HOME}"
|
||||
"""
|
||||
|
||||
class GenerateToken:
|
||||
pass
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
custom_image_name: Optional[str] = None,
|
||||
container_name: Optional[str] = None,
|
||||
auto_remove: bool = True,
|
||||
stop_container: bool = True,
|
||||
docker_env: Optional[Dict[str, str]] = None,
|
||||
expose_port: int = 8888,
|
||||
token: Optional[Union[str, GenerateToken]] = None,
|
||||
work_dir: Union[Path, str] = "/workspace",
|
||||
bind_dir: Optional[Union[Path, str]] = None,
|
||||
):
|
||||
"""Start a Jupyter kernel gateway server in a Docker container.
|
||||
|
||||
Args:
|
||||
custom_image_name: Custom Docker image to use. If None, builds and uses bundled image.
|
||||
container_name: Name for the Docker container. Auto-generated if None.
|
||||
auto_remove: If True, container will be deleted when stopped.
|
||||
stop_container: If True, container stops on program exit or when context manager exits.
|
||||
docker_env: Additional environment variables for the container.
|
||||
expose_port: Port to expose for Jupyter connection.
|
||||
token: Authentication token. If GenerateToken, creates random token. Empty for no auth.
|
||||
work_dir: Working directory inside the container.
|
||||
bind_dir: Local directory to bind to container's work_dir.
|
||||
"""
|
||||
# Generate container name if not provided
|
||||
container_name = container_name or f"agentdhal-jupyterkernelgateway-{uuid.uuid4()}"
|
||||
|
||||
# Initialize Docker client
|
||||
client = docker.from_env()
|
||||
# Set up bind directory if specified
|
||||
self._bind_dir: Optional[Path] = None
|
||||
if bind_dir:
|
||||
self._bind_dir = Path(bind_dir) if isinstance(bind_dir, str) else bind_dir
|
||||
self._bind_dir.mkdir(exist_ok=True)
|
||||
os.chmod(bind_dir, 0o777)
|
||||
|
||||
# Determine and prepare Docker image
|
||||
image_name = custom_image_name or "agentdhal-jupyterkernelgateway"
|
||||
if not custom_image_name:
|
||||
try:
|
||||
client.images.get(image_name)
|
||||
except docker.errors.ImageNotFound:
|
||||
# Build default image if not found
|
||||
here = Path(__file__).parent
|
||||
dockerfile = io.BytesIO(self.DEFAULT_DOCKERFILE.encode("utf-8"))
|
||||
logging.info(f"Building image {image_name}...")
|
||||
client.images.build(path=str(here), fileobj=dockerfile, tag=image_name)
|
||||
logging.info(f"Image {image_name} built successfully")
|
||||
else:
|
||||
# Verify custom image exists
|
||||
try:
|
||||
client.images.get(image_name)
|
||||
except docker.errors.ImageNotFound as err:
|
||||
raise ValueError(f"Custom image {image_name} does not exist") from err
|
||||
if docker_env is None:
|
||||
docker_env = {}
|
||||
if token is None:
|
||||
token = DockerJupyterServer.GenerateToken()
|
||||
# Set up authentication token
|
||||
self._token = secrets.token_hex(32) if isinstance(token, DockerJupyterServer.GenerateToken) else token
|
||||
|
||||
# Prepare environment variables
|
||||
env = {"TOKEN": self._token}
|
||||
env.update(docker_env)
|
||||
|
||||
# Define volume configuration if bind directory is specified
|
||||
volumes = {str(self._bind_dir): {"bind": str(work_dir), "mode": "rw"}} if self._bind_dir else None
|
||||
|
||||
# Start the container
|
||||
container = client.containers.run(
|
||||
image_name,
|
||||
detach=True,
|
||||
auto_remove=auto_remove,
|
||||
environment=env,
|
||||
publish_all_ports=True,
|
||||
name=container_name,
|
||||
volumes=volumes,
|
||||
working_dir=str(work_dir),
|
||||
)
|
||||
|
||||
# Wait for container to be ready
|
||||
self._wait_for_ready(container)
|
||||
|
||||
# Store container information
|
||||
self._container = container
|
||||
self._port = int(container.ports[f"{expose_port}/tcp"][0]["HostPort"])
|
||||
self._container_id = container.id
|
||||
self._expose_port = expose_port
|
||||
|
||||
if self._container_id is None:
|
||||
raise ValueError("Failed to obtain container id.")
|
||||
|
||||
# Define cleanup function
|
||||
def cleanup() -> None:
|
||||
try:
|
||||
assert self._container_id is not None
|
||||
inner_container = client.containers.get(self._container_id)
|
||||
inner_container.stop()
|
||||
except docker.errors.NotFound:
|
||||
pass
|
||||
atexit.unregister(cleanup)
|
||||
|
||||
# Register cleanup if container should be stopped automatically
|
||||
if stop_container:
|
||||
atexit.register(cleanup)
|
||||
|
||||
self._cleanup_func = cleanup
|
||||
self._stop_container = stop_container
|
||||
|
||||
@property
|
||||
def connection_info(self) -> JupyterConnectionInfo:
|
||||
return JupyterConnectionInfo(host="127.0.0.1", use_https=False, port=self._port, token=self._token)
|
||||
|
||||
def _wait_for_ready(self, container: Any, timeout: int = 60, stop_time: float = 0.1) -> None:
|
||||
elapsed_time = 0.0
|
||||
while container.status != "running" and elapsed_time < timeout:
|
||||
sleep(stop_time)
|
||||
elapsed_time += stop_time
|
||||
container.reload()
|
||||
continue
|
||||
if container.status != "running":
|
||||
raise ValueError("Container failed to start")
|
||||
|
||||
async def stop(self) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._cleanup_func)
|
||||
|
||||
async def get_client(self) -> JupyterClient:
|
||||
return JupyterClient(self.connection_info)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
|
||||
) -> None:
|
||||
await self.stop()
|
||||
@@ -0,0 +1,6 @@
|
||||
from ._jupyter_code_executor import JupyterCodeExecutor, JupyterCodeResult
|
||||
|
||||
__all__ = [
|
||||
"JupyterCodeExecutor",
|
||||
"JupyterCodeResult",
|
||||
]
|
||||
@@ -0,0 +1,335 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import uuid
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from agentdhal_core import Component
|
||||
from pydantic import BaseModel
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self
|
||||
else:
|
||||
from typing_extensions import Self
|
||||
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from typing import Optional, Union
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.code_executor import CodeBlock, CodeExecutor, CodeResult
|
||||
from nbclient import NotebookClient
|
||||
from nbformat import NotebookNode
|
||||
from nbformat import v4 as nbformat
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._common import silence_pip
|
||||
|
||||
|
||||
@dataclass
|
||||
class JupyterCodeResult(CodeResult):
|
||||
"""A code result class for Jupyter code executor."""
|
||||
|
||||
output_files: list[Path]
|
||||
|
||||
|
||||
class JupyterCodeExecutorConfig(BaseModel):
|
||||
"""Configuration for JupyterCodeExecutor"""
|
||||
|
||||
kernel_name: str = "python3"
|
||||
timeout: int = 60
|
||||
output_dir: Optional[str] = None
|
||||
|
||||
|
||||
class JupyterCodeExecutor(CodeExecutor, Component[JupyterCodeExecutorConfig]):
|
||||
"""A code executor class that executes code statefully using [nbclient](https://github.com/jupyter/nbclient).
|
||||
|
||||
.. danger::
|
||||
|
||||
This will execute code on the local machine. If being used with LLM generated code, caution should be used.
|
||||
|
||||
Example of using it directly:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.code_executor import CodeBlock
|
||||
from agentdhal_extensions.code_executors.jupyter import JupyterCodeExecutor
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
async with JupyterCodeExecutor() as executor:
|
||||
cancel_token = CancellationToken()
|
||||
code_blocks = [CodeBlock(code="print('hello world!')", language="python")]
|
||||
code_result = await executor.execute_code_blocks(code_blocks, cancel_token)
|
||||
print(code_result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Example of using it with :class:`~agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_extensions.code_executors.jupyter import JupyterCodeExecutor
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.tools.code_execution import PythonCodeExecutionTool
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
async with JupyterCodeExecutor() as executor:
|
||||
tool = PythonCodeExecutionTool(executor)
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
agent = AssistantAgent("assistant", model_client=model_client, tools=[tool])
|
||||
result = await agent.run(task="What is the 10th Fibonacci number? Use Python to calculate it.")
|
||||
print(result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Example of using it inside a :class:`~agentdhal_agentchat.agents._code_executor_agent.CodeExecutorAgent`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import CodeExecutorAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
from agentdhal_extensions.code_executors.jupyter import JupyterCodeExecutor
|
||||
from agentdhal_core import CancellationToken
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
async with JupyterCodeExecutor() as executor:
|
||||
code_executor_agent = CodeExecutorAgent("code_executor", code_executor=executor)
|
||||
task = TextMessage(
|
||||
content='''Here is some code
|
||||
```python
|
||||
print('Hello world')
|
||||
```
|
||||
''',
|
||||
source="user",
|
||||
)
|
||||
response = await code_executor_agent.on_messages([task], CancellationToken())
|
||||
print(response.chat_message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Args:
|
||||
kernel_name (str): The kernel name to use. By default, "python3".
|
||||
timeout (int): The timeout for code execution, by default 60.
|
||||
output_dir (Path): The directory to save output files, by default a temporary directory.
|
||||
|
||||
|
||||
.. note::
|
||||
Using the current directory (".") as output directory is deprecated. Using it will raise a deprecation warning.
|
||||
"""
|
||||
|
||||
component_config_schema = JupyterCodeExecutorConfig
|
||||
component_provider_override = "agentdhal_extensions.code_executors.jupyter.JupyterCodeExecutor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name: str = "python3",
|
||||
timeout: int = 60,
|
||||
output_dir: Optional[Union[Path, str]] = None,
|
||||
):
|
||||
if timeout < 1:
|
||||
raise ValueError("Timeout must be greater than or equal to 1.")
|
||||
|
||||
self._output_dir: Path = Path(tempfile.mkdtemp()) if output_dir is None else Path(output_dir)
|
||||
self._output_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
self._temp_dir: Optional[tempfile.TemporaryDirectory[str]] = None
|
||||
self._temp_dir_path: Optional[Path] = None
|
||||
|
||||
self._started = False
|
||||
|
||||
self._kernel_name = kernel_name
|
||||
self._timeout = timeout
|
||||
|
||||
self._client: Optional[NotebookClient] = None
|
||||
self.kernel_context: Optional[AbstractAsyncContextManager[None]] = None
|
||||
|
||||
async def execute_code_blocks(
|
||||
self, code_blocks: list[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> JupyterCodeResult:
|
||||
"""Execute code blocks and return the result.
|
||||
|
||||
Args:
|
||||
code_blocks (list[CodeBlock]): The code blocks to execute.
|
||||
|
||||
Returns:
|
||||
JupyterCodeResult: The result of the code execution.
|
||||
"""
|
||||
outputs: list[str] = []
|
||||
output_files: list[Path] = []
|
||||
exit_code = 0
|
||||
|
||||
for code_block in code_blocks:
|
||||
result = await self._execute_code_block(code_block, cancellation_token)
|
||||
exit_code = result.exit_code
|
||||
outputs.append(result.output)
|
||||
output_files.extend(result.output_files)
|
||||
|
||||
# Stop execution if one code block fails
|
||||
if exit_code != 0:
|
||||
break
|
||||
|
||||
return JupyterCodeResult(exit_code=exit_code, output="\n".join(outputs), output_files=output_files)
|
||||
|
||||
async def _execute_code_block(
|
||||
self, code_block: CodeBlock, cancellation_token: CancellationToken
|
||||
) -> JupyterCodeResult:
|
||||
"""Execute single code block and return the result.
|
||||
|
||||
Args:
|
||||
code_block (CodeBlock): The code block to execute.
|
||||
|
||||
Returns:
|
||||
JupyterCodeResult: The result of the code execution.
|
||||
"""
|
||||
execute_task = asyncio.create_task(
|
||||
self._execute_cell(
|
||||
nbformat.new_code_cell(silence_pip(code_block.code, code_block.language)) # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
cancellation_token.link_future(execute_task)
|
||||
output_cell = await asyncio.wait_for(asyncio.shield(execute_task), timeout=self._timeout)
|
||||
|
||||
outputs: list[str] = []
|
||||
output_files: list[Path] = []
|
||||
exit_code = 0
|
||||
|
||||
for output in output_cell.get("outputs", []):
|
||||
match output.get("output_type"):
|
||||
case "stream":
|
||||
outputs.append(output.get("text", ""))
|
||||
case "error":
|
||||
traceback = re.sub(r"\x1b\[[0-9;]*[A-Za-z]", "", "\n".join(output["traceback"]))
|
||||
outputs.append(traceback)
|
||||
exit_code = 1
|
||||
case "execute_result" | "display_data":
|
||||
data = output.get("data", {})
|
||||
for mime, content in data.items():
|
||||
match mime:
|
||||
case "text/plain":
|
||||
outputs.append(content)
|
||||
case "image/png":
|
||||
path = self._save_image(content)
|
||||
output_files.append(path)
|
||||
case "image/jpeg":
|
||||
# TODO: Should this also be encoded? Images are encoded as both png and jpg
|
||||
pass
|
||||
case "text/html":
|
||||
path = self._save_html(content)
|
||||
output_files.append(path)
|
||||
case _:
|
||||
outputs.append(json.dumps(content))
|
||||
case _:
|
||||
pass
|
||||
|
||||
return JupyterCodeResult(exit_code=exit_code, output="\n".join(outputs), output_files=output_files)
|
||||
|
||||
async def _execute_cell(self, cell: NotebookNode) -> NotebookNode:
|
||||
# Temporary push cell to nb as async_execute_cell expects it. But then we want to remove it again as cells can take up significant amount of memory (especially with images)
|
||||
if not self._client:
|
||||
raise RuntimeError("Executor must be started before executing cells")
|
||||
self._client.nb.cells.append(cell)
|
||||
output = await self._client.async_execute_cell(
|
||||
cell,
|
||||
cell_index=0,
|
||||
)
|
||||
self._client.nb.cells.pop()
|
||||
return output
|
||||
|
||||
def _save_image(self, image_data_base64: str) -> Path:
|
||||
"""Save image data to a file."""
|
||||
image_data = base64.b64decode(image_data_base64)
|
||||
path = self._output_dir / f"{uuid.uuid4().hex}.png"
|
||||
path.write_bytes(image_data)
|
||||
return path.absolute()
|
||||
|
||||
def _save_html(self, html_data: str) -> Path:
|
||||
"""Save HTML data to a file."""
|
||||
path = self._output_dir / f"{uuid.uuid4().hex}.html"
|
||||
path.write_text(html_data)
|
||||
return path.absolute()
|
||||
|
||||
async def restart(self) -> None:
|
||||
"""Restart the code executor."""
|
||||
await self.stop()
|
||||
await self.start()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""(Experimental) Start the code executor.
|
||||
|
||||
Initializes the Jupyter Notebook execution environment by creating a new notebook and setting it up with the specified Jupyter Kernel.
|
||||
Marks the executor as started, allowing for code execution.
|
||||
This method should be called before executing any code blocks.
|
||||
"""
|
||||
if self._started:
|
||||
return
|
||||
|
||||
notebook: NotebookNode = nbformat.new_notebook() # type: ignore
|
||||
|
||||
self._client = NotebookClient(
|
||||
nb=notebook,
|
||||
kernel_name=self._kernel_name,
|
||||
timeout=self._timeout,
|
||||
allow_errors=True,
|
||||
)
|
||||
|
||||
self.kernel_context = self._client.async_setup_kernel()
|
||||
await self.kernel_context.__aenter__()
|
||||
|
||||
self._started = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""(Experimental) Stop the code executor.
|
||||
|
||||
Terminates the Jupyter Notebook execution by exiting the kernel context and cleaning up the associated resources."""
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
if self.kernel_context is not None:
|
||||
await self.kernel_context.__aexit__(None, None, None)
|
||||
self.kernel_context = None
|
||||
|
||||
self._client = None
|
||||
self._started = False
|
||||
|
||||
def _to_config(self) -> JupyterCodeExecutorConfig:
|
||||
"""Convert current instance to config object"""
|
||||
return JupyterCodeExecutorConfig(
|
||||
kernel_name=self._kernel_name, timeout=self._timeout, output_dir=str(self.output_dir)
|
||||
)
|
||||
|
||||
@property
|
||||
def output_dir(self) -> Path:
|
||||
# If a user specifies the current directory, warn them that this is deprecated
|
||||
if self._output_dir == Path("."):
|
||||
warnings.warn(
|
||||
"Using the current directory as output_dir is deprecated",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self._output_dir
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: JupyterCodeExecutorConfig) -> Self:
|
||||
"""Create instance from config object"""
|
||||
return cls(
|
||||
kernel_name=config.kernel_name,
|
||||
timeout=config.timeout,
|
||||
output_dir=Path(config.output_dir) if config.output_dir else None,
|
||||
)
|
||||
517
agent_dhal/agentdhal_extensions/code_executors/local/__init__.py
Normal file
517
agent_dhal/agentdhal_extensions/code_executors/local/__init__.py
Normal file
@@ -0,0 +1,517 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/local_commandline_code_executor.py
|
||||
# Credit to original authors
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import warnings
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Callable, ClassVar, List, Optional, Sequence, Union
|
||||
|
||||
from agentdhal_core import CancellationToken, Component
|
||||
from agentdhal_core.code_executor import CodeBlock, CodeExecutor, FunctionWithRequirements, FunctionWithRequirementsStr
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import ParamSpec, Self
|
||||
|
||||
from .._common import (
|
||||
PYTHON_VARIANTS,
|
||||
CommandLineCodeResult,
|
||||
build_python_functions_file,
|
||||
get_file_name_from_content,
|
||||
lang_to_cmd,
|
||||
silence_pip,
|
||||
to_stub,
|
||||
)
|
||||
|
||||
__all__ = ("LocalCommandLineCodeExecutor",)
|
||||
|
||||
A = ParamSpec("A")
|
||||
|
||||
|
||||
class LocalCommandLineCodeExecutorConfig(BaseModel):
|
||||
"""Configuration for LocalCommandLineCodeExecutor"""
|
||||
|
||||
timeout: int = 60
|
||||
work_dir: Optional[str] = None
|
||||
functions_module: str = "functions"
|
||||
cleanup_temp_files: bool = True
|
||||
|
||||
|
||||
class LocalCommandLineCodeExecutor(CodeExecutor, Component[LocalCommandLineCodeExecutorConfig]):
|
||||
"""A code executor class that executes code through a local command line
|
||||
environment.
|
||||
|
||||
.. danger::
|
||||
|
||||
This will execute code on the local machine. If being used with LLM generated code, caution should be used.
|
||||
|
||||
Each code block is saved as a file and executed in a separate process in
|
||||
the working directory, and a unique file is generated and saved in the
|
||||
working directory for each code block.
|
||||
The code blocks are executed in the order they are received.
|
||||
Command line code is sanitized using regular expression match against a list of dangerous commands in order to prevent self-destructive
|
||||
commands from being executed which may potentially affect the users environment.
|
||||
Currently the only supported languages is Python and shell scripts.
|
||||
For Python code, use the language "python" for the code block.
|
||||
For shell scripts, use the language "bash", "shell", "sh", "pwsh", "powershell", or "ps1" for the code
|
||||
block.
|
||||
|
||||
.. note::
|
||||
|
||||
On Windows, the event loop policy must be set to `WindowsProactorEventLoopPolicy` to avoid issues with subprocesses.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||
|
||||
Args:
|
||||
timeout (int): The timeout for the execution of any single code block. Default is 60.
|
||||
work_dir (str): The working directory for the code execution. If None,
|
||||
a default working directory will be used. The default working directory is a temporary directory.
|
||||
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
|
||||
functions_module (str, optional): The name of the module that will be created to store the functions. Defaults to "functions".
|
||||
cleanup_temp_files (bool, optional): Whether to automatically clean up temporary files after execution. Defaults to True.
|
||||
virtual_env_context (Optional[SimpleNamespace], optional): The virtual environment context. Defaults to None.
|
||||
|
||||
.. note::
|
||||
Using the current directory (".") as working directory is deprecated. Using it will raise a deprecation warning.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
How to use `LocalCommandLineCodeExecutor` with a virtual environment different from the one used to run the autogen application:
|
||||
Set up a virtual environment using the `venv` module, and pass its context to the initializer of `LocalCommandLineCodeExecutor`. This way, the executor will run code within the new environment.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import venv
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.code_executor import CodeBlock
|
||||
from agentdhal_extensions.code_executors.local import LocalCommandLineCodeExecutor
|
||||
|
||||
|
||||
async def example():
|
||||
work_dir = Path("coding")
|
||||
work_dir.mkdir(exist_ok=True)
|
||||
|
||||
venv_dir = work_dir / ".venv"
|
||||
venv_builder = venv.EnvBuilder(with_pip=True)
|
||||
venv_builder.create(venv_dir)
|
||||
venv_context = venv_builder.ensure_directories(venv_dir)
|
||||
|
||||
local_executor = LocalCommandLineCodeExecutor(work_dir=work_dir, virtual_env_context=venv_context)
|
||||
await local_executor.execute_code_blocks(
|
||||
code_blocks=[
|
||||
CodeBlock(language="bash", code="pip install matplotlib"),
|
||||
],
|
||||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
|
||||
|
||||
asyncio.run(example())
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = LocalCommandLineCodeExecutorConfig
|
||||
component_provider_override = "agentdhal_extensions.code_executors.local.LocalCommandLineCodeExecutor"
|
||||
|
||||
SUPPORTED_LANGUAGES: ClassVar[List[str]] = [
|
||||
"bash",
|
||||
"shell",
|
||||
"sh",
|
||||
"pwsh",
|
||||
"powershell",
|
||||
"ps1",
|
||||
"python",
|
||||
]
|
||||
FUNCTION_PROMPT_TEMPLATE: ClassVar[
|
||||
str
|
||||
] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names.
|
||||
|
||||
For example, if there was a function called `foo` you could import it by writing `from $module_name import foo`
|
||||
|
||||
$functions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int = 60,
|
||||
work_dir: Optional[Union[Path, str]] = None,
|
||||
functions: Sequence[
|
||||
Union[
|
||||
FunctionWithRequirements[Any, A],
|
||||
Callable[..., Any],
|
||||
FunctionWithRequirementsStr,
|
||||
]
|
||||
] = [],
|
||||
functions_module: str = "functions",
|
||||
cleanup_temp_files: bool = True,
|
||||
virtual_env_context: Optional[SimpleNamespace] = None,
|
||||
):
|
||||
if timeout < 1:
|
||||
raise ValueError("Timeout must be greater than or equal to 1.")
|
||||
self._timeout = timeout
|
||||
|
||||
self._work_dir: Optional[Path] = None
|
||||
if work_dir is not None:
|
||||
# Check if user provided work_dir is the current directory and warn if so.
|
||||
if Path(work_dir).resolve() == Path.cwd().resolve():
|
||||
warnings.warn(
|
||||
"Using the current directory as work_dir is deprecated.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if isinstance(work_dir, str):
|
||||
self._work_dir = Path(work_dir)
|
||||
else:
|
||||
self._work_dir = work_dir
|
||||
self._work_dir.mkdir(exist_ok=True)
|
||||
|
||||
self._functions = functions
|
||||
# Setup could take some time so we intentionally wait for the first code block to do it.
|
||||
if len(functions) > 0:
|
||||
self._setup_functions_complete = False
|
||||
else:
|
||||
self._setup_functions_complete = True
|
||||
|
||||
if not functions_module.isidentifier():
|
||||
raise ValueError("Module name must be a valid Python identifier")
|
||||
self._functions_module = functions_module
|
||||
|
||||
self._cleanup_temp_files = cleanup_temp_files
|
||||
self._virtual_env_context: Optional[SimpleNamespace] = virtual_env_context
|
||||
|
||||
self._temp_dir: Optional[tempfile.TemporaryDirectory[str]] = None
|
||||
self._started = False
|
||||
|
||||
# Check the current event loop policy if on windows.
|
||||
if sys.platform == "win32":
|
||||
current_policy = asyncio.get_event_loop_policy()
|
||||
if hasattr(asyncio, "WindowsProactorEventLoopPolicy") and not isinstance(
|
||||
current_policy, asyncio.WindowsProactorEventLoopPolicy
|
||||
):
|
||||
warnings.warn(
|
||||
"The current event loop policy is not WindowsProactorEventLoopPolicy. "
|
||||
"This may cause issues with subprocesses. "
|
||||
"Try setting the event loop policy to WindowsProactorEventLoopPolicy. "
|
||||
"For example: `asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())`. "
|
||||
"See https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.ProactorEventLoop.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str:
|
||||
"""(Experimental) Format the functions for a prompt.
|
||||
|
||||
The template includes two variables:
|
||||
- `$module_name`: The module name.
|
||||
- `$functions`: The functions formatted as stubs with two newlines between each function.
|
||||
|
||||
Args:
|
||||
prompt_template (str): The prompt template. Default is the class default.
|
||||
|
||||
Returns:
|
||||
str: The formatted prompt.
|
||||
"""
|
||||
|
||||
template = Template(prompt_template)
|
||||
return template.substitute(
|
||||
module_name=self._functions_module,
|
||||
functions="\n\n".join([to_stub(func) for func in self._functions]),
|
||||
)
|
||||
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
"""(Experimental) The timeout for code execution."""
|
||||
return self._timeout
|
||||
|
||||
@property
|
||||
def work_dir(self) -> Path:
|
||||
"""(Experimental) The working directory for the code execution."""
|
||||
if self._work_dir is not None:
|
||||
return self._work_dir
|
||||
else:
|
||||
# Automatically create temp directory if not exists
|
||||
if self._temp_dir is None:
|
||||
self._temp_dir = tempfile.TemporaryDirectory()
|
||||
self._started = True
|
||||
return Path(self._temp_dir.name)
|
||||
|
||||
@property
|
||||
def functions(self) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def functions_module(self) -> str:
|
||||
"""(Experimental) The module name for the functions."""
|
||||
return self._functions_module
|
||||
|
||||
@property
|
||||
def cleanup_temp_files(self) -> bool:
|
||||
"""(Experimental) Whether to automatically clean up temporary files after execution."""
|
||||
return self._cleanup_temp_files
|
||||
|
||||
async def _setup_functions(self, cancellation_token: CancellationToken) -> None:
|
||||
func_file_content = build_python_functions_file(self._functions)
|
||||
func_file = self.work_dir / f"{self._functions_module}.py"
|
||||
func_file.write_text(func_file_content)
|
||||
|
||||
# Collect requirements
|
||||
lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)]
|
||||
flattened_packages = [item for sublist in lists_of_packages for item in sublist]
|
||||
required_packages = list(set(flattened_packages))
|
||||
if len(required_packages) > 0:
|
||||
logging.info("Ensuring packages are installed in executor.")
|
||||
|
||||
cmd_args = ["-m", "pip", "install"]
|
||||
cmd_args.extend(required_packages)
|
||||
|
||||
if self._virtual_env_context:
|
||||
py_executable = self._virtual_env_context.env_exe
|
||||
else:
|
||||
py_executable = sys.executable
|
||||
|
||||
task = asyncio.create_task(
|
||||
asyncio.create_subprocess_exec(
|
||||
py_executable,
|
||||
*cmd_args,
|
||||
cwd=self.work_dir,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
)
|
||||
cancellation_token.link_future(task)
|
||||
try:
|
||||
proc = await task
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), self._timeout)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise ValueError("Pip install timed out") from e
|
||||
except asyncio.CancelledError as e:
|
||||
raise ValueError("Pip install was cancelled") from e
|
||||
|
||||
if proc.returncode is not None and proc.returncode != 0:
|
||||
raise ValueError(f"Pip install failed. {stdout.decode()}, {stderr.decode()}")
|
||||
|
||||
# Attempt to load the function file to check for syntax errors, imports etc.
|
||||
exec_result = await self._execute_code_dont_check_setup(
|
||||
[CodeBlock(code=func_file_content, language="python")], cancellation_token
|
||||
)
|
||||
|
||||
if exec_result.exit_code != 0:
|
||||
raise ValueError(f"Functions failed to load: {exec_result.output}")
|
||||
|
||||
self._setup_functions_complete = True
|
||||
|
||||
async def execute_code_blocks(
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> CommandLineCodeResult:
|
||||
"""(Experimental) Execute the code blocks and return the result.
|
||||
|
||||
Args:
|
||||
code_blocks (List[CodeBlock]): The code blocks to execute.
|
||||
cancellation_token (CancellationToken): a token to cancel the operation
|
||||
|
||||
Returns:
|
||||
CommandLineCodeResult: The result of the code execution."""
|
||||
|
||||
if not self._setup_functions_complete:
|
||||
await self._setup_functions(cancellation_token)
|
||||
|
||||
return await self._execute_code_dont_check_setup(code_blocks, cancellation_token)
|
||||
|
||||
async def _execute_code_dont_check_setup(
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> CommandLineCodeResult:
|
||||
"""
|
||||
Execute the provided code blocks in the local command line without re-checking setup.
|
||||
Returns a CommandLineCodeResult indicating success or failure.
|
||||
"""
|
||||
logs_all: str = ""
|
||||
file_names: List[Path] = []
|
||||
exitcode = 0
|
||||
|
||||
for code_block in code_blocks:
|
||||
lang, code = code_block.language, code_block.code
|
||||
lang = lang.lower()
|
||||
|
||||
# Remove pip output where possible
|
||||
code = silence_pip(code, lang)
|
||||
|
||||
# Normalize python variants to "python"
|
||||
if lang in PYTHON_VARIANTS:
|
||||
lang = "python"
|
||||
|
||||
# Abort if not supported
|
||||
if lang not in self.SUPPORTED_LANGUAGES:
|
||||
exitcode = 1
|
||||
logs_all += "\n" + f"unknown language {lang}"
|
||||
break
|
||||
|
||||
# Try extracting a filename (if present)
|
||||
try:
|
||||
filename = get_file_name_from_content(code, self.work_dir)
|
||||
except ValueError:
|
||||
return CommandLineCodeResult(
|
||||
exit_code=1,
|
||||
output="Filename is not in the workspace",
|
||||
code_file=None,
|
||||
)
|
||||
|
||||
# If no filename is found, create one
|
||||
if filename is None:
|
||||
code_hash = sha256(code.encode()).hexdigest()
|
||||
if lang.startswith("python"):
|
||||
ext = "py"
|
||||
elif lang in ["pwsh", "powershell", "ps1"]:
|
||||
ext = "ps1"
|
||||
else:
|
||||
ext = lang
|
||||
|
||||
filename = f"tmp_code_{code_hash}.{ext}"
|
||||
|
||||
written_file = (self.work_dir / filename).resolve()
|
||||
with written_file.open("w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
file_names.append(written_file)
|
||||
|
||||
# Build environment
|
||||
env = os.environ.copy()
|
||||
if self._virtual_env_context:
|
||||
virtual_env_bin_abs_path = os.path.abspath(self._virtual_env_context.bin_path)
|
||||
env["PATH"] = f"{virtual_env_bin_abs_path}{os.pathsep}{env['PATH']}"
|
||||
|
||||
# Decide how to invoke the script
|
||||
if lang == "python":
|
||||
program = (
|
||||
os.path.abspath(self._virtual_env_context.env_exe) if self._virtual_env_context else sys.executable
|
||||
)
|
||||
extra_args = [str(written_file.absolute())]
|
||||
else:
|
||||
# Get the appropriate command for the language
|
||||
program = lang_to_cmd(lang)
|
||||
|
||||
# Special handling for PowerShell
|
||||
if program == "pwsh":
|
||||
extra_args = [
|
||||
"-NoProfile",
|
||||
"-ExecutionPolicy",
|
||||
"Bypass",
|
||||
"-File",
|
||||
str(written_file.absolute()),
|
||||
]
|
||||
else:
|
||||
# Shell commands (bash, sh, etc.)
|
||||
extra_args = [str(written_file.absolute())]
|
||||
|
||||
# Create a subprocess and run
|
||||
task = asyncio.create_task(
|
||||
asyncio.create_subprocess_exec(
|
||||
program,
|
||||
*extra_args,
|
||||
cwd=self.work_dir,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
)
|
||||
cancellation_token.link_future(task)
|
||||
|
||||
proc = None # Track the process
|
||||
try:
|
||||
proc = await task
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), self._timeout)
|
||||
exitcode = proc.returncode or 0
|
||||
except asyncio.TimeoutError:
|
||||
logs_all += "\nTimeout"
|
||||
exitcode = 124
|
||||
if proc:
|
||||
proc.terminate()
|
||||
await proc.wait() # Ensure process is fully dead
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
logs_all += "\nCancelled"
|
||||
exitcode = 125
|
||||
if proc:
|
||||
proc.terminate()
|
||||
await proc.wait()
|
||||
break
|
||||
|
||||
logs_all += stderr.decode()
|
||||
logs_all += stdout.decode()
|
||||
|
||||
if exitcode != 0:
|
||||
break
|
||||
|
||||
code_file = str(file_names[0]) if file_names else None
|
||||
code_result = CommandLineCodeResult(exit_code=exitcode, output=logs_all, code_file=code_file)
|
||||
|
||||
if self._cleanup_temp_files:
|
||||
for file in file_names:
|
||||
try:
|
||||
file.unlink(missing_ok=True)
|
||||
except OSError as error:
|
||||
logging.error(f"Failed to delete temporary file {file}: {error}")
|
||||
|
||||
return code_result
|
||||
|
||||
async def restart(self) -> None:
|
||||
"""(Experimental) Restart the code executor."""
|
||||
warnings.warn(
|
||||
"Restarting local command line code executor is not supported. No action is taken.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""(Experimental) Start the code executor.
|
||||
|
||||
Initializes the local code executor and should be called before executing any code blocks.
|
||||
It marks the executor internal state as started.
|
||||
If no working directory is provided, the method creates a temporary directory for the executor to use.
|
||||
"""
|
||||
if self._work_dir is None and self._temp_dir is None:
|
||||
self._temp_dir = tempfile.TemporaryDirectory()
|
||||
self._started = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""(Experimental) Stop the code executor.
|
||||
|
||||
Stops the local code executor and performs the cleanup of the temporary working directory (if it was created).
|
||||
The executor's internal state is markes as no longer started.
|
||||
"""
|
||||
if self._temp_dir is not None:
|
||||
self._temp_dir.cleanup()
|
||||
self._temp_dir = None
|
||||
self._started = False
|
||||
pass
|
||||
|
||||
def _to_config(self) -> LocalCommandLineCodeExecutorConfig:
|
||||
if self._functions:
|
||||
logging.info("Functions will not be included in serialized configuration")
|
||||
if self._virtual_env_context:
|
||||
logging.info("Virtual environment context will not be included in serialized configuration")
|
||||
|
||||
return LocalCommandLineCodeExecutorConfig(
|
||||
timeout=self._timeout,
|
||||
work_dir=str(self.work_dir),
|
||||
functions_module=self._functions_module,
|
||||
cleanup_temp_files=self._cleanup_temp_files,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: LocalCommandLineCodeExecutorConfig) -> Self:
|
||||
return cls(
|
||||
timeout=config.timeout,
|
||||
work_dir=Path(config.work_dir) if config.work_dir is not None else None,
|
||||
functions_module=config.functions_module,
|
||||
cleanup_temp_files=config.cleanup_temp_files,
|
||||
)
|
||||
@@ -0,0 +1,210 @@
|
||||
# Task-Centric Memory
|
||||
_(EXPERIMENTAL, RESEARCH IN PROGRESS)_
|
||||
|
||||
**Task-Centric Memory** is an active research project aimed at giving AI agents the ability to:
|
||||
|
||||
* Accomplish general tasks more effectively by learning quickly and continually beyond context-window limitations.
|
||||
* Remember guidance, corrections, plans, and demonstrations provided by users.
|
||||
* Learn through the agent's own experience and adapt quickly to changing circumstances.
|
||||
* Avoid repeating mistakes on tasks that are similar to those previously encountered.
|
||||
|
||||
## Installation
|
||||
|
||||
Install AutoGen and its extension package as follows:
|
||||
|
||||
```bash
|
||||
pip install -U "autogen-agentchat" "autogen-ext[openai]" "autogen-ext[task-centric-memory]"
|
||||
```
|
||||
|
||||
## Quickstart
|
||||
|
||||
<p align="right">
|
||||
<img src="../../../../imgs/task_centric_memory_2.png" alt="Description" width="150" align="right" style="margin-left: 10px;">
|
||||
</p>
|
||||
|
||||
This first code snippet runs a basic test to verify that the installation was successful,
|
||||
as illustrated by the diagram to the right.
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.experimental.task_centric_memory import MemoryController
|
||||
from autogen_ext.experimental.task_centric_memory.utils import PageLogger
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart"}) # Optional, but very useful.
|
||||
memory_controller = MemoryController(reset=True, client=client, logger=logger)
|
||||
|
||||
# Add a few task-insight pairs as memories, where an insight can be any string that may help solve the task.
|
||||
await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color")
|
||||
await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan")
|
||||
await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite")
|
||||
|
||||
# Retrieve memories for a new task that's related to only two of the stored memories.
|
||||
memos = await memory_controller.retrieve_relevant_memos(task="What colors do I like most?")
|
||||
print("{} memories retrieved".format(len(memos)))
|
||||
for memo in memos:
|
||||
print("- " + memo.insight)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
<p align="right">
|
||||
<img src="../../../../imgs/task_centric_memory_3.png" alt="Description" width="150" align="right" style="margin-left: 10px;">
|
||||
</p>
|
||||
|
||||
This second code example shows one way to incorporate task-centric memory directly into an AutoGen agent,
|
||||
in this case a subclass of RoutedAgent.
|
||||
To keep the code short, only the simplest form of memory retrieval is exercised by this agent.
|
||||
|
||||
```python
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler
|
||||
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.experimental.task_centric_memory import MemoryController
|
||||
from autogen_ext.experimental.task_centric_memory.utils import PageLogger
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
content: str
|
||||
|
||||
|
||||
class MemoryEnabledAgent(RoutedAgent):
|
||||
def __init__(
|
||||
self, description: str, model_client: ChatCompletionClient, memory_controller: MemoryController
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
self._model_client = model_client
|
||||
self._memory_controller = memory_controller
|
||||
|
||||
@message_handler
|
||||
async def handle_message(self, message: Message, context: MessageContext) -> Message:
|
||||
# Retrieve relevant memories for the task.
|
||||
memos = await self._memory_controller.retrieve_relevant_memos(task=message.content)
|
||||
|
||||
# Format the memories for the model.
|
||||
formatted_memos = "Info that may be useful:\n" + "\n".join(["- " + memo.insight for memo in memos])
|
||||
print(f"{'-' * 23}Text appended to the user message{'-' * 24}\n{formatted_memos}\n{'-' * 80}")
|
||||
|
||||
# Create the messages for the model with the retrieved memories.
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content=message.content, source="user"),
|
||||
UserMessage(content=formatted_memos, source="user"),
|
||||
]
|
||||
|
||||
# Call the model with the messages.
|
||||
model_result = await self._model_client.create(messages=messages)
|
||||
assert isinstance(model_result.content, str)
|
||||
|
||||
# Send the model's response to the user.
|
||||
return Message(content=model_result.content)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart2"}) # Optional, but very useful.
|
||||
memory_controller = MemoryController(reset=True, client=client, logger=logger)
|
||||
|
||||
# Prepopulate memory to mimic learning from a prior session.
|
||||
await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color")
|
||||
await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan")
|
||||
await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite")
|
||||
|
||||
# Create and start an agent runtime.
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
runtime.start()
|
||||
|
||||
# Register the agent type.
|
||||
await MemoryEnabledAgent.register(
|
||||
runtime,
|
||||
"memory_enabled_agent",
|
||||
lambda: MemoryEnabledAgent(
|
||||
"A agent with memory", model_client=client, memory_controller=memory_controller
|
||||
),
|
||||
)
|
||||
|
||||
# Send a direct message to the agent.
|
||||
request = "What colors do I like most?"
|
||||
print("User request: " + request)
|
||||
response = await runtime.send_message(
|
||||
Message(content=request), AgentId("memory_enabled_agent", "default")
|
||||
)
|
||||
print("Agent response: " + response.content)
|
||||
|
||||
# Stop the agent runtime.
|
||||
await runtime.stop()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## Sample Code
|
||||
|
||||
The example above modifies the agent's code.
|
||||
But it's also possible to add task-centric memory to an agent or multi-agent team _without_ modifying any agent code.
|
||||
See the [sample code](../../../../../../samples/task_centric_memory) for that and other forms of fast, memory-based learning.
|
||||
|
||||
|
||||
## Architecture
|
||||
|
||||
<p align="right">
|
||||
<img src="../../../../imgs/task_centric_memory.png" alt="Description" width="300" align="right" style="margin-left: 10px;">
|
||||
</p>
|
||||
|
||||
The block diagram to the right outlines the key components of the architecture in the most general form.
|
||||
The memory components are shown in blue, and the green blocks represent external components.
|
||||
|
||||
The **Memory Controller** implements the fast-learning methods described below,
|
||||
and manages communication with a **Memory Bank** containing a vector DB and associated structures.
|
||||
|
||||
The **Agent or Team** is the AI agent or team of agents to which memory is being added.
|
||||
The sample code shows how to add task-centric memory to a simple AssistantAgent or a MagenticOneGroupChat team.
|
||||
|
||||
The **Apprentice, app, or service** represents the code that instantiates the agent and memory controller,
|
||||
and routes information between them, effectively wrapping agent and memory into a combined component.
|
||||
The term _Apprentice_ connotes that this combination uses memory to learn quickly on the job.
|
||||
The Apprentice class is a minimal reference implementation provided as utility code for illustration and testing,
|
||||
but most applications will use their own code instead of the Apprentice.
|
||||
|
||||
## Memory Creation and Storage
|
||||
|
||||
Each stored memory (called a _memo_) contains a text insight and (optionally) a task description.
|
||||
The insight is intended to help the agent accomplish future tasks that are similar to a prior task.
|
||||
The memory controller provides methods for different types of learning.
|
||||
If the user provides advice for solving a given task, the advice is extracted by the model client and stored as an insight.
|
||||
If the user demonstrates how to perform a task,
|
||||
the task and demonstration are stored together as an insight used to solve similar but different tasks.
|
||||
If the agent is given a task (free of side-effects) and some means of determining success or failure,
|
||||
the memory controller repeats the following learning loop in the background some number of times:
|
||||
|
||||
1. Test the agent on the task a few times to check for a failure.
|
||||
2. If a failure is found, analyze the agent's response in order to:
|
||||
1. Diagnose the failure of reasoning or missing information,
|
||||
2. Phrase a general piece of advice, such as what a teacher might give to a student,
|
||||
3. Temporarily append this advice to the task description,
|
||||
4. Return to step 1.
|
||||
5. If some piece of advice succeeds in helping the agent solve the task a number of times, add the advice as an insight to memory.
|
||||
3. For each insight to be stored in memory, an LLM is prompted to generate a set of free-form, multi-word topics related to the insight. Each topic is embedded to a fixed-length vector and stored in a vector DB mapping it to the topic’s related insight.
|
||||
|
||||
## Memory Retrieval and Usage
|
||||
|
||||
The memory controller provides methods for different types of memory retrieval.
|
||||
When the agent is given a task, the following steps are performed by the controller:
|
||||
1. The task is rephrased into a generalized form.
|
||||
2. A set of free-form, multi-word query topics are generated from the generalized task.
|
||||
3. A potentially large number of previously stored topics, those most similar to each query topic, are retrieved from the vector DB along with the insights they map to.
|
||||
4. These candidate memos are filtered by the aggregate similarity of their stored topics to the query topics.
|
||||
5. In the final filtering stage, an LLM is prompted to validate only those insights that seem potentially useful in solving the task at hand.
|
||||
|
||||
Retrieved insights that pass the filtering steps are listed under a heading like
|
||||
"Important insights that may help solve tasks like this", then appended to the task description before it is passed to the agent as usual.
|
||||
@@ -0,0 +1,4 @@
|
||||
from ._memory_bank import MemoryBankConfig
|
||||
from .memory_controller import MemoryController, MemoryControllerConfig
|
||||
|
||||
__all__ = ["MemoryController", "MemoryControllerConfig", "MemoryBankConfig"]
|
||||
@@ -0,0 +1,201 @@
|
||||
import os
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict
|
||||
|
||||
from ._string_similarity_map import StringSimilarityMap
|
||||
from .utils.page_logger import PageLogger
|
||||
|
||||
|
||||
@dataclass
|
||||
class Memo:
|
||||
"""
|
||||
Represents an atomic unit of memory that can be stored in a memory bank and later retrieved.
|
||||
"""
|
||||
|
||||
task: str | None # The task description, if any.
|
||||
insight: str # A hint, solution, plan, or any other text that may help solve a similar task.
|
||||
|
||||
|
||||
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
|
||||
# the settings that change frequently, as when loading many settings from a single YAML file.
|
||||
class MemoryBankConfig(TypedDict, total=False):
|
||||
path: str
|
||||
relevance_conversion_threshold: float
|
||||
n_results: int
|
||||
distance_threshold: int
|
||||
|
||||
|
||||
class MemoryBank:
|
||||
"""
|
||||
Stores task-completion insights as memories in a vector DB for later retrieval.
|
||||
|
||||
Args:
|
||||
reset: True to clear the DB before starting.
|
||||
config: An optional dict that can be used to override the following values:
|
||||
|
||||
- path: The path to the directory where the memory bank files are stored.
|
||||
- relevance_conversion_threshold: The threshold used to normalize relevance.
|
||||
- n_results: The maximum number of most relevant results to return for any given topic.
|
||||
- distance_threshold: The maximum string-pair distance for a memo to be retrieved.
|
||||
|
||||
logger: An optional logger. If None, no logging will be performed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reset: bool,
|
||||
config: MemoryBankConfig | None = None,
|
||||
logger: PageLogger | None = None,
|
||||
) -> None:
|
||||
if logger is None:
|
||||
logger = PageLogger() # Nothing will be logged by this object.
|
||||
self.logger = logger
|
||||
self.logger.enter_function()
|
||||
|
||||
# Apply default settings and any config overrides.
|
||||
memory_dir_path = "./memory_bank/default"
|
||||
self.relevance_conversion_threshold = 1.7
|
||||
self.n_results = 25
|
||||
self.distance_threshold = 100
|
||||
if config is not None:
|
||||
memory_dir_path = config.get("path", memory_dir_path)
|
||||
self.relevance_conversion_threshold = config.get(
|
||||
"relevance_conversion_threshold", self.relevance_conversion_threshold
|
||||
)
|
||||
self.n_results = config.get("n_results", self.n_results)
|
||||
self.distance_threshold = config.get("distance_threshold", self.distance_threshold)
|
||||
|
||||
memory_dir_path = os.path.expanduser(memory_dir_path)
|
||||
self.logger.info("\nMEMORY BANK DIRECTORY {}".format(memory_dir_path))
|
||||
path_to_db_dir = os.path.join(memory_dir_path, "string_map")
|
||||
self.path_to_dict = os.path.join(memory_dir_path, "uid_memo_dict.pkl")
|
||||
|
||||
self.string_map = StringSimilarityMap(reset=reset, path_to_db_dir=path_to_db_dir, logger=self.logger)
|
||||
|
||||
# Load or create the associated memo dict on disk.
|
||||
self.uid_memo_dict: Dict[str, Memo] = {}
|
||||
self.last_memo_id = 0
|
||||
if (not reset) and os.path.exists(self.path_to_dict):
|
||||
self.logger.info("\nLOADING MEMOS FROM DISK at {}".format(self.path_to_dict))
|
||||
with open(self.path_to_dict, "rb") as f:
|
||||
self.uid_memo_dict = pickle.load(f)
|
||||
self.last_memo_id = len(self.uid_memo_dict)
|
||||
self.logger.info("\n{} MEMOS LOADED".format(len(self.uid_memo_dict)))
|
||||
|
||||
# Clear the DB if requested.
|
||||
if reset:
|
||||
self._reset_memos()
|
||||
|
||||
self.logger.leave_function()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Forces immediate deletion of all contents, in memory and on disk.
|
||||
"""
|
||||
self.string_map.reset_db()
|
||||
self._reset_memos()
|
||||
|
||||
def _reset_memos(self) -> None:
|
||||
"""
|
||||
Forces immediate deletion of the memos, in memory and on disk.
|
||||
"""
|
||||
self.logger.info("\nCLEARING MEMOS")
|
||||
self.uid_memo_dict = {}
|
||||
self.save_memos()
|
||||
|
||||
def save_memos(self) -> None:
|
||||
"""
|
||||
Saves the current memo structures (possibly empty) to disk.
|
||||
"""
|
||||
self.string_map.save_string_pairs()
|
||||
with open(self.path_to_dict, "wb") as file:
|
||||
self.logger.info("\nSAVING MEMOS TO DISK at {}".format(self.path_to_dict))
|
||||
pickle.dump(self.uid_memo_dict, file)
|
||||
|
||||
def contains_memos(self) -> bool:
|
||||
"""
|
||||
Returns True if the memory bank contains any memo.
|
||||
"""
|
||||
return len(self.uid_memo_dict) > 0
|
||||
|
||||
def _map_topics_to_memo(self, topics: List[str], memo_id: str, memo: Memo) -> None:
|
||||
"""
|
||||
Adds a mapping in the vec DB from each topic to the memo.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.logger.info("\nINSIGHT\n{}".format(memo.insight))
|
||||
for topic in topics:
|
||||
self.logger.info("\n TOPIC = {}".format(topic))
|
||||
self.string_map.add_input_output_pair(topic, memo_id)
|
||||
self.uid_memo_dict[memo_id] = memo
|
||||
self.save_memos()
|
||||
self.logger.leave_function()
|
||||
|
||||
def add_memo(self, insight_str: str, topics: List[str], task_str: Optional[str] = None) -> None:
|
||||
"""
|
||||
Adds an insight to the memory bank, given topics related to the insight, and optionally the task.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.last_memo_id += 1
|
||||
id_str = str(self.last_memo_id)
|
||||
insight = Memo(insight=insight_str, task=task_str)
|
||||
self._map_topics_to_memo(topics, id_str, insight)
|
||||
self.logger.leave_function()
|
||||
|
||||
def add_task_with_solution(self, task: str, solution: str, topics: List[str]) -> None:
|
||||
"""
|
||||
Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight.
|
||||
This is useful when the insight is a demonstration of how to solve a given type of task.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.last_memo_id += 1
|
||||
id_str = str(self.last_memo_id)
|
||||
# Prepend the insight to the task description for context.
|
||||
insight_str = "Example task:\n\n{}\n\nExample solution:\n\n{}".format(task, solution)
|
||||
memo = Memo(insight=insight_str, task=task)
|
||||
self._map_topics_to_memo(topics, id_str, memo)
|
||||
self.logger.leave_function()
|
||||
|
||||
def get_relevant_memos(self, topics: List[str]) -> List[Memo]:
|
||||
"""
|
||||
Returns any memos from the memory bank that appear sufficiently relevant to the input topics.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
# Retrieve all topic matches, and gather them into a single list.
|
||||
matches: List[Tuple[str, str, float]] = [] # Each match is a tuple: (topic, memo_id, distance)
|
||||
for topic in topics:
|
||||
matches.extend(self.string_map.get_related_string_pairs(topic, self.n_results, self.distance_threshold))
|
||||
|
||||
# Build a dict of memo-relevance pairs from the matches.
|
||||
memo_relevance_dict: Dict[str, float] = {}
|
||||
for match in matches:
|
||||
relevance = self.relevance_conversion_threshold - match[2]
|
||||
memo_id = match[1]
|
||||
if memo_id in memo_relevance_dict:
|
||||
memo_relevance_dict[memo_id] += relevance
|
||||
else:
|
||||
memo_relevance_dict[memo_id] = relevance
|
||||
|
||||
# Log the details of all the retrieved memos.
|
||||
self.logger.info("\n{} POTENTIALLY RELEVANT MEMOS".format(len(memo_relevance_dict)))
|
||||
for memo_id, relevance in memo_relevance_dict.items():
|
||||
memo = self.uid_memo_dict[memo_id]
|
||||
details = ""
|
||||
if memo.task is not None:
|
||||
details += "\n TASK: {}\n".format(memo.task)
|
||||
details += "\n INSIGHT: {}\n\n RELEVANCE: {:.3f}\n".format(memo.insight, relevance)
|
||||
self.logger.info(details)
|
||||
|
||||
# Sort the memo-relevance pairs by relevance, in descending order.
|
||||
memo_relevance_dict = dict(sorted(memo_relevance_dict.items(), key=lambda item: item[1], reverse=True))
|
||||
|
||||
# Compose the list of sufficiently relevant memos to return.
|
||||
memo_list: List[Memo] = []
|
||||
for memo_id in memo_relevance_dict:
|
||||
if memo_relevance_dict[memo_id] >= 0:
|
||||
memo_list.append(self.uid_memo_dict[memo_id])
|
||||
|
||||
self.logger.leave_function()
|
||||
return memo_list
|
||||
@@ -0,0 +1,289 @@
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
from agentdhal_core import Image
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from .utils._functions import UserContent
|
||||
from .utils.page_logger import PageLogger
|
||||
|
||||
|
||||
class Prompter:
|
||||
"""
|
||||
Centralizes most of the Apprentice prompts sent to the model client.
|
||||
|
||||
Args:
|
||||
client: The client to call the model.
|
||||
logger: An optional logger. If None, no logging will be performed.
|
||||
"""
|
||||
|
||||
def __init__(self, client: ChatCompletionClient, logger: PageLogger | None = None) -> None:
|
||||
if logger is None:
|
||||
logger = PageLogger() # Nothing will be logged by this object.
|
||||
self.logger = logger
|
||||
|
||||
self.client = client
|
||||
self.default_system_message_content = "You are a helpful assistant."
|
||||
self.time_spent_in_model_calls = 0.0
|
||||
self.num_model_calls = 0
|
||||
self.start_time = time.time()
|
||||
|
||||
# Create the chat history
|
||||
self._chat_history: List[LLMMessage] = []
|
||||
|
||||
async def call_model(
|
||||
self,
|
||||
summary: str,
|
||||
user_content: UserContent,
|
||||
system_message_content: str | None = None,
|
||||
keep_these_messages: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Calls the model client with the given input and returns the response.
|
||||
"""
|
||||
# Prepare the input message list
|
||||
if system_message_content is None:
|
||||
system_message_content = self.default_system_message_content
|
||||
system_message: LLMMessage
|
||||
if self.client.model_info["family"] == "o1":
|
||||
# No system message allowed, so pass it as the first user message.
|
||||
system_message = UserMessage(content=system_message_content, source="User")
|
||||
else:
|
||||
# System message allowed.
|
||||
system_message = SystemMessage(content=system_message_content)
|
||||
|
||||
user_message = UserMessage(content=user_content, source="User")
|
||||
input_messages = [system_message] + self._chat_history + [user_message]
|
||||
|
||||
# Double check the types of the input messages.
|
||||
for message in input_messages:
|
||||
for part in message.content:
|
||||
assert isinstance(part, str) or isinstance(part, Image), "Invalid message content type: {}".format(
|
||||
type(part)
|
||||
)
|
||||
|
||||
# Call the model
|
||||
start_time = time.time()
|
||||
response = await self.client.create(input_messages)
|
||||
assert isinstance(response, CreateResult)
|
||||
response_string = response.content
|
||||
assert isinstance(response_string, str)
|
||||
response_message = AssistantMessage(content=response_string, source="Assistant")
|
||||
assert isinstance(response_message, AssistantMessage)
|
||||
self.time_spent_in_model_calls += time.time() - start_time
|
||||
self.num_model_calls += 1
|
||||
|
||||
# Log the model call
|
||||
self.logger.log_model_call(summary=summary, input_messages=input_messages, response=response)
|
||||
|
||||
# Manage the chat history
|
||||
if keep_these_messages:
|
||||
self._chat_history.append(user_message)
|
||||
self._chat_history.append(response_message)
|
||||
|
||||
# Return the response as a string for now
|
||||
return response_string
|
||||
|
||||
def _clear_history(self) -> None:
|
||||
"""
|
||||
Empties the message list containing the chat history.
|
||||
"""
|
||||
self._chat_history = []
|
||||
|
||||
async def learn_from_failure(
|
||||
self, task_description: str, memory_section: str, final_response: str, expected_answer: str, work_history: str
|
||||
) -> str:
|
||||
"""
|
||||
Tries to create an insight to help avoid the given failure in the future.
|
||||
"""
|
||||
sys_message = """- You are a patient and thorough teacher.
|
||||
- Your job is to review work done by students and help them learn how to do better."""
|
||||
|
||||
user_message: List[Union[str, Image]] = []
|
||||
user_message.append("# A team of students made a mistake on the following task:\n")
|
||||
user_message.extend([task_description])
|
||||
|
||||
if len(memory_section) > 0:
|
||||
user_message.append(memory_section)
|
||||
|
||||
user_message.append("# Here's the expected answer, which would have been correct:\n")
|
||||
user_message.append(expected_answer)
|
||||
|
||||
user_message.append("# Here is the students' answer, which was INCORRECT:\n")
|
||||
user_message.append(final_response)
|
||||
|
||||
user_message.append("# Please review the students' work which follows:\n")
|
||||
user_message.append("**----- START OF STUDENTS' WORK -----**\n\n")
|
||||
user_message.append(work_history)
|
||||
user_message.append("\n**----- END OF STUDENTS' WORK -----**\n\n")
|
||||
|
||||
user_message.append(
|
||||
"# Now carefully review the students' work above, explaining in detail what the students did right and what they did wrong.\n"
|
||||
)
|
||||
|
||||
self._clear_history()
|
||||
await self.call_model(
|
||||
summary="Ask the model to learn from this failure",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
user_message = [
|
||||
"Now put yourself in the mind of the students. What misconception led them to their incorrect answer?"
|
||||
]
|
||||
await self.call_model(
|
||||
summary="Ask the model to state the misconception",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
|
||||
user_message = [
|
||||
"Please express your key insights in the form of short, general advice that will be given to the students. Just one or two sentences, or they won't bother to read it."
|
||||
]
|
||||
insight = await self.call_model(
|
||||
summary="Ask the model to formulate a concise insight",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
return insight
|
||||
|
||||
async def find_index_topics(self, input_string: str) -> List[str]:
|
||||
"""
|
||||
Returns a list of topics related to the given string.
|
||||
"""
|
||||
sys_message = """You are an expert at semantic analysis."""
|
||||
|
||||
user_message: List[Union[str, Image]] = []
|
||||
user_message.append("""- My job is to create a thorough index for a book called Task Completion, and I need your help.
|
||||
- Every paragraph in the book needs to be indexed by all the topics related to various kinds of tasks and strategies for completing them.
|
||||
- Your job is to read the text below and extract the task-completion topics that are covered.
|
||||
- The number of topics depends on the length and content of the text. But you should list at least one topic, and potentially many more.
|
||||
- Each topic you list should be a meaningful phrase composed of a few words. Don't use whole sentences as topics.
|
||||
- Don't include details that are unrelated to the general nature of the task, or a potential strategy for completing tasks.
|
||||
- List each topic on a separate line, without any extra text like numbering, or bullets, or any other formatting, because we don't want those things in the index of the book.\n\n""")
|
||||
|
||||
user_message.append("# Text to be indexed\n")
|
||||
user_message.append(input_string)
|
||||
|
||||
self._clear_history()
|
||||
topics = await self.call_model(
|
||||
summary="Ask the model to extract topics", system_message_content=sys_message, user_content=user_message
|
||||
)
|
||||
|
||||
# Parse the topics into a list.
|
||||
topic_list: List[str] = []
|
||||
for line in topics.split("\n"):
|
||||
if len(line) > 0:
|
||||
topic_list.append(line)
|
||||
|
||||
return topic_list
|
||||
|
||||
async def generalize_task(self, task_description: str, revise: bool | None = True) -> str:
|
||||
"""
|
||||
Attempts to rewrite a task description in a more general form.
|
||||
"""
|
||||
|
||||
sys_message = """You are a helpful and thoughtful assistant."""
|
||||
|
||||
user_message: List[Union[str, Image]] = [
|
||||
"We have been given a task description. Our job is not to complete the task, but merely rephrase the task in simpler, more general terms, if possible. Please reach through the following task description, then explain your understanding of the task in detail, as a single, flat list of all the important points."
|
||||
]
|
||||
user_message.append("\n# Task description")
|
||||
user_message.append(task_description)
|
||||
|
||||
self._clear_history()
|
||||
generalized_task = await self.call_model(
|
||||
summary="Ask the model to rephrase the task in a list of important points",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
|
||||
if revise:
|
||||
user_message = [
|
||||
"Do you see any parts of this list that are irrelevant to actually solving the task? If so, explain which items are irrelevant."
|
||||
]
|
||||
await self.call_model(
|
||||
summary="Ask the model to identify irrelevant points",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
|
||||
user_message = [
|
||||
"Revise your original list to include only the most general terms, those that are critical to solving the task, removing any themes or descriptions that are not essential to the solution. Your final list may be shorter, but do not leave out any part of the task that is needed for solving the task. Do not add any additional commentary either before or after the list."
|
||||
]
|
||||
generalized_task = await self.call_model(
|
||||
summary="Ask the model to make a final list of general terms",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
|
||||
return generalized_task
|
||||
|
||||
async def validate_insight(self, insight: str, task_description: str) -> bool:
|
||||
"""
|
||||
Judges whether the insight could help solve the task.
|
||||
"""
|
||||
|
||||
sys_message = """You are a helpful and thoughtful assistant."""
|
||||
|
||||
user_message: List[Union[str, Image]] = [
|
||||
"""We have been given a potential insight that may or may not be useful for solving a given task.
|
||||
- First review the following task.
|
||||
- Then review the insight that follows, and consider whether it might help solve the given task.
|
||||
- Do not attempt to actually solve the task.
|
||||
- Reply with a single character, '1' if the insight may be useful, or '0' if it is not."""
|
||||
]
|
||||
user_message.append("\n# Task description")
|
||||
user_message.append(task_description)
|
||||
user_message.append("\n# Possibly useful insight")
|
||||
user_message.append(insight)
|
||||
self._clear_history()
|
||||
response = await self.call_model(
|
||||
summary="Ask the model to validate the insight",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
return response == "1"
|
||||
|
||||
async def extract_task(self, text: str) -> str | None:
|
||||
"""
|
||||
Returns a task found in the given text, or None if not found.
|
||||
"""
|
||||
sys_message = """You are a helpful and thoughtful assistant."""
|
||||
user_message: List[Union[str, Image]] = [
|
||||
"""Does the following text contain a question or a some task we are being asked to perform?
|
||||
- If so, please reply with the full question or task description, along with any supporting information, but without adding extra commentary or formatting.
|
||||
- If the task is just to remember something, that doesn't count as a task, so don't include it.
|
||||
- If there is no question or task in the text, simply write "None" with no punctuation."""
|
||||
]
|
||||
user_message.append("\n# Text to analyze")
|
||||
user_message.append(text)
|
||||
self._clear_history()
|
||||
response = await self.call_model(
|
||||
summary="Ask the model to extract a task", system_message_content=sys_message, user_content=user_message
|
||||
)
|
||||
return response if response != "None" else None
|
||||
|
||||
async def extract_advice(self, text: str) -> str | None:
|
||||
"""
|
||||
Returns advice from the given text, or None if not found.
|
||||
"""
|
||||
sys_message = """You are a helpful and thoughtful assistant."""
|
||||
user_message: List[Union[str, Image]] = [
|
||||
"""Does the following text contain any information or advice that might be useful later?
|
||||
- If so, please copy the information or advice, adding no extra commentary or formatting.
|
||||
- If there is no potentially useful information or advice at all, simply write "None" with no punctuation."""
|
||||
]
|
||||
user_message.append("\n# Text to analyze")
|
||||
user_message.append(text)
|
||||
self._clear_history()
|
||||
response = await self.call_model(
|
||||
summary="Ask the model to extract advice", system_message_content=sys_message, user_content=user_message
|
||||
)
|
||||
return response if response != "None" else None
|
||||
@@ -0,0 +1,124 @@
|
||||
import os
|
||||
import pickle
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import chromadb
|
||||
from chromadb.api.types import (
|
||||
QueryResult,
|
||||
)
|
||||
from chromadb.config import Settings
|
||||
|
||||
from .utils.page_logger import PageLogger
|
||||
|
||||
|
||||
class StringSimilarityMap:
|
||||
"""
|
||||
Provides storage and similarity-based retrieval of string pairs using a vector database.
|
||||
Each DB entry is a pair of strings: an input string and an output string.
|
||||
The input string is embedded and used as the retrieval key.
|
||||
The output string can be anything, but it's typically used as a dict key.
|
||||
Vector embeddings are currently supplied by Chroma's default Sentence Transformers.
|
||||
|
||||
Args:
|
||||
- reset: True to clear the DB immediately after creation.
|
||||
- path_to_db_dir: Path to the directory where the DB is stored.
|
||||
- logger: An optional logger. If None, no logging will be performed.
|
||||
"""
|
||||
|
||||
def __init__(self, reset: bool, path_to_db_dir: str, logger: PageLogger | None = None) -> None:
|
||||
if logger is None:
|
||||
logger = PageLogger() # Nothing will be logged by this object.
|
||||
self.logger = logger
|
||||
self.path_to_db_dir = path_to_db_dir
|
||||
|
||||
# Load or create the vector DB on disk.
|
||||
chromadb_settings = Settings(
|
||||
anonymized_telemetry=False, allow_reset=True, is_persistent=True, persist_directory=path_to_db_dir
|
||||
)
|
||||
self.db_client = chromadb.Client(chromadb_settings)
|
||||
self.vec_db = self.db_client.create_collection("string-pairs", get_or_create=True) # The collection is the DB.
|
||||
|
||||
# Load or create the associated string-pair dict on disk.
|
||||
self.path_to_dict = os.path.join(path_to_db_dir, "uid_text_dict.pkl")
|
||||
self.uid_text_dict: Dict[str, Tuple[str, str]] = {}
|
||||
self.last_string_pair_id = 0
|
||||
if (not reset) and os.path.exists(self.path_to_dict):
|
||||
self.logger.debug("\nLOADING STRING SIMILARITY MAP FROM DISK at {}".format(self.path_to_dict))
|
||||
with open(self.path_to_dict, "rb") as f:
|
||||
self.uid_text_dict = pickle.load(f)
|
||||
self.last_string_pair_id = len(self.uid_text_dict)
|
||||
if len(self.uid_text_dict) > 0:
|
||||
self.logger.debug("\n{} STRING PAIRS LOADED".format(len(self.uid_text_dict)))
|
||||
self._log_string_pairs()
|
||||
|
||||
# Clear the DB if requested.
|
||||
if reset:
|
||||
self.reset_db()
|
||||
|
||||
def _log_string_pairs(self) -> None:
|
||||
"""
|
||||
Logs all string pairs currently in the map.
|
||||
"""
|
||||
self.logger.debug("LIST OF STRING PAIRS")
|
||||
for uid, text in self.uid_text_dict.items():
|
||||
input_text, output_text = text
|
||||
self.logger.debug(" ID: {}\n INPUT TEXT: {}\n OUTPUT TEXT: {}".format(uid, input_text, output_text))
|
||||
|
||||
def save_string_pairs(self) -> None:
|
||||
"""
|
||||
Saves the string-pair dict (self.uid_text_dict) to disk.
|
||||
"""
|
||||
self.logger.debug("\nSAVING STRING SIMILARITY MAP TO DISK at {}".format(self.path_to_dict))
|
||||
with open(self.path_to_dict, "wb") as file:
|
||||
pickle.dump(self.uid_text_dict, file)
|
||||
|
||||
def reset_db(self) -> None:
|
||||
"""
|
||||
Forces immediate deletion of the DB's contents, in memory and on disk.
|
||||
"""
|
||||
self.logger.debug("\nCLEARING STRING-PAIR MAP")
|
||||
self.db_client.delete_collection("string-pairs")
|
||||
self.vec_db = self.db_client.create_collection("string-pairs")
|
||||
self.uid_text_dict = {}
|
||||
self.save_string_pairs()
|
||||
|
||||
def add_input_output_pair(self, input_text: str, output_text: str) -> None:
|
||||
"""
|
||||
Adds one input-output string pair to the DB.
|
||||
"""
|
||||
self.last_string_pair_id += 1
|
||||
self.vec_db.add(documents=[input_text], ids=[str(self.last_string_pair_id)])
|
||||
self.uid_text_dict[str(self.last_string_pair_id)] = input_text, output_text
|
||||
self.logger.debug(
|
||||
"\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {}\n INPUT\n {}\n OUTPUT\n {}\n".format(
|
||||
self.last_string_pair_id, input_text, output_text
|
||||
)
|
||||
)
|
||||
# self._log_string_pairs() # For deeper debugging, uncomment to log all string pairs after each addition.
|
||||
|
||||
def get_related_string_pairs(
|
||||
self, query_text: str, n_results: int, threshold: Union[int, float]
|
||||
) -> List[Tuple[str, str, float]]:
|
||||
"""
|
||||
Retrieves up to n string pairs that are related to the given query text within the specified distance threshold.
|
||||
"""
|
||||
string_pairs_with_distances: List[Tuple[str, str, float]] = []
|
||||
if n_results > len(self.uid_text_dict):
|
||||
n_results = len(self.uid_text_dict)
|
||||
if n_results > 0:
|
||||
results: QueryResult = self.vec_db.query(query_texts=[query_text], n_results=n_results)
|
||||
num_results = len(results["ids"][0])
|
||||
for i in range(num_results):
|
||||
uid = results["ids"][0][i]
|
||||
input_text = results["documents"][0][i] if results["documents"] else ""
|
||||
distance = results["distances"][0][i] if results["distances"] else 0.0
|
||||
if distance < threshold:
|
||||
input_text_2, output_text = self.uid_text_dict[uid]
|
||||
assert input_text == input_text_2
|
||||
self.logger.debug(
|
||||
"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {}\n OUTPUT\n {}\n DISTANCE\n {}".format(
|
||||
input_text, output_text, distance
|
||||
)
|
||||
)
|
||||
string_pairs_with_distances.append((input_text, output_text, distance))
|
||||
return string_pairs_with_distances
|
||||
@@ -0,0 +1,478 @@
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, List, Tuple, TypedDict
|
||||
|
||||
from agentdhal_core.models import (
|
||||
ChatCompletionClient,
|
||||
)
|
||||
|
||||
from ._memory_bank import Memo, MemoryBank
|
||||
from ._prompter import Prompter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._memory_bank import MemoryBankConfig
|
||||
from .utils.grader import Grader
|
||||
from .utils.page_logger import PageLogger
|
||||
|
||||
|
||||
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
|
||||
# the settings that change frequently, as when loading many settings from a single YAML file.
|
||||
class MemoryControllerConfig(TypedDict, total=False):
|
||||
generalize_task: bool
|
||||
revise_generalized_task: bool
|
||||
generate_topics: bool
|
||||
validate_memos: bool
|
||||
max_memos_to_retrieve: int
|
||||
max_train_trials: int
|
||||
max_test_trials: int
|
||||
MemoryBank: "MemoryBankConfig"
|
||||
|
||||
|
||||
class MemoryController:
|
||||
"""
|
||||
(EXPERIMENTAL, RESEARCH IN PROGRESS)
|
||||
|
||||
Implements fast, memory-based learning, and manages the flow of information to and from a memory bank.
|
||||
|
||||
Args:
|
||||
reset: True to empty the memory bank before starting.
|
||||
client: The model client to use internally.
|
||||
task_assignment_callback: An optional callback used to assign a task to any agent managed by the caller.
|
||||
config: An optional dict that can be used to override the following values:
|
||||
|
||||
- generalize_task: Whether to rewrite tasks in more general terms.
|
||||
- revise_generalized_task: Whether to critique then rewrite the generalized task.
|
||||
- generate_topics: Whether to base retrieval directly on tasks, or on topics extracted from tasks.
|
||||
- validate_memos: Whether to apply a final validation stage to retrieved memos.
|
||||
- max_memos_to_retrieve: The maximum number of memos to return from retrieve_relevant_memos().
|
||||
- max_train_trials: The maximum number of learning iterations to attempt when training on a task.
|
||||
- max_test_trials: The total number of attempts made when testing for failure on a task.
|
||||
- MemoryBank: A config dict passed to MemoryBank.
|
||||
|
||||
logger: An optional logger. If None, a default logger will be created.
|
||||
|
||||
Example:
|
||||
|
||||
The `task-centric-memory` extra first needs to be installed:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[task-centric-memory]"
|
||||
|
||||
The following code snippet shows how to use this class for the most basic storage and retrieval of memories.:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.experimental.task_centric_memory import MemoryController
|
||||
from agentdhal_extensions.experimental.task_centric_memory.utils import PageLogger
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart"}) # Optional, but very useful.
|
||||
memory_controller = MemoryController(reset=True, client=client, logger=logger)
|
||||
|
||||
# Add a few task-insight pairs as memories, where an insight can be any string that may help solve the task.
|
||||
await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color")
|
||||
await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan")
|
||||
await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite")
|
||||
|
||||
# Retrieve memories for a new task that's related to only two of the stored memories.
|
||||
memos = await memory_controller.retrieve_relevant_memos(task="What colors do I like most?")
|
||||
print("{} memories retrieved".format(len(memos)))
|
||||
for memo in memos:
|
||||
print("- " + memo.insight)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reset: bool,
|
||||
client: ChatCompletionClient,
|
||||
task_assignment_callback: Callable[[str], Awaitable[Tuple[str, str]]] | None = None,
|
||||
config: MemoryControllerConfig | None = None,
|
||||
logger: PageLogger | None = None,
|
||||
) -> None:
|
||||
if logger is None:
|
||||
logger = PageLogger({"level": "DEBUG"})
|
||||
self.logger = logger
|
||||
self.logger.enter_function()
|
||||
|
||||
# Apply default settings and any config overrides.
|
||||
self.generalize_task = True
|
||||
self.revise_generalized_task = True
|
||||
self.generate_topics = True
|
||||
self.validate_memos = True
|
||||
self.max_memos_to_retrieve = 10
|
||||
self.max_train_trials = 10
|
||||
self.max_test_trials = 3
|
||||
memory_bank_config = None
|
||||
if config is not None:
|
||||
self.generalize_task = config.get("generalize_task", self.generalize_task)
|
||||
self.revise_generalized_task = config.get("revise_generalized_task", self.revise_generalized_task)
|
||||
self.generate_topics = config.get("generate_topics", self.generate_topics)
|
||||
self.validate_memos = config.get("validate_memos", self.validate_memos)
|
||||
self.max_memos_to_retrieve = config.get("max_memos_to_retrieve", self.max_memos_to_retrieve)
|
||||
self.max_train_trials = config.get("max_train_trials", self.max_train_trials)
|
||||
self.max_test_trials = config.get("max_test_trials", self.max_test_trials)
|
||||
memory_bank_config = config.get("MemoryBank", memory_bank_config)
|
||||
|
||||
self.client = client
|
||||
self.task_assignment_callback = task_assignment_callback
|
||||
self.prompter = Prompter(client, logger)
|
||||
self.memory_bank = MemoryBank(reset=reset, config=memory_bank_config, logger=logger)
|
||||
self.grader = Grader(client, logger)
|
||||
self.logger.leave_function()
|
||||
|
||||
def reset_memory(self) -> None:
|
||||
"""
|
||||
Empties the memory bank in RAM and on disk.
|
||||
"""
|
||||
self.memory_bank.reset()
|
||||
|
||||
async def train_on_task(self, task: str, expected_answer: str) -> None:
|
||||
"""
|
||||
Repeatedly assigns a task to the agent, and tries to learn from failures by creating useful insights as memories.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.logger.info("Iterate on the task, possibly discovering a useful new insight.\n")
|
||||
_, insight = await self._iterate_on_task(task, expected_answer)
|
||||
if insight is None:
|
||||
self.logger.info("No useful insight was discovered.\n")
|
||||
else:
|
||||
self.logger.info("A new insight was created:\n{}".format(insight))
|
||||
await self.add_memo(insight, task)
|
||||
self.logger.leave_function()
|
||||
|
||||
async def test_on_task(self, task: str, expected_answer: str, num_trials: int = 1) -> Tuple[str, int, int]:
|
||||
"""
|
||||
Assigns a task to the agent, along with any relevant memos retrieved from memory.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
assert self.task_assignment_callback is not None
|
||||
response = ""
|
||||
num_successes = 0
|
||||
|
||||
for trial in range(num_trials):
|
||||
self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1))
|
||||
task_plus_insights = task
|
||||
|
||||
# Try to retrieve any relevant memories from the DB.
|
||||
filtered_memos = await self.retrieve_relevant_memos(task)
|
||||
filtered_insights = [memo.insight for memo in filtered_memos]
|
||||
if len(filtered_insights) > 0:
|
||||
self.logger.info("Relevant insights were retrieved from memory.\n")
|
||||
memory_section = self._format_memory_section(filtered_insights)
|
||||
if len(memory_section) > 0:
|
||||
task_plus_insights = task + "\n\n" + memory_section
|
||||
|
||||
# Attempt to solve the task.
|
||||
self.logger.info("Try to solve the task.\n")
|
||||
response, _ = await self.task_assignment_callback(task_plus_insights)
|
||||
|
||||
# Check if the response is correct.
|
||||
response_is_correct, extracted_answer = await self.grader.is_response_correct(
|
||||
task, response, expected_answer
|
||||
)
|
||||
self.logger.info("Extracted answer: {}".format(extracted_answer))
|
||||
if response_is_correct:
|
||||
self.logger.info("Answer is CORRECT.\n")
|
||||
num_successes += 1
|
||||
else:
|
||||
self.logger.info("Answer is INCORRECT.\n")
|
||||
|
||||
# Calculate the success rate as a percentage, rounded to the nearest whole number.
|
||||
self.logger.info("\nSuccess rate: {}%\n".format(round((num_successes / num_trials) * 100)))
|
||||
self.logger.leave_function()
|
||||
return response, num_successes, num_trials
|
||||
|
||||
async def add_memo(self, insight: str, task: None | str = None, index_on_both: bool = True) -> None:
|
||||
"""
|
||||
Adds one insight to the memory bank, using the task (if provided) as context.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
generalized_task = ""
|
||||
if task is not None:
|
||||
self.logger.info("\nGIVEN TASK:")
|
||||
self.logger.info(task)
|
||||
if self.generalize_task:
|
||||
generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task)
|
||||
else:
|
||||
generalized_task = task
|
||||
|
||||
self.logger.info("\nGIVEN INSIGHT:")
|
||||
self.logger.info(insight)
|
||||
|
||||
# Get a list of topics from the insight and the task (if provided).
|
||||
if task is None:
|
||||
text_to_index = insight
|
||||
self.logger.info("\nTOPICS EXTRACTED FROM INSIGHT:")
|
||||
else:
|
||||
if index_on_both:
|
||||
text_to_index = generalized_task.strip() + "\n(Hint: " + insight + ")"
|
||||
self.logger.info("\nTOPICS EXTRACTED FROM TASK AND INSIGHT COMBINED:")
|
||||
else:
|
||||
text_to_index = task
|
||||
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
|
||||
|
||||
if self.generate_topics:
|
||||
topics = await self.prompter.find_index_topics(text_to_index)
|
||||
else:
|
||||
topics = [text_to_index]
|
||||
self.logger.info("\n".join(topics))
|
||||
self.logger.info("")
|
||||
|
||||
# Add the insight to the memory bank.
|
||||
self.memory_bank.add_memo(insight, topics, task)
|
||||
self.logger.leave_function()
|
||||
|
||||
async def add_task_solution_pair_to_memory(self, task: str, solution: str) -> None:
|
||||
"""
|
||||
Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight.
|
||||
This is useful when the task-solution pair is an exemplar of solving a task related to some other task.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
self.logger.info("\nEXAMPLE TASK:")
|
||||
self.logger.info(task)
|
||||
|
||||
self.logger.info("\nEXAMPLE SOLUTION:")
|
||||
self.logger.info(solution)
|
||||
|
||||
# Get a list of topics from the task.
|
||||
if self.generate_topics:
|
||||
topics = await self.prompter.find_index_topics(task.strip())
|
||||
else:
|
||||
topics = [task.strip()]
|
||||
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
|
||||
self.logger.info("\n".join(topics))
|
||||
self.logger.info("")
|
||||
|
||||
# Add the task and solution (as a combined insight) to the memory bank.
|
||||
self.memory_bank.add_task_with_solution(task=task, solution=solution, topics=topics)
|
||||
self.logger.leave_function()
|
||||
|
||||
async def retrieve_relevant_memos(self, task: str) -> List[Memo]:
|
||||
"""
|
||||
Retrieves any memos from memory that seem relevant to the task.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
if self.memory_bank.contains_memos():
|
||||
self.logger.info("\nCURRENT TASK:")
|
||||
self.logger.info(task)
|
||||
|
||||
# Get a list of topics from the generalized task.
|
||||
if self.generalize_task:
|
||||
generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task)
|
||||
else:
|
||||
generalized_task = task
|
||||
if self.generate_topics:
|
||||
task_topics = await self.prompter.find_index_topics(generalized_task)
|
||||
else:
|
||||
task_topics = [generalized_task]
|
||||
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
|
||||
self.logger.info("\n".join(task_topics))
|
||||
self.logger.info("")
|
||||
|
||||
# Retrieve relevant memos from the memory bank.
|
||||
memo_list = self.memory_bank.get_relevant_memos(topics=task_topics)
|
||||
|
||||
# Apply a final validation stage to keep only the memos that the LLM concludes are sufficiently relevant.
|
||||
validated_memos: List[Memo] = []
|
||||
for memo in memo_list:
|
||||
if len(validated_memos) >= self.max_memos_to_retrieve:
|
||||
break
|
||||
if (not self.validate_memos) or await self.prompter.validate_insight(memo.insight, task):
|
||||
validated_memos.append(memo)
|
||||
|
||||
self.logger.info("\n{} VALIDATED MEMOS".format(len(validated_memos)))
|
||||
for memo in validated_memos:
|
||||
if memo.task is not None:
|
||||
self.logger.info("\n TASK: {}".format(memo.task))
|
||||
self.logger.info("\n INSIGHT: {}".format(memo.insight))
|
||||
else:
|
||||
self.logger.info("\nNO SUFFICIENTLY RELEVANT MEMOS WERE FOUND IN MEMORY")
|
||||
validated_memos = []
|
||||
|
||||
self.logger.leave_function()
|
||||
return validated_memos
|
||||
|
||||
def _format_memory_section(self, memories: List[str]) -> str:
|
||||
"""
|
||||
Formats a list of memories as a section for appending to a task description.
|
||||
"""
|
||||
memory_section = ""
|
||||
if len(memories) > 0:
|
||||
memory_section = "## Important insights that may help solve tasks like this\n"
|
||||
for mem in memories:
|
||||
memory_section += "- " + mem + "\n"
|
||||
return memory_section
|
||||
|
||||
async def _test_for_failure(
|
||||
self, task: str, task_plus_insights: str, expected_answer: str
|
||||
) -> Tuple[bool, str, str]:
|
||||
"""
|
||||
Attempts to solve the given task multiple times to find a failure case to learn from.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.logger.info("\nTask description, including any insights: {}".format(task_plus_insights))
|
||||
self.logger.info("\nExpected answer: {}\n".format(expected_answer))
|
||||
|
||||
assert self.task_assignment_callback is not None
|
||||
failure_found = False
|
||||
response, work_history = "", ""
|
||||
|
||||
for trial in range(self.max_test_trials):
|
||||
self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1))
|
||||
|
||||
# Attempt to solve the task.
|
||||
self.logger.info("Try to solve the task.")
|
||||
response, work_history = await self.task_assignment_callback(task_plus_insights)
|
||||
|
||||
response_is_correct, extracted_answer = await self.grader.is_response_correct(
|
||||
task, response, expected_answer
|
||||
)
|
||||
self.logger.info("Extracted answer: {}".format(extracted_answer))
|
||||
if response_is_correct:
|
||||
self.logger.info("Answer is CORRECT.\n")
|
||||
else:
|
||||
self.logger.info("Answer is INCORRECT.\n Stop testing, and return the details of the failure.\n")
|
||||
failure_found = True
|
||||
break
|
||||
|
||||
self.logger.leave_function()
|
||||
return failure_found, response, work_history
|
||||
|
||||
async def _iterate_on_task(self, task: str, expected_answer: str) -> Tuple[str, None | str]:
|
||||
"""
|
||||
Repeatedly assigns a task to the agent, and tries to learn from failures by creating useful insights as memories.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.logger.info("\nTask description: {}".format(task))
|
||||
self.logger.info("\nExpected answer: {}\n".format(expected_answer))
|
||||
|
||||
final_response = ""
|
||||
old_memos = await self.retrieve_relevant_memos(task)
|
||||
old_insights = [memo.insight for memo in old_memos]
|
||||
new_insights: List[str] = []
|
||||
last_insight = None
|
||||
insight = None
|
||||
successful_insight = None
|
||||
|
||||
# Loop until success (or timeout) while learning from failures.
|
||||
for trial in range(1, self.max_train_trials + 1):
|
||||
self.logger.info("\n----- TRAIN TRIAL {} -----\n".format(trial))
|
||||
task_plus_insights = task
|
||||
|
||||
# Add any new insights we've accumulated so far.
|
||||
if last_insight is not None:
|
||||
memory_section = self._format_memory_section(old_insights + [last_insight])
|
||||
else:
|
||||
memory_section = self._format_memory_section(old_insights)
|
||||
if len(memory_section) > 0:
|
||||
task_plus_insights += "\n\n" + memory_section
|
||||
|
||||
# Can we find a failure case to learn from?
|
||||
failure_found, response, work_history = await self._test_for_failure(
|
||||
task, task_plus_insights, expected_answer
|
||||
)
|
||||
if not failure_found:
|
||||
# No. Time to exit the loop.
|
||||
self.logger.info("\nResponse is CORRECT.\n Stop looking for insights.\n")
|
||||
# Was this the first trial?
|
||||
if trial == 1:
|
||||
# Yes. We should return the successful response, and no insight.
|
||||
final_response = response
|
||||
else:
|
||||
# No. We learned a successful insight, which should be returned.
|
||||
successful_insight = insight
|
||||
break
|
||||
|
||||
# Will we try again?
|
||||
if trial == self.max_train_trials:
|
||||
# No. We're out of training trials.
|
||||
self.logger.info("\nNo more trials will be attempted.\n")
|
||||
break
|
||||
|
||||
# Try to learn from this failure.
|
||||
self.logger.info("\nResponse is INCORRECT. Try to learn from this failure.\n")
|
||||
insight = await self.prompter.learn_from_failure(
|
||||
task, memory_section, response, expected_answer, work_history
|
||||
)
|
||||
self.logger.info("\nInsight: {}\n".format(insight))
|
||||
new_insights.append(insight)
|
||||
last_insight = insight
|
||||
|
||||
# Return the answer from the last loop.
|
||||
self.logger.info("\n{}\n".format(final_response))
|
||||
self.logger.leave_function()
|
||||
return final_response, successful_insight
|
||||
|
||||
async def _append_any_relevant_memories(self, task: str) -> str:
|
||||
"""
|
||||
Appends any relevant memories to the task description.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
filtered_memos = await self.retrieve_relevant_memos(task)
|
||||
filtered_insights = [memo.insight for memo in filtered_memos]
|
||||
if len(filtered_insights) > 0:
|
||||
self.logger.info("Relevant insights were retrieved from memory.\n")
|
||||
memory_section = self._format_memory_section(filtered_insights)
|
||||
if len(memory_section) > 0:
|
||||
task = task + "\n\n" + memory_section
|
||||
|
||||
self.logger.leave_function()
|
||||
return task
|
||||
|
||||
async def assign_task(self, task: str, use_memory: bool = True, should_await: bool = True) -> str:
|
||||
"""
|
||||
Assigns a task to some agent through the task_assignment_callback, along with any relevant memories.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
assert self.task_assignment_callback is not None
|
||||
|
||||
if use_memory:
|
||||
task = await self._append_any_relevant_memories(task)
|
||||
|
||||
# Attempt to solve the task.
|
||||
self.logger.info("Try to solve the task.\n")
|
||||
assert should_await
|
||||
response, _ = await self.task_assignment_callback(task)
|
||||
|
||||
self.logger.leave_function()
|
||||
return response
|
||||
|
||||
async def consider_memo_storage(self, text: str) -> str | None:
|
||||
"""
|
||||
Tries to extract any advice from the given text and add it to memory.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
advice = await self.prompter.extract_advice(text)
|
||||
self.logger.info("Advice: {}".format(advice))
|
||||
if advice is not None:
|
||||
await self.add_memo(insight=advice)
|
||||
|
||||
self.logger.leave_function()
|
||||
return advice
|
||||
|
||||
async def handle_user_message(self, text: str, should_await: bool = True) -> str:
|
||||
"""
|
||||
Handles a user message by extracting any advice as an insight to be stored in memory, and then calling assign_task().
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
# Check for advice.
|
||||
advice = await self.consider_memo_storage(text)
|
||||
|
||||
# Assign the task through the task_assignment_callback, using memory only if no advice was just provided.
|
||||
response = await self.assign_task(text, use_memory=(advice is None), should_await=should_await)
|
||||
|
||||
self.logger.leave_function()
|
||||
return response
|
||||
@@ -0,0 +1,15 @@
|
||||
from .apprentice import Apprentice, ApprenticeConfig
|
||||
from .chat_completion_client_recorder import ChatCompletionClientRecorder
|
||||
from .grader import Grader
|
||||
from .page_logger import PageLogger, PageLoggerConfig
|
||||
from .teachability import Teachability
|
||||
|
||||
__all__ = [
|
||||
"Apprentice",
|
||||
"ChatCompletionClientRecorder",
|
||||
"Grader",
|
||||
"PageLogger",
|
||||
"Teachability",
|
||||
"ApprenticeConfig",
|
||||
"PageLoggerConfig",
|
||||
]
|
||||
@@ -0,0 +1,96 @@
|
||||
import hashlib
|
||||
import os
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from agentdhal_core import FunctionCall, Image
|
||||
from agentdhal_core.models import FunctionExecutionResult
|
||||
|
||||
# Convenience types
|
||||
UserContent = Union[str, List[Union[str, Image]]]
|
||||
AssistantContent = Union[str, List[FunctionCall]]
|
||||
FunctionExecutionContent = List[FunctionExecutionResult]
|
||||
SystemContent = str
|
||||
MessageContent = UserContent | AssistantContent | SystemContent | FunctionExecutionContent
|
||||
|
||||
|
||||
def message_content_to_str(message_content: MessageContent | None) -> str:
|
||||
"""
|
||||
Converts the message content to a string.
|
||||
"""
|
||||
if message_content is None:
|
||||
return ""
|
||||
elif isinstance(message_content, str):
|
||||
return message_content
|
||||
elif isinstance(message_content, List):
|
||||
converted: List[str] = list()
|
||||
for item in message_content:
|
||||
if isinstance(item, str):
|
||||
converted.append(item)
|
||||
elif isinstance(item, Image):
|
||||
converted.append("<Image>")
|
||||
else:
|
||||
converted.append(str(item).rstrip())
|
||||
return "\n".join(converted)
|
||||
else:
|
||||
raise AssertionError("Unexpected response type.")
|
||||
|
||||
|
||||
def text_from_user_content(user_content: UserContent) -> str:
|
||||
"""
|
||||
Extracts just the text from the user content.
|
||||
"""
|
||||
if isinstance(user_content, str):
|
||||
return user_content
|
||||
elif isinstance(user_content, List):
|
||||
text_list: List[str] = list()
|
||||
for item in user_content:
|
||||
if isinstance(item, str):
|
||||
text_list.append(item.rstrip())
|
||||
return "\n\n".join(text_list)
|
||||
else:
|
||||
raise AssertionError("Unexpected response type.")
|
||||
|
||||
|
||||
def single_image_from_user_content(user_content: UserContent) -> Union[Image, None]:
|
||||
"""
|
||||
Extracts a single image from the user content.
|
||||
"""
|
||||
image_to_return = None
|
||||
if isinstance(user_content, str):
|
||||
return None
|
||||
elif isinstance(user_content, List):
|
||||
for item in user_content:
|
||||
if isinstance(item, Image):
|
||||
assert image_to_return is None, "Only one image is currently allowed in the user content."
|
||||
image_to_return = item
|
||||
else:
|
||||
raise AssertionError("Unexpected response type.")
|
||||
return image_to_return
|
||||
|
||||
|
||||
def hash_directory(directory: str, hash_algo: str = "sha256") -> Tuple[str, int, int]:
|
||||
"""Computes a hash representing the state of a directory, including its structure and file contents."""
|
||||
hash_func = hashlib.new(hash_algo)
|
||||
|
||||
# Also count the number of files and sub-directories
|
||||
num_files = 0
|
||||
num_subdirs = 0
|
||||
|
||||
for root, dirs, files in sorted(os.walk(directory)): # Ensure order for consistent hashing
|
||||
num_files += len(files)
|
||||
num_subdirs += len(dirs)
|
||||
for dir_name in sorted(dirs):
|
||||
hash_func.update(dir_name.encode()) # Hash directory names
|
||||
|
||||
for file_name in sorted(files):
|
||||
file_path = os.path.join(root, file_name)
|
||||
hash_func.update(file_name.encode()) # Hash file names
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
while chunk := f.read(4096): # Read in chunks
|
||||
hash_func.update(chunk)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return hash_func.hexdigest(), num_files, num_subdirs
|
||||
@@ -0,0 +1,257 @@
|
||||
import random
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, List, Sequence, Tuple, TypedDict
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.base import TaskResult
|
||||
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, TextMessage
|
||||
from agentdhal_core.models import (
|
||||
ChatCompletionClient,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from .page_logger import PageLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..memory_controller import MemoryControllerConfig
|
||||
|
||||
|
||||
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
|
||||
# the settings that change frequently, as when loading many settings from a single YAML file.
|
||||
class ApprenticeConfig(TypedDict, total=False):
|
||||
name_of_agent_or_team: str
|
||||
disable_prefix_caching: bool
|
||||
MemoryController: "MemoryControllerConfig"
|
||||
|
||||
|
||||
class Apprentice:
|
||||
"""
|
||||
A minimal wrapper combining task-centric memory with an agent or team.
|
||||
Applications may use the Apprentice class, or they may directly instantiate
|
||||
and call the Memory Controller using this class as an example.
|
||||
|
||||
Args:
|
||||
client: The client to call the model.
|
||||
config: An optional dict that can be used to override the following values:
|
||||
|
||||
- name_of_agent_or_team: The name of the target agent or team for assigning tasks to.
|
||||
- disable_prefix_caching: True to disable prefix caching by prepending random ints to the first message.
|
||||
- MemoryController: A config dict passed to MemoryController.
|
||||
|
||||
logger: An optional logger. If None, a default logger will be created.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ChatCompletionClient,
|
||||
config: ApprenticeConfig | None = None,
|
||||
logger: PageLogger | None = None,
|
||||
) -> None:
|
||||
if logger is None:
|
||||
logger = PageLogger({"level": "DEBUG"})
|
||||
self.logger = logger
|
||||
|
||||
# Apply default settings and any config overrides.
|
||||
self.name_of_agent_or_team = "AssistantAgent"
|
||||
self.disable_prefix_caching = False
|
||||
memory_controller_config = None
|
||||
if config is not None:
|
||||
self.name_of_agent_or_team = config.get("name_of_agent_or_team", self.name_of_agent_or_team)
|
||||
self.disable_prefix_caching = config.get("disable_prefix_caching", self.disable_prefix_caching)
|
||||
memory_controller_config = config.get("MemoryController", memory_controller_config)
|
||||
|
||||
self.client = client
|
||||
if self.disable_prefix_caching:
|
||||
self.rand = random.Random()
|
||||
self.rand.seed(int(time.time() * 1000))
|
||||
|
||||
# Create the MemoryController, which creates the MemoryBank.
|
||||
from ..memory_controller import MemoryController
|
||||
|
||||
self.memory_controller = MemoryController(
|
||||
reset=True,
|
||||
client=self.client,
|
||||
task_assignment_callback=self.assign_task_to_agent_or_team,
|
||||
config=memory_controller_config,
|
||||
logger=self.logger,
|
||||
)
|
||||
|
||||
def reset_memory(self) -> None:
|
||||
"""
|
||||
Resets the memory bank.
|
||||
"""
|
||||
self.memory_controller.reset_memory()
|
||||
|
||||
async def handle_user_message(self, text: str, should_await: bool = True) -> str:
|
||||
"""
|
||||
Handles a user message, extracting any advice and assigning a task to the agent.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
# Pass the user message through to the memory controller.
|
||||
response = await self.memory_controller.handle_user_message(text, should_await)
|
||||
|
||||
self.logger.leave_function()
|
||||
return response
|
||||
|
||||
async def add_task_solution_pair_to_memory(self, task: str, solution: str) -> None:
|
||||
"""
|
||||
Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight.
|
||||
This is useful when the insight is a demonstration of how to solve a given type of task.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
# Pass the task and solution through to the memory controller.
|
||||
await self.memory_controller.add_task_solution_pair_to_memory(task, solution)
|
||||
|
||||
self.logger.leave_function()
|
||||
|
||||
async def assign_task(self, task: str, use_memory: bool = True, should_await: bool = True) -> str:
|
||||
"""
|
||||
Assigns a task to the agent, along with any relevant insights/memories.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
# Pass the task through to the memory controller.
|
||||
response = await self.memory_controller.assign_task(task, use_memory, should_await)
|
||||
|
||||
self.logger.leave_function()
|
||||
return response
|
||||
|
||||
async def train_on_task(self, task: str, expected_answer: str) -> None:
|
||||
"""
|
||||
Repeatedly assigns a task to the completion agent, and tries to learn from failures by creating useful insights as memories.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
# Pass the task through to the memory controller.
|
||||
await self.memory_controller.train_on_task(task, expected_answer)
|
||||
|
||||
self.logger.leave_function()
|
||||
|
||||
async def assign_task_to_agent_or_team(self, task: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Passes the given task to the target agent or team.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
# Pass the task through.
|
||||
if self.name_of_agent_or_team == "MagenticOneGroupChat":
|
||||
response, work_history = await self._assign_task_to_magentic_one(task)
|
||||
elif self.name_of_agent_or_team == "AssistantAgent":
|
||||
response, work_history = await self._assign_task_to_assistant_agent(task)
|
||||
else:
|
||||
raise AssertionError("Invalid base agent")
|
||||
|
||||
self.logger.leave_function()
|
||||
return response, work_history
|
||||
|
||||
async def _assign_task_to_assistant_agent(self, task: str) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Passes the given task to a newly created AssistantAgent with a generic 6-step system prompt.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.logger.info(task)
|
||||
|
||||
system_message_content = """You are a helpful and thoughtful assistant.
|
||||
In responding to every user message, you follow the same multi-step process given here:
|
||||
1. Explain your understanding of the user message in detail, covering all the important points.
|
||||
2. List as many possible responses as you can think of.
|
||||
3. Carefully list and weigh the pros and cons (if any) of each possible response.
|
||||
4. Critique the pros and cons above, looking for any flaws in your reasoning. But don't make up flaws that don't exist.
|
||||
5. Decide on the best response, looping back to step 1 if none of the responses are satisfactory.
|
||||
6. Finish by providing your final response in the particular format requested by the user."""
|
||||
|
||||
if self.disable_prefix_caching:
|
||||
# Prepend a random int to disable prefix caching.
|
||||
random_str = "({})\n\n".format(self.rand.randint(0, 1000000))
|
||||
system_message_content = random_str + system_message_content
|
||||
|
||||
system_message: LLMMessage
|
||||
if self.client.model_info["family"] == "o1":
|
||||
# No system message allowed, so pass it as the first user message.
|
||||
system_message = UserMessage(content=system_message_content, source="User")
|
||||
else:
|
||||
# System message allowed.
|
||||
system_message = SystemMessage(content=system_message_content)
|
||||
|
||||
user_message: LLMMessage = UserMessage(content=task, source="User")
|
||||
system_message_list: List[LLMMessage] = [system_message]
|
||||
user_message_list: List[LLMMessage] = [user_message]
|
||||
input_messages: List[LLMMessage] = system_message_list + user_message_list
|
||||
|
||||
assistant_agent = AssistantAgent(
|
||||
"assistant_agent",
|
||||
self.client,
|
||||
system_message=system_message_content,
|
||||
)
|
||||
|
||||
# Get the agent's response to the task.
|
||||
task_result: TaskResult = await assistant_agent.run(task=TextMessage(content=task, source="User"))
|
||||
messages: Sequence[BaseAgentEvent | BaseChatMessage] = task_result.messages
|
||||
message: BaseAgentEvent | BaseChatMessage = messages[-1]
|
||||
response_str = message.to_text()
|
||||
|
||||
# Log the model call
|
||||
self.logger.log_model_task(
|
||||
summary="Ask the model to complete the task", input_messages=input_messages, task_result=task_result
|
||||
)
|
||||
self.logger.info("\n----- RESPONSE -----\n\n{}\n".format(response_str))
|
||||
|
||||
# Use the response as the work history as well.
|
||||
work_history = response_str
|
||||
|
||||
self.logger.leave_function()
|
||||
return response_str, work_history
|
||||
|
||||
async def _assign_task_to_magentic_one(self, task: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Instantiates a MagenticOneGroupChat team, and passes the given task to it.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.logger.info(task)
|
||||
|
||||
general_agent = AssistantAgent(
|
||||
"general_agent",
|
||||
self.client,
|
||||
description="A general GPT-4o AI assistant capable of performing a variety of tasks.",
|
||||
)
|
||||
|
||||
from agentdhal_extensions.agents.web_surfer import MultimodalWebSurfer
|
||||
|
||||
web_surfer = MultimodalWebSurfer(
|
||||
name="web_surfer",
|
||||
model_client=self.client,
|
||||
downloads_folder="logs",
|
||||
debug_dir="logs",
|
||||
to_save_screenshots=True,
|
||||
)
|
||||
|
||||
from agentdhal_agentchat.teams import MagenticOneGroupChat
|
||||
|
||||
team = MagenticOneGroupChat(
|
||||
[general_agent, web_surfer],
|
||||
model_client=self.client,
|
||||
max_turns=20,
|
||||
)
|
||||
|
||||
# Get the team's response to the task.
|
||||
task_result: TaskResult = await team.run(task=task)
|
||||
|
||||
assert isinstance(task_result, TaskResult)
|
||||
messages = task_result.messages
|
||||
|
||||
response_str_list: List[str] = []
|
||||
for message in messages:
|
||||
response_str_list.append(message.to_text())
|
||||
response_str = "\n".join(response_str_list)
|
||||
|
||||
self.logger.info("\n----- RESPONSE -----\n\n{}\n".format(response_str))
|
||||
|
||||
# MagenticOne's response is the chat history, which we use here as the work history.
|
||||
work_history = response_str
|
||||
|
||||
self.logger.leave_function()
|
||||
return response_str, work_history
|
||||
@@ -0,0 +1,227 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Sequence, TypedDict, Union
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.models import (
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
LLMMessage,
|
||||
ModelCapabilities, # type: ignore
|
||||
ModelInfo,
|
||||
RequestUsage,
|
||||
)
|
||||
from agentdhal_core.tools import Tool, ToolSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .page_logger import PageLogger
|
||||
|
||||
|
||||
class RecordDict(TypedDict):
|
||||
mode: Literal["create", "create_stream"]
|
||||
messages: List[Mapping[str, Any]]
|
||||
response: Dict[str, Any]
|
||||
stream: List[Mapping[str, Any]]
|
||||
|
||||
|
||||
class ChatCompletionClientRecorder(ChatCompletionClient):
|
||||
"""
|
||||
A chat completion client that supports fast, large-scale tests of code calling LLM clients.
|
||||
|
||||
Two modes are supported:
|
||||
|
||||
1. "record": delegates to the underlying client while also recording the input messages and responses,
|
||||
which are saved to disk when finalize() is called.
|
||||
2. "replay": loads previously recorded message and responses from disk, then on each call
|
||||
checks that its message matches the recorded message, and returns the recorded response.
|
||||
|
||||
The recorded data is stored as a JSON list of records. Each record is a dictionary with a "mode"
|
||||
field (either "create" or "create_stream"), a serialized list of messages, and either a "response" (for
|
||||
create calls) or a "stream" (a list of streamed outputs for create_stream calls).
|
||||
|
||||
ReplayChatCompletionClient and ChatCompletionCache do similar things, but with significant differences:
|
||||
|
||||
- ReplayChatCompletionClient replays pre-defined responses in a specified order without recording anything or checking the messages sent to the client.
|
||||
- ChatCompletionCache caches responses and replays them for messages that have been seen before, regardless of order, and calls the base client for any uncached messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ChatCompletionClient,
|
||||
mode: Literal["record", "replay"],
|
||||
session_file_path: str,
|
||||
logger: PageLogger | None = None,
|
||||
) -> None:
|
||||
if logger is None:
|
||||
self.logger = PageLogger() # Disabled by default.
|
||||
else:
|
||||
self.logger = logger
|
||||
self.logger.enter_function()
|
||||
self.logger.info("Wrapping the base client in ChatCompletionClientRecorder.")
|
||||
|
||||
self.base_client = client
|
||||
self.mode = mode
|
||||
self.session_file_path = os.path.expanduser(session_file_path)
|
||||
self.records: List[RecordDict] = []
|
||||
self._record_index = 0
|
||||
self._num_checked_records = 0
|
||||
if self.mode == "record":
|
||||
# Prepare to record the messages and responses.
|
||||
self.logger.info("Recording mode enabled.\nRecording session to: " + self.session_file_path)
|
||||
elif self.mode == "replay":
|
||||
# Load the previously recorded messages and responses from disk.
|
||||
self.logger.info("Replay mode enabled.\nRetrieving session from: " + self.session_file_path)
|
||||
try:
|
||||
with open(self.session_file_path, "r") as f:
|
||||
self.records = json.load(f)
|
||||
except Exception as e:
|
||||
error_str = f"\nFailed to load recorded session: '{self.session_file_path}': {e}"
|
||||
self.logger.error(error_str)
|
||||
raise ValueError(error_str) from e
|
||||
|
||||
self.logger.leave_function()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
) -> CreateResult:
|
||||
current_messages: List[Mapping[str, Any]] = [msg.model_dump() for msg in messages]
|
||||
if self.mode == "record":
|
||||
response = await self.base_client.create(
|
||||
messages,
|
||||
tools=tools,
|
||||
json_output=json_output,
|
||||
tool_choice=tool_choice,
|
||||
extra_create_args=extra_create_args,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
rec: RecordDict = {
|
||||
"mode": "create",
|
||||
"messages": current_messages,
|
||||
"response": response.model_dump(),
|
||||
"stream": [],
|
||||
}
|
||||
self.records.append(rec)
|
||||
return response
|
||||
elif self.mode == "replay":
|
||||
if self._record_index >= len(self.records):
|
||||
error_str = "\nNo more recorded turns to check."
|
||||
self.logger.error(error_str)
|
||||
raise ValueError(error_str)
|
||||
rec = self.records[self._record_index]
|
||||
if rec.get("mode") != "create":
|
||||
error_str = f"\nRecorded call type mismatch at index {self._record_index}: expected 'create', got '{rec.get('mode')}'."
|
||||
self.logger.error(error_str)
|
||||
raise ValueError(error_str)
|
||||
recorded_messages = rec.get("messages")
|
||||
if recorded_messages != current_messages:
|
||||
error_str = (
|
||||
"\nCurrent message list doesn't match the recorded message list. See the pagelogs for details."
|
||||
)
|
||||
assert recorded_messages is not None
|
||||
self.logger.log_dict_list(recorded_messages, "recorded message list")
|
||||
assert current_messages is not None
|
||||
self.logger.log_dict_list(current_messages, "current message list")
|
||||
self.logger.error(error_str)
|
||||
raise ValueError(error_str)
|
||||
self._record_index += 1
|
||||
self._num_checked_records += 1
|
||||
|
||||
data = rec.get("response")
|
||||
# Populate a CreateResult from the data.
|
||||
assert data is not None
|
||||
result = CreateResult(
|
||||
content=data.get("content", ""),
|
||||
finish_reason=data.get("finish_reason", "stop"),
|
||||
usage=data.get("usage", RequestUsage(prompt_tokens=0, completion_tokens=0)),
|
||||
cached=True,
|
||||
)
|
||||
return result
|
||||
|
||||
else:
|
||||
error_str = f"\nUnknown mode: {self.mode}"
|
||||
self.logger.error(error_str)
|
||||
raise ValueError(error_str)
|
||||
|
||||
def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
return self.base_client.create_stream(
|
||||
messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
json_output=json_output,
|
||||
extra_create_args=extra_create_args,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.base_client.close()
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
# Calls base_client.actual_usage() and returns the result.
|
||||
return self.base_client.actual_usage()
|
||||
|
||||
def total_usage(self) -> RequestUsage:
|
||||
# Calls base_client.total_usage() and returns the result.
|
||||
return self.base_client.total_usage()
|
||||
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
# Calls base_client.count_tokens() and returns the result.
|
||||
return self.base_client.count_tokens(messages, tools=tools)
|
||||
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
# Calls base_client.remaining_tokens() and returns the result.
|
||||
return self.base_client.remaining_tokens(messages, tools=tools)
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ModelCapabilities: # type: ignore
|
||||
# Calls base_client.capabilities and returns the result.
|
||||
warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2)
|
||||
return self.base_client.capabilities
|
||||
|
||||
@property
|
||||
def model_info(self) -> ModelInfo:
|
||||
# Calls base_client.model_info and returns the result.
|
||||
return self.base_client.model_info
|
||||
|
||||
def finalize(self) -> None:
|
||||
"""
|
||||
In record mode, saves the accumulated records to disk.
|
||||
In replay mode, makes sure all the records were checked.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
if self.mode == "record":
|
||||
try:
|
||||
# Create the directory if it doesn't exist.
|
||||
os.makedirs(os.path.dirname(self.session_file_path), exist_ok=True)
|
||||
# Write the records to disk.
|
||||
with open(self.session_file_path, "w") as f:
|
||||
json.dump(self.records, f, indent=2)
|
||||
self.logger.info("\nRecorded session was saved to: " + self.session_file_path)
|
||||
except Exception as e:
|
||||
error_str = f"Failed to write records to '{self.session_file_path}': {e}"
|
||||
self.logger.error(error_str)
|
||||
raise ValueError(error_str) from e
|
||||
elif self.mode == "replay":
|
||||
if self._num_checked_records < len(self.records):
|
||||
error_str = f"\nEarly termination. Only {self._num_checked_records} of the {len(self.records)} recorded turns were checked."
|
||||
self.logger.error(error_str)
|
||||
raise ValueError(error_str)
|
||||
self.logger.info("\nRecorded session was fully replayed and checked.")
|
||||
self.logger.leave_function()
|
||||
@@ -0,0 +1,179 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
|
||||
from agentdhal_core import Image
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from ._functions import UserContent
|
||||
from .page_logger import PageLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .apprentice import Apprentice
|
||||
|
||||
|
||||
class Grader:
|
||||
"""
|
||||
Runs basic tests, and determines task success without limitation to string matches.
|
||||
|
||||
Args:
|
||||
client: The client to call the model.
|
||||
logger: An optional logger. If None, no logging will be performed.
|
||||
"""
|
||||
|
||||
def __init__(self, client: ChatCompletionClient, logger: PageLogger | None = None) -> None:
|
||||
if logger is None:
|
||||
logger = PageLogger() # Nothing will be logged by this object.
|
||||
self.logger = logger
|
||||
self.client = client
|
||||
|
||||
# Create the chat history
|
||||
self._chat_history: List[LLMMessage] = []
|
||||
|
||||
async def test_apprentice(
|
||||
self,
|
||||
apprentice: Apprentice,
|
||||
task_description: str,
|
||||
expected_answer: str,
|
||||
num_trials: int,
|
||||
use_memory: bool,
|
||||
client: ChatCompletionClient,
|
||||
) -> Tuple[int, int]:
|
||||
self.logger.enter_function()
|
||||
|
||||
self.logger.info("Testing the apprentice on the given task.\n")
|
||||
|
||||
num_successes = 0
|
||||
|
||||
for trial in range(num_trials):
|
||||
self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1))
|
||||
self.logger.info("Try to solve the task.\n")
|
||||
response = await apprentice.assign_task(task_description, use_memory=use_memory)
|
||||
response_is_correct, extracted_answer = await self.is_response_correct(
|
||||
task_description, response, expected_answer
|
||||
)
|
||||
self.logger.info("Extracted answer: {}".format(extracted_answer))
|
||||
if response_is_correct:
|
||||
self.logger.info("Answer is CORRECT.\n")
|
||||
num_successes += 1
|
||||
else:
|
||||
self.logger.info("Answer is INCORRECT.\n")
|
||||
|
||||
self.logger.info("\nSuccess rate: {}%\n".format(round((num_successes / num_trials) * 100)))
|
||||
self.logger.leave_function()
|
||||
return num_successes, num_trials
|
||||
|
||||
async def call_model(
|
||||
self,
|
||||
summary: str,
|
||||
user_content: UserContent,
|
||||
system_message_content: str | None = None,
|
||||
keep_these_messages: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Calls the model client with the given input and returns the response.
|
||||
"""
|
||||
# Prepare the input message list
|
||||
if system_message_content is None:
|
||||
system_message_content = "You are a helpful assistant."
|
||||
system_message: LLMMessage
|
||||
if self.client.model_info["family"] == "o1":
|
||||
# No system message allowed, so pass it as the first user message.
|
||||
system_message = UserMessage(content=system_message_content, source="User")
|
||||
else:
|
||||
# System message allowed.
|
||||
system_message = SystemMessage(content=system_message_content)
|
||||
user_message = UserMessage(content=user_content, source="User")
|
||||
input_messages = [system_message] + self._chat_history + [user_message]
|
||||
|
||||
# Call the model.
|
||||
response = await self.client.create(input_messages)
|
||||
assert isinstance(response, CreateResult)
|
||||
response_string = response.content
|
||||
assert isinstance(response_string, str)
|
||||
response_message = AssistantMessage(content=response_string, source="Assistant")
|
||||
assert isinstance(response_message, AssistantMessage)
|
||||
|
||||
# Log the model call
|
||||
self.logger.log_model_call(summary=summary, input_messages=input_messages, response=response)
|
||||
|
||||
# Manage the chat history
|
||||
if keep_these_messages:
|
||||
self._chat_history.append(user_message)
|
||||
self._chat_history.append(response_message)
|
||||
|
||||
# Return the response as a string
|
||||
return response_string
|
||||
|
||||
def _clear_history(self) -> None:
|
||||
"""
|
||||
Empties the message list containing the chat history.
|
||||
"""
|
||||
self._chat_history = []
|
||||
|
||||
async def is_response_correct(
|
||||
self, task_description: str, response_to_be_graded: str, correct_answer: str
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Determines whether the response is equivalent to the task's correct answer.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
sys_message = """You are a helpful and thoughtful assistant."""
|
||||
|
||||
# Ask the model to extract the answer from the response.
|
||||
user_message: List[Union[str, Image]] = []
|
||||
user_message.append("""Your job is to extract a possible answer to the following question from the given text.
|
||||
- First review the following task.
|
||||
- Then review the text that follows, which may an answer, plus reasoning that led to the answer.
|
||||
- Do not attempt to actually solve the task yourself.
|
||||
- Don't try to judge whether the reasoning steps were correct.
|
||||
- Simply respond by summarizing the answer described in the text, omitting any other parts of the text.
|
||||
- If no answer is present can be extracted from the text, simply reply "None".""")
|
||||
user_message.append("\n# Task description")
|
||||
user_message.append(task_description)
|
||||
user_message.append("\n# Text that may contain an answer")
|
||||
user_message.append(response_to_be_graded)
|
||||
user_message_arg: UserContent = user_message
|
||||
self._clear_history()
|
||||
extracted_answer = await self.call_model(
|
||||
summary="Ask the model to extract the answer",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message_arg,
|
||||
)
|
||||
self.logger.info("Extracted answer: " + extracted_answer)
|
||||
|
||||
# Ask the model to check the answer for correctness.
|
||||
user_message = [
|
||||
"""Your job is to decide whether a given answer to a task is correct or not.
|
||||
- You will be given the task description and the correct, gold-standard answer, along with the answer to be graded.
|
||||
- In general, an answer is correct if it is equivalent to the correct answer.
|
||||
- Specifically, the given answer must contain the important information from the correct answer, and must not in any way contradict the correct answer.
|
||||
- Ignore any differences of grammar, spelling mistakes, punctuation, capitalization, formatting, or extra commentary.
|
||||
- An answer should be considered correct if it omits information that is clearly inferred.
|
||||
- For instance, if the correct answer is "Paris, France", the answer "Paris" should be considered correct.
|
||||
- Respond with a single character: '1' if the answer to be graded is correct", '0' if not."""
|
||||
]
|
||||
user_message.append("\n# Task description")
|
||||
user_message.append(task_description)
|
||||
user_message.append("\n# Correct answer")
|
||||
user_message.append(correct_answer)
|
||||
user_message.append("\n# Answer to be graded")
|
||||
user_message.append(extracted_answer)
|
||||
self._clear_history()
|
||||
decision = await self.call_model(
|
||||
summary="Ask the model to check the answer for correctness",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
self.logger.info("Decision: " + decision)
|
||||
|
||||
self.logger.leave_function()
|
||||
return decision == "1", extracted_answer
|
||||
@@ -0,0 +1,546 @@
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict
|
||||
|
||||
from agentdhal_agentchat.base import TaskResult
|
||||
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage
|
||||
from agentdhal_core import Image
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
CreateResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from ._functions import MessageContent, hash_directory
|
||||
|
||||
|
||||
def _html_opening(file_title: str, finished: bool = False) -> str:
|
||||
"""
|
||||
Returns the opening text of a simple HTML file.
|
||||
"""
|
||||
refresh_tag = '<meta http-equiv="refresh" content="2">' if not finished else ""
|
||||
st = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
{refresh_tag}
|
||||
<title>{file_title}</title>
|
||||
<style>
|
||||
body {{font-size: 20px}}
|
||||
body {{white-space: pre-wrap}}
|
||||
</style>
|
||||
</head>
|
||||
<body>"""
|
||||
return st
|
||||
|
||||
|
||||
def _html_closing() -> str:
|
||||
"""
|
||||
Return the closing text of a simple HTML file.
|
||||
"""
|
||||
return """</body></html>"""
|
||||
|
||||
|
||||
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
|
||||
# the settings that change frequently, as when loading many settings from a single YAML file.
|
||||
class PageLoggerConfig(TypedDict, total=False):
|
||||
level: str
|
||||
path: str
|
||||
|
||||
|
||||
class PageLogger:
|
||||
"""
|
||||
Logs text and images to a set of HTML pages, one per function/method, linked to each other in a call tree.
|
||||
|
||||
Args:
|
||||
config: An optional dict that can be used to override the following values:
|
||||
|
||||
- level: The logging level, one of DEBUG, INFO, WARNING, ERROR, CRITICAL, or NONE.
|
||||
- path: The path to the directory where the log files will be written.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PageLoggerConfig | None = None) -> None:
|
||||
self.levels = {
|
||||
"DEBUG": 10,
|
||||
"INFO": 20,
|
||||
"WARNING": 30,
|
||||
"ERROR": 40,
|
||||
"CRITICAL": 50,
|
||||
"NONE": 100,
|
||||
}
|
||||
|
||||
# Apply default settings and any config overrides.
|
||||
level_str = "NONE" # Default to no logging at all.
|
||||
self.log_dir = "./pagelogs/default"
|
||||
if config is not None:
|
||||
level_str = config.get("level", level_str)
|
||||
self.log_dir = config.get("path", self.log_dir)
|
||||
self.level = self.levels[level_str]
|
||||
self.log_dir = os.path.expanduser(self.log_dir)
|
||||
|
||||
# If the logging level is set to NONE or higher, don't log anything.
|
||||
if self.level >= self.levels["NONE"]:
|
||||
return
|
||||
|
||||
self.page_stack = PageStack()
|
||||
self.pages: List[Page] = []
|
||||
self.last_page_id = 0
|
||||
self.name = "0 Call Tree"
|
||||
self._create_run_dir()
|
||||
self.flush()
|
||||
self.finalized = False
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.finalize()
|
||||
|
||||
def finalize(self) -> None:
|
||||
# Writes a hash of the log directory to a file for change detection.
|
||||
if self.level >= self.levels["NONE"]:
|
||||
return
|
||||
|
||||
# Don't finalize the log if it has already been finalized.
|
||||
if self.finalized:
|
||||
return
|
||||
|
||||
# Do nothing if the app is being forced to exit early.
|
||||
if self.page_stack.size() > 0:
|
||||
return
|
||||
|
||||
self.flush(finished=True)
|
||||
|
||||
# Write the hash and other details to a file.
|
||||
hash_str, num_files, num_subdirs = hash_directory(self.log_dir)
|
||||
hash_path = os.path.join(self.log_dir, "hash.txt")
|
||||
with open(hash_path, "w") as f:
|
||||
f.write(hash_str)
|
||||
f.write("\n")
|
||||
f.write("{} files\n".format(num_files))
|
||||
f.write("{} subdirectories\n".format(num_subdirs))
|
||||
|
||||
self.finalized = True
|
||||
|
||||
@staticmethod
|
||||
def _decorate_text(text: str, color: str, weight: str = "bold", demarcate: bool = False) -> str:
|
||||
"""
|
||||
Returns a string of text with HTML styling for weight and color.
|
||||
"""
|
||||
if demarcate:
|
||||
text = f"<<<<< {text} >>>>>"
|
||||
return f'<span style="color: {color}; font-weight: {weight};">{text}</span>'
|
||||
|
||||
@staticmethod
|
||||
def _link_to_image(image_path: str, description: str) -> str:
|
||||
"""
|
||||
Returns an HTML string defining a thumbnail link to an image.
|
||||
"""
|
||||
# To avoid a bug in heml rendering aht displays underscores to the left of thumbnails,
|
||||
# define the following string on a single line.
|
||||
link = f"""<a href="{image_path}"><img src="{image_path}" alt="{description}" style="width: 300px; height: auto;"></a>"""
|
||||
return link
|
||||
|
||||
def _get_next_page_id(self) -> int:
|
||||
"""Returns the next page id and increments the counter."""
|
||||
self.last_page_id += 1
|
||||
return self.last_page_id
|
||||
|
||||
def _create_run_dir(self) -> None:
|
||||
"""Creates a fresh log directory."""
|
||||
if os.path.exists(self.log_dir):
|
||||
shutil.rmtree(self.log_dir)
|
||||
os.makedirs(self.log_dir)
|
||||
|
||||
def _add_page(self, summary: str, show_in_call_tree: bool = True, finished: bool = True) -> "Page":
|
||||
"""
|
||||
Adds a new page to the log.
|
||||
"""
|
||||
page = Page(
|
||||
page_logger=self,
|
||||
index=self._get_next_page_id(),
|
||||
summary=summary,
|
||||
indent_level=len(self.page_stack.stack),
|
||||
show_in_call_tree=show_in_call_tree,
|
||||
finished=finished,
|
||||
)
|
||||
self.pages.append(page)
|
||||
self.flush()
|
||||
if len(self.page_stack.stack) > 0:
|
||||
# Insert a link to the new page into the calling page.
|
||||
self.info("\n" + page.full_link)
|
||||
return page
|
||||
|
||||
def _log_text(self, text: str) -> None:
|
||||
"""
|
||||
Adds text to the current page.
|
||||
"""
|
||||
page = self.page_stack.top()
|
||||
if page is not None:
|
||||
page.add_lines(text, flush=True)
|
||||
|
||||
def debug(self, line: str) -> None:
|
||||
"""
|
||||
Adds DEBUG text to the current page if debugging level <= DEBUG.
|
||||
"""
|
||||
if self.level <= self.levels["DEBUG"]:
|
||||
self._log_text(line)
|
||||
|
||||
def info(self, line: str) -> None:
|
||||
"""
|
||||
Adds INFO text to the current page if debugging level <= INFO.
|
||||
"""
|
||||
if self.level <= self.levels["INFO"]:
|
||||
self._log_text(line)
|
||||
|
||||
def warning(self, line: str) -> None:
|
||||
"""
|
||||
Adds WARNING text to the current page if debugging level <= WARNING.
|
||||
"""
|
||||
if self.level <= self.levels["WARNING"]:
|
||||
self._log_text(line)
|
||||
|
||||
def error(self, line: str) -> None:
|
||||
"""
|
||||
Adds ERROR text to the current page if debugging level <= ERROR.
|
||||
"""
|
||||
if self.level <= self.levels["ERROR"]:
|
||||
self._log_text(line)
|
||||
|
||||
def critical(self, line: str) -> None:
|
||||
"""
|
||||
Adds CRITICAL text to the current page if debugging level <= CRITICAL.
|
||||
"""
|
||||
if self.level <= self.levels["CRITICAL"]:
|
||||
self._log_text(line)
|
||||
|
||||
def _message_source(self, message: LLMMessage) -> str:
|
||||
"""
|
||||
Returns a decorated string indicating the source of a message.
|
||||
"""
|
||||
source = "UNKNOWN"
|
||||
color = "black"
|
||||
if isinstance(message, SystemMessage):
|
||||
source = "SYSTEM"
|
||||
color = "purple"
|
||||
elif isinstance(message, UserMessage):
|
||||
source = "USER"
|
||||
color = "blue"
|
||||
elif isinstance(message, AssistantMessage):
|
||||
source = "ASSISTANT"
|
||||
color = "green"
|
||||
elif isinstance(message, FunctionExecutionResultMessage):
|
||||
source = "FUNCTION"
|
||||
color = "red"
|
||||
return self._decorate_text(source, color, demarcate=True)
|
||||
|
||||
def _format_message_content(self, message_content: MessageContent) -> str:
|
||||
"""
|
||||
Formats the message content for logging.
|
||||
"""
|
||||
# Start by converting the message content to a list of strings.
|
||||
content_list: List[str] = []
|
||||
content = message_content
|
||||
if isinstance(content, str):
|
||||
content_list.append(content)
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
content_list.append(item.rstrip())
|
||||
elif isinstance(item, Image):
|
||||
# Save the image to disk.
|
||||
image_filename = str(self._get_next_page_id()) + " image.jpg"
|
||||
image_path = os.path.join(self.log_dir, image_filename)
|
||||
item.image.save(image_path)
|
||||
# Add a link to the image.
|
||||
content_list.append(self._link_to_image(image_filename, "message_image"))
|
||||
elif isinstance(item, Dict):
|
||||
# Add a dictionary to the log.
|
||||
json_str = json.dumps(item, indent=4)
|
||||
content_list.append(json_str)
|
||||
else:
|
||||
content_list.append(str(item).rstrip())
|
||||
else:
|
||||
content_list.append("<UNKNOWN MESSAGE CONTENT>")
|
||||
|
||||
# Convert the list of strings to a single string containing newline separators.
|
||||
output = ""
|
||||
for item in content_list:
|
||||
output += f"\n{item}\n"
|
||||
return output
|
||||
|
||||
def log_message_content(self, message_content: MessageContent, summary: str) -> None:
|
||||
"""
|
||||
Adds a page containing the message's content, including any images.
|
||||
"""
|
||||
if self.level > self.levels["INFO"]:
|
||||
return None
|
||||
page = self._add_page(summary=summary, show_in_call_tree=False)
|
||||
self.page_stack.write_stack_to_page(page)
|
||||
page.add_lines(self._format_message_content(message_content=message_content))
|
||||
page.flush()
|
||||
|
||||
def log_dict_list(self, content: List[Mapping[str, Any]], summary: str) -> None:
|
||||
"""
|
||||
Adds a page containing a list of dicts.
|
||||
"""
|
||||
if self.level > self.levels["INFO"]:
|
||||
return None
|
||||
page = self._add_page(summary=summary, show_in_call_tree=False)
|
||||
self.page_stack.write_stack_to_page(page)
|
||||
|
||||
for item in content:
|
||||
json_str = json.dumps(item, indent=4)
|
||||
page.add_lines(json_str)
|
||||
|
||||
page.flush()
|
||||
|
||||
def _log_model_messages(
|
||||
self, summary: str, input_messages: List[LLMMessage], response_str: str, usage: RequestUsage | None
|
||||
) -> Optional["Page"]:
|
||||
"""
|
||||
Adds a page containing the messages to a model (including any input images) and its response.
|
||||
"""
|
||||
page = self._add_page(summary=summary, show_in_call_tree=False)
|
||||
self.page_stack.write_stack_to_page(page)
|
||||
|
||||
if usage is not None:
|
||||
page.add_lines("{} prompt tokens".format(usage.prompt_tokens))
|
||||
page.add_lines("{} completion tokens".format(usage.completion_tokens))
|
||||
for m in input_messages:
|
||||
page.add_lines("\n" + self._message_source(m))
|
||||
page.add_lines(self._format_message_content(message_content=m.content))
|
||||
page.add_lines("\n" + self._decorate_text("ASSISTANT RESPONSE", "green", demarcate=True))
|
||||
page.add_lines("\n" + response_str + "\n")
|
||||
page.flush()
|
||||
return page
|
||||
|
||||
def log_model_call(
|
||||
self, summary: str, input_messages: List[LLMMessage], response: CreateResult
|
||||
) -> Optional["Page"]:
|
||||
"""
|
||||
Logs messages sent to a model and the TaskResult response to a new page.
|
||||
"""
|
||||
if self.level > self.levels["INFO"]:
|
||||
return None
|
||||
|
||||
response_str = response.content
|
||||
if not isinstance(response_str, str):
|
||||
response_str = "??"
|
||||
|
||||
page = self._log_model_messages(summary, input_messages, response_str, response.usage)
|
||||
return page
|
||||
|
||||
def log_model_task(
|
||||
self, summary: str, input_messages: List[LLMMessage], task_result: TaskResult
|
||||
) -> Optional["Page"]:
|
||||
"""
|
||||
Logs messages sent to a model and the TaskResult response to a new page.
|
||||
"""
|
||||
if self.level > self.levels["INFO"]:
|
||||
return None
|
||||
|
||||
messages: Sequence[BaseAgentEvent | BaseChatMessage] = task_result.messages
|
||||
message = messages[-1]
|
||||
response_str = message.to_text()
|
||||
if not isinstance(response_str, str):
|
||||
response_str = "??"
|
||||
|
||||
if hasattr(message, "models_usage"):
|
||||
usage: RequestUsage | None = message.models_usage
|
||||
else:
|
||||
usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
page = self._log_model_messages(summary, input_messages, response_str, usage)
|
||||
return page
|
||||
|
||||
def log_link_to_local_file(self, file_path: str) -> str:
|
||||
"""
|
||||
Returns a link to a local file in the log.
|
||||
"""
|
||||
file_name = os.path.basename(file_path)
|
||||
link = f'<a href="{file_name}">{file_name}</a>'
|
||||
return link
|
||||
|
||||
def add_link_to_image(self, description: str, source_image_path: str) -> None:
|
||||
"""
|
||||
Inserts a thumbnail link to an image to the page.
|
||||
"""
|
||||
# Remove every character from the string 'description' that is not alphanumeric or a space.
|
||||
description = "".join(e for e in description if e.isalnum() or e.isspace())
|
||||
target_image_filename = str(self._get_next_page_id()) + " - " + description
|
||||
# Copy the image to the log directory.
|
||||
local_image_path = os.path.join(self.log_dir, target_image_filename)
|
||||
shutil.copyfile(source_image_path, local_image_path)
|
||||
self._log_text("\n" + description)
|
||||
self._log_text(self._link_to_image(target_image_filename, description))
|
||||
|
||||
def flush(self, finished: bool = False) -> None:
|
||||
"""
|
||||
Writes the current state of the log to disk.
|
||||
"""
|
||||
if self.level > self.levels["INFO"]:
|
||||
return
|
||||
# Create a call tree of the log.
|
||||
call_tree_path = os.path.join(self.log_dir, self.name + ".html")
|
||||
with open(call_tree_path, "w") as f:
|
||||
f.write(_html_opening("0 Call Tree", finished=finished))
|
||||
f.write(f"<h3>{self.name}</h3>")
|
||||
f.write("\n")
|
||||
for page in self.pages:
|
||||
if page.show_in_call_tree:
|
||||
f.write(page.line_text + "\n")
|
||||
f.write("\n")
|
||||
f.write(_html_closing())
|
||||
|
||||
def enter_function(self) -> Optional["Page"]:
|
||||
"""
|
||||
Adds a new page corresponding to the current function call.
|
||||
"""
|
||||
if self.level > self.levels["INFO"]:
|
||||
return None
|
||||
|
||||
page = None
|
||||
frame_type = inspect.currentframe()
|
||||
if frame_type is not None:
|
||||
frame = frame_type.f_back # Get the calling frame
|
||||
if frame is not None:
|
||||
# Check if it's a method by looking for 'self' or 'cls' in f_locals
|
||||
if "self" in frame.f_locals:
|
||||
class_name = type(frame.f_locals["self"]).__name__
|
||||
elif "cls" in frame.f_locals:
|
||||
class_name = frame.f_locals["cls"].__name__
|
||||
else:
|
||||
class_name = None # Not part of a class
|
||||
|
||||
if class_name is None: # Not part of a class
|
||||
caller_name = frame.f_code.co_name
|
||||
else:
|
||||
caller_name = class_name + "." + frame.f_code.co_name
|
||||
|
||||
# Create a new page for this function.
|
||||
page = self._add_page(summary=caller_name, show_in_call_tree=True, finished=False)
|
||||
self.page_stack.push(page)
|
||||
self.page_stack.write_stack_to_page(page)
|
||||
|
||||
page.add_lines("\nENTER {}".format(caller_name), flush=True)
|
||||
return page
|
||||
|
||||
def leave_function(self) -> None:
|
||||
"""
|
||||
Finishes the page corresponding to the current function call.
|
||||
"""
|
||||
if self.level > self.levels["INFO"]:
|
||||
return None
|
||||
page = self.page_stack.top()
|
||||
if page is not None:
|
||||
page.finished = True
|
||||
page.add_lines("\nLEAVE {}".format(page.summary), flush=True)
|
||||
self.page_stack.pop()
|
||||
|
||||
|
||||
class Page:
|
||||
"""
|
||||
Represents a single HTML page in the logger output.
|
||||
|
||||
Args:
|
||||
page_logger: The PageLogger object that created this page.
|
||||
index: The index of the page.
|
||||
summary: A brief summary of the page's contents for display.
|
||||
indent_level: The level of indentation in the call tree.
|
||||
show_in_call_tree: Whether to display the page in the call tree.
|
||||
finished: Whether the page is complete.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
page_logger: PageLogger,
|
||||
index: int,
|
||||
summary: str,
|
||||
indent_level: int,
|
||||
show_in_call_tree: bool = True,
|
||||
finished: bool = True,
|
||||
):
|
||||
"""
|
||||
Initializes and writes to a new HTML page.
|
||||
"""
|
||||
self.page_logger = page_logger
|
||||
self.index_str = str(index)
|
||||
self.summary = summary
|
||||
self.indent_level = indent_level
|
||||
self.show_in_call_tree = show_in_call_tree
|
||||
self.finished = finished
|
||||
self.file_title = self.index_str + " " + self.summary
|
||||
self.indentation_text = "| " * self.indent_level
|
||||
self.full_link = f'<a href="{self.index_str}.html">{self.file_title}</a>'
|
||||
self.line_text = self.indentation_text + self.full_link
|
||||
self.lines: List[str] = []
|
||||
self.flush()
|
||||
|
||||
def add_lines(self, lines: str, flush: bool = False) -> None:
|
||||
"""
|
||||
Adds one or more lines to the page.
|
||||
"""
|
||||
lines_to_add: List[str] = []
|
||||
if "\n" in lines:
|
||||
lines_to_add = lines.split("\n")
|
||||
else:
|
||||
lines_to_add.append(lines)
|
||||
self.lines.extend(lines_to_add)
|
||||
if flush:
|
||||
self.flush()
|
||||
|
||||
def flush(self) -> None:
|
||||
"""
|
||||
Writes the HTML page to disk.
|
||||
"""
|
||||
page_path = os.path.join(self.page_logger.log_dir, self.index_str + ".html")
|
||||
with open(page_path, "w") as f:
|
||||
f.write(_html_opening(self.file_title, finished=self.finished))
|
||||
f.write(f"<h3>{self.file_title}</h3>\n")
|
||||
for line in self.lines:
|
||||
try:
|
||||
f.write(f"{line}\n")
|
||||
except UnicodeEncodeError:
|
||||
f.write("UnicodeEncodeError in this line.\n")
|
||||
f.write(_html_closing())
|
||||
f.flush()
|
||||
|
||||
|
||||
class PageStack:
|
||||
"""
|
||||
A call stack containing a list of currently active function pages in the order they called each other.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.stack: List[Page] = []
|
||||
|
||||
def push(self, page: Page) -> None:
|
||||
"""Adds a page to the top of the stack."""
|
||||
self.stack.append(page)
|
||||
|
||||
def pop(self) -> Page:
|
||||
"""Removes and returns the top page from the stack"""
|
||||
return self.stack.pop()
|
||||
|
||||
def size(self) -> int:
|
||||
"""Returns the number of pages in the stack."""
|
||||
return len(self.stack)
|
||||
|
||||
def top(self) -> Page | None:
|
||||
"""Returns the top page from the stack without removing it"""
|
||||
if self.size() == 0:
|
||||
return None
|
||||
return self.stack[-1]
|
||||
|
||||
def write_stack_to_page(self, page: Page) -> None:
|
||||
# Logs a properly indented string displaying the current call stack.
|
||||
page.add_lines("\nCALL STACK")
|
||||
for stack_page in self.stack:
|
||||
page.add_lines(stack_page.line_text)
|
||||
page.add_lines("")
|
||||
page.add_lines("")
|
||||
page.flush()
|
||||
@@ -0,0 +1,133 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agentdhal_core import CancellationToken, Image
|
||||
from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
|
||||
from agentdhal_core.model_context import ChatCompletionContext
|
||||
from agentdhal_core.models import UserMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentdhal_extensions.experimental.task_centric_memory import MemoryController
|
||||
|
||||
|
||||
class Teachability(Memory):
|
||||
"""
|
||||
Gives an AssistantAgent the ability to learn quickly from user teachings, hints, and advice.
|
||||
|
||||
Steps for usage:
|
||||
|
||||
1. Instantiate MemoryController.
|
||||
2. Instantiate Teachability, passing the memory controller as a parameter.
|
||||
3. Instantiate an AssistantAgent, passing the teachability instance (wrapped in a list) as the memory parameter.
|
||||
4. Use the AssistantAgent as usual, such as for chatting with the user.
|
||||
"""
|
||||
|
||||
def __init__(self, memory_controller: "MemoryController", name: str | None = None) -> None:
|
||||
"""Initialize Teachability."""
|
||||
self._memory_controller = memory_controller
|
||||
self._logger = memory_controller.logger
|
||||
self._name = name or "teachability"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get the memory instance identifier."""
|
||||
return self._name
|
||||
|
||||
def _extract_text(self, content_item: str | MemoryContent) -> str:
|
||||
"""Extract searchable text from content."""
|
||||
if isinstance(content_item, str):
|
||||
return content_item
|
||||
|
||||
content = content_item.content
|
||||
mime_type = content_item.mime_type
|
||||
|
||||
if mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]:
|
||||
return str(content)
|
||||
elif mime_type == MemoryMimeType.JSON:
|
||||
if isinstance(content, dict):
|
||||
# Store original JSON string representation
|
||||
return str(content).lower()
|
||||
raise ValueError("JSON content must be a dict")
|
||||
elif isinstance(content, Image):
|
||||
raise ValueError("Image content cannot be converted to text")
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {mime_type}")
|
||||
|
||||
async def update_context(
|
||||
self,
|
||||
model_context: ChatCompletionContext,
|
||||
) -> UpdateContextResult:
|
||||
"""
|
||||
Extracts any advice from the last user turn to be stored in memory,
|
||||
and adds any relevant memories to the model context.
|
||||
"""
|
||||
self._logger.enter_function()
|
||||
|
||||
# Extract text from the user's last message
|
||||
messages = await model_context.get_messages()
|
||||
if not messages:
|
||||
self._logger.leave_function()
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
|
||||
last_message = messages[-1]
|
||||
last_user_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
|
||||
|
||||
# Add any relevant memories to the chat history
|
||||
query_results = await self.query(last_user_text)
|
||||
if query_results.results:
|
||||
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)]
|
||||
memory_context = "\nPotentially relevant memories:\n" + "\n".join(memory_strings)
|
||||
await model_context.add_message(UserMessage(content=memory_context, source="user"))
|
||||
|
||||
# Add any user advice to memory
|
||||
await self._memory_controller.consider_memo_storage(last_user_text)
|
||||
|
||||
self._logger.leave_function()
|
||||
return UpdateContextResult(memories=query_results)
|
||||
|
||||
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
|
||||
"""
|
||||
Tries to extract any advice from the passed content and add it to memory.
|
||||
"""
|
||||
self._logger.enter_function()
|
||||
|
||||
# Extract text from the incoming content
|
||||
text = self._extract_text(content)
|
||||
|
||||
# Check for advice to add to memory for later turns.
|
||||
await self._memory_controller.consider_memo_storage(text)
|
||||
|
||||
self._logger.leave_function()
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str | MemoryContent,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MemoryQueryResult:
|
||||
"""
|
||||
Returns any memories that seem relevant to the query.
|
||||
"""
|
||||
self._logger.enter_function()
|
||||
|
||||
task = self._extract_text(query)
|
||||
memory_results: list[MemoryContent] = []
|
||||
filtered_memos = await self._memory_controller.retrieve_relevant_memos(task=task)
|
||||
filtered_insights = [memo.insight for memo in filtered_memos]
|
||||
for insight in filtered_insights:
|
||||
self._logger.info(f"Insight: {insight}")
|
||||
memory_content = MemoryContent(
|
||||
content=insight,
|
||||
mime_type="MemoryMimeType.TEXT",
|
||||
metadata={},
|
||||
)
|
||||
memory_results.append(memory_content)
|
||||
|
||||
self._logger.leave_function()
|
||||
return MemoryQueryResult(results=memory_results)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all entries from memory."""
|
||||
self._memory_controller.reset_memory()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up memory resources."""
|
||||
pass # No cleanup needed for this memory implementation
|
||||
1
agent_dhal/agentdhal_extensions/memory/__init__.py
Normal file
1
agent_dhal/agentdhal_extensions/memory/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from ._text_canvas import TextCanvas
|
||||
from ._text_canvas_memory import TextCanvasMemory
|
||||
|
||||
__all__ = ["TextCanvas", "TextCanvasMemory"]
|
||||
50
agent_dhal/agentdhal_extensions/memory/canvas/_canvas.py
Normal file
50
agent_dhal/agentdhal_extensions/memory/canvas/_canvas.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
|
||||
class BaseCanvas(ABC):
|
||||
"""
|
||||
An abstract protocol for "canvas" objects that maintain
|
||||
revision history for file-like data. Concrete subclasses
|
||||
can handle text, images, structured data, etc.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental API and may change in the future.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self) -> Dict[str, int]:
|
||||
"""
|
||||
Returns a dict of filename -> latest revision number.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_latest_content(self, filename: str) -> Union[str, bytes, Any]:
|
||||
"""
|
||||
Returns the latest version of a file's content.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_or_update_file(self, filename: str, new_content: Union[str, bytes, Any]) -> None:
|
||||
"""
|
||||
Creates or updates the file content with a new revision.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_diff(self, filename: str, from_revision: int, to_revision: int) -> str:
|
||||
"""
|
||||
Returns a diff (in some format) between two revisions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_patch(self, filename: str, patch_data: Union[str, bytes, Any]) -> None:
|
||||
"""
|
||||
Applies a patch/diff to the latest revision and increments the revision.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,64 @@
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._text_canvas import TextCanvas
|
||||
|
||||
|
||||
class UpdateFileArgs(BaseModel):
|
||||
filename: str
|
||||
new_content: str
|
||||
|
||||
|
||||
class UpdateFileResult(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class UpdateFileTool(BaseTool[UpdateFileArgs, UpdateFileResult]):
|
||||
"""
|
||||
Overwrites or creates a file in the canvas.
|
||||
"""
|
||||
|
||||
def __init__(self, canvas: TextCanvas):
|
||||
super().__init__(
|
||||
args_type=UpdateFileArgs,
|
||||
return_type=UpdateFileResult,
|
||||
name="update_file",
|
||||
description="Create/update a file on the canvas with the provided content.",
|
||||
)
|
||||
self._canvas = canvas
|
||||
|
||||
async def run(self, args: UpdateFileArgs, cancellation_token: CancellationToken) -> UpdateFileResult:
|
||||
self._canvas.add_or_update_file(args.filename, args.new_content)
|
||||
return UpdateFileResult(status="OK")
|
||||
|
||||
|
||||
class ApplyPatchArgs(BaseModel):
|
||||
filename: str
|
||||
patch_text: str
|
||||
|
||||
|
||||
class ApplyPatchResult(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class ApplyPatchTool(BaseTool[ApplyPatchArgs, ApplyPatchResult]):
|
||||
"""
|
||||
Applies a unified diff patch to the given file on the canvas.
|
||||
"""
|
||||
|
||||
def __init__(self, canvas: TextCanvas):
|
||||
super().__init__(
|
||||
args_type=ApplyPatchArgs,
|
||||
return_type=ApplyPatchResult,
|
||||
name="apply_patch",
|
||||
description=(
|
||||
"Apply a unified diff patch to an existing file on the canvas. "
|
||||
"The patch must be in diff/patch format. The file must exist or be created first."
|
||||
),
|
||||
)
|
||||
self._canvas = canvas
|
||||
|
||||
async def run(self, args: ApplyPatchArgs, cancellation_token: CancellationToken) -> ApplyPatchResult:
|
||||
self._canvas.apply_patch(args.filename, args.patch_text)
|
||||
return ApplyPatchResult(status="PATCH APPLIED")
|
||||
192
agent_dhal/agentdhal_extensions/memory/canvas/_text_canvas.py
Normal file
192
agent_dhal/agentdhal_extensions/memory/canvas/_text_canvas.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import difflib
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
try: # pragma: no cover
|
||||
from unidiff import PatchSet
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
PatchSet = None # type: ignore
|
||||
|
||||
from ._canvas import BaseCanvas
|
||||
|
||||
|
||||
class FileRevision:
|
||||
"""Tracks the history of one file's content."""
|
||||
|
||||
__slots__ = ("content", "revision")
|
||||
|
||||
def __init__(self, content: str, revision: int) -> None:
|
||||
self.content: str = content
|
||||
self.revision: int = revision # e.g. an integer, a timestamp, or git hash
|
||||
|
||||
|
||||
class TextCanvas(BaseCanvas):
|
||||
"""An in‑memory canvas that stores *text* files with full revision history.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental API and may change in the future.
|
||||
|
||||
Besides the original CRUD‑like operations, this enhanced implementation adds:
|
||||
|
||||
* **apply_patch** – applies patches using the ``unidiff`` library for accurate
|
||||
hunk application and context line validation.
|
||||
* **get_revision_content** – random access to any historical revision.
|
||||
* **get_revision_diffs** – obtain the list of diffs applied between every
|
||||
consecutive pair of revisions so that a caller can replay or audit the
|
||||
full change history.
|
||||
"""
|
||||
|
||||
# ----------------------------------------------------------------------------------
|
||||
# Construction helpers
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
def __init__(self) -> None:
|
||||
# For each file we keep an *ordered* list of FileRevision where the last
|
||||
# element is the most recent. Using a list keeps the memory footprint
|
||||
# small and preserves order without any extra bookkeeping.
|
||||
self._files: Dict[str, List[FileRevision]] = {}
|
||||
|
||||
# ----------------------------------------------------------------------------------
|
||||
# Internal utilities
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
def _latest_idx(self, filename: str) -> int:
|
||||
"""Return the index (not revision number) of the newest revision."""
|
||||
return len(self._files.get(filename, [])) - 1
|
||||
|
||||
def _ensure_file(self, filename: str) -> None:
|
||||
if filename not in self._files:
|
||||
raise ValueError(f"File '{filename}' does not exist on the canvas; create it first.")
|
||||
|
||||
# ----------------------------------------------------------------------------------
|
||||
# Revision inspection helpers
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
def get_revision_content(self, filename: str, revision: int) -> str: # NEW 🚀
|
||||
"""Return the exact content stored in *revision*.
|
||||
|
||||
If the revision does not exist an empty string is returned so that
|
||||
downstream code can handle the "not found" case without exceptions.
|
||||
"""
|
||||
for rev in self._files.get(filename, []):
|
||||
if rev.revision == revision:
|
||||
return rev.content
|
||||
return ""
|
||||
|
||||
def get_revision_diffs(self, filename: str) -> List[str]: # NEW 🚀
|
||||
"""Return a *chronological* list of unified‑diffs for *filename*.
|
||||
|
||||
Each element in the returned list represents the diff that transformed
|
||||
revision *n* into revision *n+1* (starting at revision 1 → 2).
|
||||
"""
|
||||
revisions = self._files.get(filename, [])
|
||||
diffs: List[str] = []
|
||||
for i in range(1, len(revisions)):
|
||||
older, newer = revisions[i - 1], revisions[i]
|
||||
diff = difflib.unified_diff(
|
||||
older.content.splitlines(keepends=True),
|
||||
newer.content.splitlines(keepends=True),
|
||||
fromfile=f"{filename}@r{older.revision}",
|
||||
tofile=f"{filename}@r{newer.revision}",
|
||||
)
|
||||
diffs.append("".join(diff))
|
||||
return diffs
|
||||
|
||||
# ----------------------------------------------------------------------------------
|
||||
# BaseCanvas interface implementation
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
def list_files(self) -> Dict[str, int]:
|
||||
"""Return a mapping of *filename → latest revision number*."""
|
||||
return {fname: revs[-1].revision for fname, revs in self._files.items() if revs}
|
||||
|
||||
def get_latest_content(self, filename: str) -> str: # noqa: D401 – keep API identical
|
||||
"""Return the most recent content or an empty string if the file is new."""
|
||||
revs = self._files.get(filename, [])
|
||||
return revs[-1].content if revs else ""
|
||||
|
||||
def add_or_update_file(self, filename: str, new_content: Union[str, bytes, Any]) -> None:
|
||||
"""Create *filename* or append a new revision containing *new_content*."""
|
||||
if isinstance(new_content, bytes):
|
||||
new_content = new_content.decode("utf-8")
|
||||
if not isinstance(new_content, str):
|
||||
raise ValueError(f"Expected str or bytes, got {type(new_content)}")
|
||||
if filename not in self._files:
|
||||
self._files[filename] = [FileRevision(new_content, 1)]
|
||||
else:
|
||||
last_rev_num = self._files[filename][-1].revision
|
||||
self._files[filename].append(FileRevision(new_content, last_rev_num + 1))
|
||||
|
||||
def get_diff(self, filename: str, from_revision: int, to_revision: int) -> str:
|
||||
"""Return a unified diff between *from_revision* and *to_revision*."""
|
||||
revisions = self._files.get(filename, [])
|
||||
if not revisions:
|
||||
return ""
|
||||
# Fetch the contents for the requested revisions.
|
||||
from_content = self.get_revision_content(filename, from_revision)
|
||||
to_content = self.get_revision_content(filename, to_revision)
|
||||
if from_content == "" and to_content == "": # one (or both) revision ids not found
|
||||
return ""
|
||||
diff = difflib.unified_diff(
|
||||
from_content.splitlines(keepends=True),
|
||||
to_content.splitlines(keepends=True),
|
||||
fromfile=f"{filename}@r{from_revision}",
|
||||
tofile=f"{filename}@r{to_revision}",
|
||||
)
|
||||
return "".join(diff)
|
||||
|
||||
def apply_patch(self, filename: str, patch_data: Union[str, bytes, Any]) -> None:
|
||||
"""Apply *patch_text* (unified diff) to the latest revision and save a new revision.
|
||||
|
||||
Uses the *unidiff* library to accurately apply hunks and validate context lines.
|
||||
"""
|
||||
if isinstance(patch_data, bytes):
|
||||
patch_data = patch_data.decode("utf-8")
|
||||
if not isinstance(patch_data, str):
|
||||
raise ValueError(f"Expected str or bytes, got {type(patch_data)}")
|
||||
self._ensure_file(filename)
|
||||
original_content = self.get_latest_content(filename)
|
||||
|
||||
if PatchSet is None:
|
||||
raise ImportError(
|
||||
"The 'unidiff' package is required for patch application. Install with 'pip install unidiff'."
|
||||
)
|
||||
|
||||
patch = PatchSet(patch_data)
|
||||
# Our canvas stores exactly one file per patch operation so we
|
||||
# use the first (and only) patched_file object.
|
||||
if not patch:
|
||||
raise ValueError("Empty patch text provided.")
|
||||
patched_file = patch[0]
|
||||
working_lines = original_content.splitlines(keepends=True)
|
||||
line_offset = 0
|
||||
for hunk in patched_file:
|
||||
# Calculate the slice boundaries in the *current* working copy.
|
||||
start = hunk.source_start - 1 + line_offset
|
||||
end = start + hunk.source_length
|
||||
# Build the replacement block for this hunk.
|
||||
replacement: List[str] = []
|
||||
for line in hunk:
|
||||
if line.is_added or line.is_context:
|
||||
replacement.append(line.value)
|
||||
# removed lines (line.is_removed) are *not* added.
|
||||
# Replace the slice with the hunk‑result.
|
||||
working_lines[start:end] = replacement
|
||||
line_offset += len(replacement) - (end - start)
|
||||
new_content = "".join(working_lines)
|
||||
|
||||
# Finally commit the new revision.
|
||||
self.add_or_update_file(filename, new_content)
|
||||
|
||||
# ----------------------------------------------------------------------------------
|
||||
# Convenience helpers
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
def get_all_contents_for_context(self) -> str: # noqa: D401 – keep public API stable
|
||||
"""Return a summarised view of every file and its *latest* revision."""
|
||||
out: List[str] = ["=== CANVAS FILES ==="]
|
||||
for fname, revs in self._files.items():
|
||||
latest = revs[-1]
|
||||
out.append(f"File: {fname} (rev {latest.revision}):\n{latest.content}\n")
|
||||
out.append("=== END OF CANVAS ===")
|
||||
return "\n".join(out)
|
||||
@@ -0,0 +1,229 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.memory import (
|
||||
Memory,
|
||||
MemoryContent,
|
||||
MemoryMimeType,
|
||||
MemoryQueryResult,
|
||||
UpdateContextResult,
|
||||
)
|
||||
from agentdhal_core.model_context import ChatCompletionContext
|
||||
from agentdhal_core.models import SystemMessage
|
||||
|
||||
from ._canvas_writer import ApplyPatchTool, UpdateFileTool
|
||||
from ._text_canvas import TextCanvas
|
||||
|
||||
|
||||
class TextCanvasMemory(Memory):
|
||||
"""
|
||||
A memory implementation that uses a Canvas for storing file-like content.
|
||||
Inserts the current state of the canvas into the ChatCompletionContext on each turn.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental API and may change in the future.
|
||||
|
||||
The TextCanvasMemory provides a persistent, file-like storage mechanism that can be used
|
||||
by agents to read and write content. It automatically injects the current state of all files
|
||||
in the canvas into the model context before each inference.
|
||||
|
||||
This is particularly useful for:
|
||||
- Allowing agents to create and modify documents over multiple turns
|
||||
- Enabling collaborative document editing between multiple agents
|
||||
- Maintaining persistent state across conversation turns
|
||||
- Working with content too large to fit in a single message
|
||||
|
||||
The canvas provides tools for:
|
||||
- Creating or updating files with new content
|
||||
- Applying patches (unified diff format) to existing files
|
||||
|
||||
Examples:
|
||||
|
||||
**Example: Using TextCanvasMemory with an AssistantAgent**
|
||||
|
||||
The following example demonstrates how to create a TextCanvasMemory and use it with
|
||||
an AssistantAgent to write and update a story file.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
from agentdhal_extensions.memory.canvas import TextCanvasMemory
|
||||
|
||||
|
||||
async def main():
|
||||
# Create a model client
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o",
|
||||
# api_key = "your_openai_api_key"
|
||||
)
|
||||
|
||||
# Create the canvas memory
|
||||
text_canvas_memory = TextCanvasMemory()
|
||||
|
||||
# Get tools for working with the canvas
|
||||
update_file_tool = text_canvas_memory.get_update_file_tool()
|
||||
apply_patch_tool = text_canvas_memory.get_apply_patch_tool()
|
||||
|
||||
# Create an agent with the canvas memory and tools
|
||||
writer_agent = AssistantAgent(
|
||||
name="Writer",
|
||||
model_client=model_client,
|
||||
description="A writer agent that creates and updates stories.",
|
||||
system_message='''
|
||||
You are a Writer Agent. Your focus is to generate a story based on the user's request.
|
||||
|
||||
Instructions for using the canvas:
|
||||
|
||||
- The story should be stored on the canvas in a file named "story.md".
|
||||
- If "story.md" does not exist, create it by calling the 'update_file' tool.
|
||||
- If "story.md" already exists, generate a unified diff (patch) from the current
|
||||
content to the new version, and call the 'apply_patch' tool to apply the changes.
|
||||
|
||||
IMPORTANT: Do not include the full story text in your chat messages.
|
||||
Only write the story content to the canvas using the tools.
|
||||
''',
|
||||
tools=[update_file_tool, apply_patch_tool],
|
||||
memory=[text_canvas_memory],
|
||||
)
|
||||
|
||||
# Send a message to the agent
|
||||
await writer_agent.on_messages(
|
||||
[TextMessage(content="Write a short story about a bunny and a sunflower.", source="user")],
|
||||
CancellationToken(),
|
||||
)
|
||||
|
||||
# Retrieve the content from the canvas
|
||||
story_content = text_canvas_memory.canvas.get_latest_content("story.md")
|
||||
print("Story content from canvas:")
|
||||
print(story_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
**Example: Using TextCanvasMemory with multiple agents**
|
||||
|
||||
The following example shows how to use TextCanvasMemory with multiple agents
|
||||
collaborating on the same document.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_extensions.memory.canvas import TextCanvasMemory
|
||||
|
||||
|
||||
async def main():
|
||||
# Create a model client
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o",
|
||||
# api_key = "your_openai_api_key"
|
||||
)
|
||||
|
||||
# Create the shared canvas memory
|
||||
text_canvas_memory = TextCanvasMemory()
|
||||
update_file_tool = text_canvas_memory.get_update_file_tool()
|
||||
apply_patch_tool = text_canvas_memory.get_apply_patch_tool()
|
||||
|
||||
# Create a writer agent
|
||||
writer_agent = AssistantAgent(
|
||||
name="Writer",
|
||||
model_client=model_client,
|
||||
description="A writer agent that creates stories.",
|
||||
system_message="You write children's stories on the canvas in story.md.",
|
||||
tools=[update_file_tool, apply_patch_tool],
|
||||
memory=[text_canvas_memory],
|
||||
)
|
||||
|
||||
# Create a critique agent
|
||||
critique_agent = AssistantAgent(
|
||||
name="Critique",
|
||||
model_client=model_client,
|
||||
description="A critique agent that provides feedback on stories.",
|
||||
system_message="You review the story.md file and provide constructive feedback.",
|
||||
memory=[text_canvas_memory],
|
||||
)
|
||||
|
||||
# Create a team with both agents
|
||||
team = RoundRobinGroupChat(
|
||||
participants=[writer_agent, critique_agent],
|
||||
termination_condition=TextMentionTermination("TERMINATE"),
|
||||
max_turns=10,
|
||||
)
|
||||
|
||||
# Run the team on a task
|
||||
await team.run(task="Create a children's book about a bunny and a sunflower")
|
||||
|
||||
# Get the final story
|
||||
story = text_canvas_memory.canvas.get_latest_content("story.md")
|
||||
print(story)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
def __init__(self, canvas: Optional[TextCanvas] = None):
|
||||
super().__init__()
|
||||
self.canvas = canvas if canvas is not None else TextCanvas()
|
||||
|
||||
async def update_context(self, model_context: ChatCompletionContext) -> UpdateContextResult:
|
||||
"""
|
||||
Inject the entire canvas summary (or a selected subset) as reference data.
|
||||
Here, we just put it into a system message, but you could customize.
|
||||
"""
|
||||
snapshot = self.canvas.get_all_contents_for_context()
|
||||
if snapshot.strip():
|
||||
msg = SystemMessage(content=snapshot)
|
||||
await model_context.add_message(msg)
|
||||
|
||||
# Return it for debugging/logging
|
||||
memory_content = MemoryContent(content=snapshot, mime_type=MemoryMimeType.TEXT)
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=[memory_content]))
|
||||
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
|
||||
|
||||
async def query(
|
||||
self, query: str | MemoryContent, cancellation_token: Optional[CancellationToken] = None, **kwargs: Any
|
||||
) -> MemoryQueryResult:
|
||||
"""
|
||||
Potentially search for matching filenames or file content.
|
||||
This example returns empty.
|
||||
"""
|
||||
return MemoryQueryResult(results=[])
|
||||
|
||||
async def add(self, content: MemoryContent, cancellation_token: Optional[CancellationToken] = None) -> None:
|
||||
"""
|
||||
Example usage: Possibly interpret content as a patch or direct file update.
|
||||
Could also be done by a specialized "CanvasTool" instead.
|
||||
"""
|
||||
# NO-OP here, leaving actual changes to the CanvasTool
|
||||
pass
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the entire canvas by replacing it with a new empty instance."""
|
||||
# Create a new TextCanvas instance instead of calling __init__ directly
|
||||
self.canvas = TextCanvas()
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
def get_update_file_tool(self) -> UpdateFileTool:
|
||||
"""
|
||||
Returns an UpdateFileTool instance that works with this memory's canvas.
|
||||
"""
|
||||
return UpdateFileTool(self.canvas)
|
||||
|
||||
def get_apply_patch_tool(self) -> ApplyPatchTool:
|
||||
"""
|
||||
Returns an ApplyPatchTool instance that works with this memory's canvas.
|
||||
"""
|
||||
return ApplyPatchTool(self.canvas)
|
||||
21
agent_dhal/agentdhal_extensions/memory/chromadb/__init__.py
Normal file
21
agent_dhal/agentdhal_extensions/memory/chromadb/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from ._chroma_configs import (
|
||||
ChromaDBVectorMemoryConfig,
|
||||
CustomEmbeddingFunctionConfig,
|
||||
DefaultEmbeddingFunctionConfig,
|
||||
HttpChromaDBVectorMemoryConfig,
|
||||
OpenAIEmbeddingFunctionConfig,
|
||||
PersistentChromaDBVectorMemoryConfig,
|
||||
SentenceTransformerEmbeddingFunctionConfig,
|
||||
)
|
||||
from ._chromadb import ChromaDBVectorMemory
|
||||
|
||||
__all__ = [
|
||||
"ChromaDBVectorMemory",
|
||||
"ChromaDBVectorMemoryConfig",
|
||||
"PersistentChromaDBVectorMemoryConfig",
|
||||
"HttpChromaDBVectorMemoryConfig",
|
||||
"DefaultEmbeddingFunctionConfig",
|
||||
"SentenceTransformerEmbeddingFunctionConfig",
|
||||
"OpenAIEmbeddingFunctionConfig",
|
||||
"CustomEmbeddingFunctionConfig",
|
||||
]
|
||||
@@ -0,0 +1,137 @@
|
||||
"""Configuration classes for ChromaDB vector memory."""
|
||||
|
||||
from typing import Any, Callable, Dict, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class DefaultEmbeddingFunctionConfig(BaseModel):
|
||||
"""Configuration for the default ChromaDB embedding function.
|
||||
|
||||
Uses ChromaDB's default embedding function (Sentence Transformers all-MiniLM-L6-v2).
|
||||
|
||||
.. versionadded:: v0.4.1
|
||||
Support for custom embedding functions in ChromaDB memory.
|
||||
"""
|
||||
|
||||
function_type: Literal["default"] = "default"
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingFunctionConfig(BaseModel):
|
||||
"""Configuration for SentenceTransformer embedding functions.
|
||||
|
||||
Allows specifying a custom SentenceTransformer model for embeddings.
|
||||
|
||||
.. versionadded:: v0.4.1
|
||||
Support for custom embedding functions in ChromaDB memory.
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the SentenceTransformer model to use.
|
||||
Defaults to "all-MiniLM-L6-v2".
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_extensions.memory.chromadb import SentenceTransformerEmbeddingFunctionConfig
|
||||
|
||||
_ = SentenceTransformerEmbeddingFunctionConfig(model_name="paraphrase-multilingual-mpnet-base-v2")
|
||||
"""
|
||||
|
||||
function_type: Literal["sentence_transformer"] = "sentence_transformer"
|
||||
model_name: str = Field(default="all-MiniLM-L6-v2", description="SentenceTransformer model name to use")
|
||||
|
||||
|
||||
class OpenAIEmbeddingFunctionConfig(BaseModel):
|
||||
"""Configuration for OpenAI embedding functions.
|
||||
|
||||
Uses OpenAI's embedding API for generating embeddings.
|
||||
|
||||
.. versionadded:: v0.4.1
|
||||
Support for custom embedding functions in ChromaDB memory.
|
||||
|
||||
Args:
|
||||
api_key (str): OpenAI API key. If empty, will attempt to use environment variable.
|
||||
model_name (str): OpenAI embedding model name. Defaults to "text-embedding-ada-002".
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_extensions.memory.chromadb import OpenAIEmbeddingFunctionConfig
|
||||
|
||||
_ = OpenAIEmbeddingFunctionConfig(api_key="sk-...", model_name="text-embedding-3-small")
|
||||
"""
|
||||
|
||||
function_type: Literal["openai"] = "openai"
|
||||
api_key: str = Field(default="", description="OpenAI API key")
|
||||
model_name: str = Field(default="text-embedding-ada-002", description="OpenAI embedding model name")
|
||||
|
||||
|
||||
class CustomEmbeddingFunctionConfig(BaseModel):
|
||||
"""Configuration for custom embedding functions.
|
||||
|
||||
Allows using a custom function that returns a ChromaDB-compatible embedding function.
|
||||
|
||||
.. versionadded:: v0.4.1
|
||||
Support for custom embedding functions in ChromaDB memory.
|
||||
|
||||
.. warning::
|
||||
Configurations containing custom functions are not serializable.
|
||||
|
||||
Args:
|
||||
function (Callable): Function that returns a ChromaDB-compatible embedding function.
|
||||
params (Dict[str, Any]): Parameters to pass to the function.
|
||||
"""
|
||||
|
||||
function_type: Literal["custom"] = "custom"
|
||||
function: Callable[..., Any] = Field(description="Function that returns an embedding function")
|
||||
params: Dict[str, Any] = Field(default_factory=dict, description="Parameters to pass to the function")
|
||||
|
||||
|
||||
# Tagged union type for embedding function configurations
|
||||
EmbeddingFunctionConfig = Annotated[
|
||||
Union[
|
||||
DefaultEmbeddingFunctionConfig,
|
||||
SentenceTransformerEmbeddingFunctionConfig,
|
||||
OpenAIEmbeddingFunctionConfig,
|
||||
CustomEmbeddingFunctionConfig,
|
||||
],
|
||||
Field(discriminator="function_type"),
|
||||
]
|
||||
|
||||
|
||||
class ChromaDBVectorMemoryConfig(BaseModel):
|
||||
"""Base configuration for ChromaDB-based memory implementation.
|
||||
|
||||
.. versionchanged:: v0.4.1
|
||||
Added support for custom embedding functions via embedding_function_config.
|
||||
"""
|
||||
|
||||
client_type: Literal["persistent", "http"]
|
||||
collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection")
|
||||
distance_metric: str = Field(default="cosine", description="Distance metric for similarity search")
|
||||
k: int = Field(default=3, description="Number of results to return in queries")
|
||||
score_threshold: float | None = Field(default=None, description="Minimum similarity score threshold")
|
||||
allow_reset: bool = Field(default=False, description="Whether to allow resetting the ChromaDB client")
|
||||
tenant: str = Field(default="default_tenant", description="Tenant to use")
|
||||
database: str = Field(default="default_database", description="Database to use")
|
||||
embedding_function_config: EmbeddingFunctionConfig = Field(
|
||||
default_factory=DefaultEmbeddingFunctionConfig, description="Configuration for the embedding function"
|
||||
)
|
||||
|
||||
|
||||
class PersistentChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
|
||||
"""Configuration for persistent ChromaDB memory."""
|
||||
|
||||
client_type: Literal["persistent", "http"] = "persistent"
|
||||
persistence_path: str = Field(default="./chroma_db", description="Path for persistent storage")
|
||||
|
||||
|
||||
class HttpChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
|
||||
"""Configuration for HTTP ChromaDB memory."""
|
||||
|
||||
client_type: Literal["persistent", "http"] = "http"
|
||||
host: str = Field(default="localhost", description="Host of the remote server")
|
||||
port: int = Field(default=8000, description="Port of the remote server")
|
||||
ssl: bool = Field(default=False, description="Whether to use HTTPS")
|
||||
headers: Dict[str, str] | None = Field(default=None, description="Headers to send to the server")
|
||||
459
agent_dhal/agentdhal_extensions/memory/chromadb/_chromadb.py
Normal file
459
agent_dhal/agentdhal_extensions/memory/chromadb/_chromadb.py
Normal file
@@ -0,0 +1,459 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, List
|
||||
|
||||
from agentdhal_core import CancellationToken, Component, Image
|
||||
from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
|
||||
from agentdhal_core.model_context import ChatCompletionContext
|
||||
from agentdhal_core.models import SystemMessage
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.types import Document, Metadata
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._chroma_configs import (
|
||||
ChromaDBVectorMemoryConfig,
|
||||
CustomEmbeddingFunctionConfig,
|
||||
DefaultEmbeddingFunctionConfig,
|
||||
HttpChromaDBVectorMemoryConfig,
|
||||
OpenAIEmbeddingFunctionConfig,
|
||||
PersistentChromaDBVectorMemoryConfig,
|
||||
SentenceTransformerEmbeddingFunctionConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from chromadb.api import ClientAPI
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"To use the ChromaDBVectorMemory the chromadb extra must be installed. Run `pip install autogen-ext[chromadb]`"
|
||||
) from e
|
||||
|
||||
|
||||
class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
|
||||
"""
|
||||
Store and retrieve memory using vector similarity search powered by ChromaDB.
|
||||
|
||||
`ChromaDBVectorMemory` provides a vector-based memory implementation that uses ChromaDB for
|
||||
storing and retrieving content based on semantic similarity. It enhances agents with the ability
|
||||
to recall contextually relevant information during conversations by leveraging vector embeddings
|
||||
to find similar content.
|
||||
|
||||
This implementation serves as a reference for more complex memory systems using vector embeddings.
|
||||
For advanced use cases requiring specialized formatting of retrieved content, users should extend
|
||||
this class and override the `update_context()` method.
|
||||
|
||||
This implementation requires the ChromaDB extra to be installed. Install with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[chromadb]"
|
||||
|
||||
Args:
|
||||
config (ChromaDBVectorMemoryConfig | None): Configuration for the ChromaDB memory.
|
||||
If None, defaults to a PersistentChromaDBVectorMemoryConfig with default values.
|
||||
Two config types are supported:
|
||||
* PersistentChromaDBVectorMemoryConfig: For local storage
|
||||
* HttpChromaDBVectorMemoryConfig: For connecting to a remote ChromaDB server
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_core.memory import MemoryContent, MemoryMimeType
|
||||
from agentdhal_extensions.memory.chromadb import (
|
||||
ChromaDBVectorMemory,
|
||||
PersistentChromaDBVectorMemoryConfig,
|
||||
SentenceTransformerEmbeddingFunctionConfig,
|
||||
OpenAIEmbeddingFunctionConfig,
|
||||
)
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
def get_weather(city: str) -> str:
|
||||
return f"The weather in {city} is sunny with a high of 90°F and a low of 70°F."
|
||||
|
||||
|
||||
def fahrenheit_to_celsius(fahrenheit: float) -> float:
|
||||
return (fahrenheit - 32) * 5.0 / 9.0
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Use default embedding function
|
||||
default_memory = ChromaDBVectorMemory(
|
||||
config=PersistentChromaDBVectorMemoryConfig(
|
||||
collection_name="user_preferences",
|
||||
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
|
||||
k=3, # Return top 3 results
|
||||
score_threshold=0.5, # Minimum similarity score
|
||||
)
|
||||
)
|
||||
|
||||
# Using a custom SentenceTransformer model
|
||||
custom_memory = ChromaDBVectorMemory(
|
||||
config=PersistentChromaDBVectorMemoryConfig(
|
||||
collection_name="multilingual_memory",
|
||||
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
|
||||
embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(
|
||||
model_name="paraphrase-multilingual-mpnet-base-v2"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Using OpenAI embeddings
|
||||
openai_memory = ChromaDBVectorMemory(
|
||||
config=PersistentChromaDBVectorMemoryConfig(
|
||||
collection_name="openai_memory",
|
||||
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
|
||||
embedding_function_config=OpenAIEmbeddingFunctionConfig(
|
||||
api_key=os.environ["OPENAI_API_KEY"], model_name="text-embedding-3-small"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Add user preferences to memory
|
||||
await openai_memory.add(
|
||||
MemoryContent(
|
||||
content="The user prefers weather temperatures in Celsius",
|
||||
mime_type=MemoryMimeType.TEXT,
|
||||
metadata={"category": "preferences", "type": "units"},
|
||||
)
|
||||
)
|
||||
|
||||
# Create assistant agent with ChromaDB memory
|
||||
assistant = AssistantAgent(
|
||||
name="assistant",
|
||||
model_client=OpenAIChatCompletionClient(
|
||||
model="gpt-4.1",
|
||||
),
|
||||
tools=[
|
||||
get_weather,
|
||||
fahrenheit_to_celsius,
|
||||
],
|
||||
max_tool_iterations=10,
|
||||
memory=[openai_memory],
|
||||
)
|
||||
|
||||
# The memory will automatically retrieve relevant content during conversations
|
||||
await Console(assistant.run_stream(task="What's the temperature in New York?"))
|
||||
|
||||
# Remember to close the memory when finished
|
||||
await default_memory.close()
|
||||
await custom_memory.close()
|
||||
await openai_memory.close()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Output:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
---------- TextMessage (user) ----------
|
||||
What's the temperature in New York?
|
||||
---------- MemoryQueryEvent (assistant) ----------
|
||||
[MemoryContent(content='The user prefers weather temperatures in Celsius', mime_type='MemoryMimeType.TEXT', metadata={'type': 'units', 'category': 'preferences', 'mime_type': 'MemoryMimeType.TEXT', 'score': 0.3133561611175537, 'id': 'fb00506c-acf4-4174-93d7-2a942593f3f7'}), MemoryContent(content='The user prefers weather temperatures in Celsius', mime_type='MemoryMimeType.TEXT', metadata={'mime_type': 'MemoryMimeType.TEXT', 'category': 'preferences', 'type': 'units', 'score': 0.3133561611175537, 'id': '34311689-b419-4e1a-8bc4-09143f356c66'})]
|
||||
---------- ToolCallRequestEvent (assistant) ----------
|
||||
[FunctionCall(id='call_7TjsFd430J1aKwU5T2w8bvdh', arguments='{"city":"New York"}', name='get_weather')]
|
||||
---------- ToolCallExecutionEvent (assistant) ----------
|
||||
[FunctionExecutionResult(content='The weather in New York is sunny with a high of 90°F and a low of 70°F.', name='get_weather', call_id='call_7TjsFd430J1aKwU5T2w8bvdh', is_error=False)]
|
||||
---------- ToolCallRequestEvent (assistant) ----------
|
||||
[FunctionCall(id='call_RTjMHEZwDXtjurEYTjDlvq9c', arguments='{"fahrenheit": 90}', name='fahrenheit_to_celsius'), FunctionCall(id='call_3mMuCK1aqtzZPTqIHPoHKxtP', arguments='{"fahrenheit": 70}', name='fahrenheit_to_celsius')]
|
||||
---------- ToolCallExecutionEvent (assistant) ----------
|
||||
[FunctionExecutionResult(content='32.22222222222222', name='fahrenheit_to_celsius', call_id='call_RTjMHEZwDXtjurEYTjDlvq9c', is_error=False), FunctionExecutionResult(content='21.11111111111111', name='fahrenheit_to_celsius', call_id='call_3mMuCK1aqtzZPTqIHPoHKxtP', is_error=False)]
|
||||
---------- TextMessage (assistant) ----------
|
||||
The temperature in New York today is sunny with a high of about 32°C and a low of about 21°C.
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = ChromaDBVectorMemoryConfig
|
||||
component_provider_override = "agentdhal_extensions.memory.chromadb.ChromaDBVectorMemory"
|
||||
|
||||
def __init__(self, config: ChromaDBVectorMemoryConfig | None = None) -> None:
|
||||
self._config = config or PersistentChromaDBVectorMemoryConfig()
|
||||
self._client: ClientAPI | None = None
|
||||
self._collection: Collection | None = None
|
||||
|
||||
@property
|
||||
def collection_name(self) -> str:
|
||||
"""Get the name of the ChromaDB collection."""
|
||||
return self._config.collection_name
|
||||
|
||||
def _create_embedding_function(self) -> Any:
|
||||
"""Create an embedding function based on the configuration.
|
||||
|
||||
Returns:
|
||||
A ChromaDB-compatible embedding function.
|
||||
|
||||
Raises:
|
||||
ValueError: If the embedding function type is unsupported.
|
||||
ImportError: If required dependencies are not installed.
|
||||
"""
|
||||
try:
|
||||
from chromadb.utils import embedding_functions
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ChromaDB embedding functions not available. Ensure chromadb is properly installed."
|
||||
) from e
|
||||
|
||||
config = self._config.embedding_function_config
|
||||
|
||||
if isinstance(config, DefaultEmbeddingFunctionConfig):
|
||||
return embedding_functions.DefaultEmbeddingFunction()
|
||||
|
||||
elif isinstance(config, SentenceTransformerEmbeddingFunctionConfig):
|
||||
try:
|
||||
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=config.model_name)
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Failed to create SentenceTransformer embedding function with model '{config.model_name}'. "
|
||||
f"Ensure sentence-transformers is installed and the model is available. Error: {e}"
|
||||
) from e
|
||||
|
||||
elif isinstance(config, OpenAIEmbeddingFunctionConfig):
|
||||
try:
|
||||
return embedding_functions.OpenAIEmbeddingFunction(api_key=config.api_key, model_name=config.model_name)
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Failed to create OpenAI embedding function with model '{config.model_name}'. "
|
||||
f"Ensure openai is installed and API key is valid. Error: {e}"
|
||||
) from e
|
||||
|
||||
elif isinstance(config, CustomEmbeddingFunctionConfig):
|
||||
try:
|
||||
return config.function(**config.params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create custom embedding function. Error: {e}") from e
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding function config type: {type(config)}")
|
||||
|
||||
def _ensure_initialized(self) -> None:
|
||||
"""Ensure ChromaDB client and collection are initialized."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from chromadb.config import Settings
|
||||
|
||||
settings = Settings(allow_reset=self._config.allow_reset)
|
||||
|
||||
if isinstance(self._config, PersistentChromaDBVectorMemoryConfig):
|
||||
self._client = PersistentClient(
|
||||
path=self._config.persistence_path,
|
||||
settings=settings,
|
||||
tenant=self._config.tenant,
|
||||
database=self._config.database,
|
||||
)
|
||||
elif isinstance(self._config, HttpChromaDBVectorMemoryConfig):
|
||||
self._client = HttpClient(
|
||||
host=self._config.host,
|
||||
port=self._config.port,
|
||||
ssl=self._config.ssl,
|
||||
headers=self._config.headers,
|
||||
settings=settings,
|
||||
tenant=self._config.tenant,
|
||||
database=self._config.database,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported config type: {type(self._config)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ChromaDB client: {e}")
|
||||
raise
|
||||
|
||||
if self._collection is None:
|
||||
try:
|
||||
# Create embedding function
|
||||
embedding_function = self._create_embedding_function()
|
||||
|
||||
# Create or get collection with embedding function
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self._config.collection_name,
|
||||
metadata={"distance_metric": self._config.distance_metric},
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get/create collection: {e}")
|
||||
raise
|
||||
|
||||
def _extract_text(self, content_item: str | MemoryContent) -> str:
|
||||
"""Extract searchable text from content."""
|
||||
if isinstance(content_item, str):
|
||||
return content_item
|
||||
|
||||
content = content_item.content
|
||||
mime_type = content_item.mime_type
|
||||
|
||||
if mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]:
|
||||
return str(content)
|
||||
elif mime_type == MemoryMimeType.JSON:
|
||||
if isinstance(content, dict):
|
||||
# Store original JSON string representation
|
||||
return str(content).lower()
|
||||
raise ValueError("JSON content must be a dict")
|
||||
elif isinstance(content, Image):
|
||||
raise ValueError("Image content cannot be converted to text")
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {mime_type}")
|
||||
|
||||
def _calculate_score(self, distance: float) -> float:
|
||||
"""Convert ChromaDB distance to a similarity score."""
|
||||
if self._config.distance_metric == "cosine":
|
||||
return 1.0 - (distance / 2.0)
|
||||
return 1.0 / (1.0 + distance)
|
||||
|
||||
async def update_context(
|
||||
self,
|
||||
model_context: ChatCompletionContext,
|
||||
) -> UpdateContextResult:
|
||||
messages = await model_context.get_messages()
|
||||
if not messages:
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
|
||||
|
||||
# Extract query from last message
|
||||
last_message = messages[-1]
|
||||
query_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
|
||||
|
||||
# Query memory and get results
|
||||
query_results = await self.query(query_text)
|
||||
|
||||
if query_results.results:
|
||||
# Format results for context
|
||||
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)]
|
||||
memory_context = "\nRelevant memory content:\n" + "\n".join(memory_strings)
|
||||
|
||||
# Add to context
|
||||
await model_context.add_message(SystemMessage(content=memory_context))
|
||||
|
||||
return UpdateContextResult(memories=query_results)
|
||||
|
||||
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
|
||||
self._ensure_initialized()
|
||||
if self._collection is None:
|
||||
raise RuntimeError("Failed to initialize ChromaDB")
|
||||
|
||||
try:
|
||||
# Extract text from content
|
||||
text = self._extract_text(content)
|
||||
|
||||
# Use metadata directly from content
|
||||
metadata_dict = content.metadata or {}
|
||||
metadata_dict["mime_type"] = str(content.mime_type)
|
||||
|
||||
# Add to ChromaDB
|
||||
self._collection.add(documents=[text], metadatas=[metadata_dict], ids=[str(uuid.uuid4())])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add content to ChromaDB: {e}")
|
||||
raise
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str | MemoryContent,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MemoryQueryResult:
|
||||
self._ensure_initialized()
|
||||
if self._collection is None:
|
||||
raise RuntimeError("Failed to initialize ChromaDB")
|
||||
|
||||
try:
|
||||
# Extract text for query
|
||||
query_text = self._extract_text(query)
|
||||
|
||||
# Query ChromaDB
|
||||
results = self._collection.query(
|
||||
query_texts=[query_text],
|
||||
n_results=self._config.k,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Convert results to MemoryContent list
|
||||
memory_results: List[MemoryContent] = []
|
||||
|
||||
if (
|
||||
not results
|
||||
or not results.get("documents")
|
||||
or not results.get("metadatas")
|
||||
or not results.get("distances")
|
||||
):
|
||||
return MemoryQueryResult(results=memory_results)
|
||||
|
||||
documents: List[Document] = results["documents"][0] if results["documents"] else []
|
||||
metadatas: List[Metadata] = results["metadatas"][0] if results["metadatas"] else []
|
||||
distances: List[float] = results["distances"][0] if results["distances"] else []
|
||||
ids: List[str] = results["ids"][0] if results["ids"] else []
|
||||
|
||||
for doc, metadata_dict, distance, doc_id in zip(documents, metadatas, distances, ids, strict=False):
|
||||
# Calculate score
|
||||
score = self._calculate_score(distance)
|
||||
metadata = dict(metadata_dict)
|
||||
metadata["score"] = score
|
||||
metadata["id"] = doc_id
|
||||
if self._config.score_threshold is not None and score < self._config.score_threshold:
|
||||
continue
|
||||
|
||||
# Extract mime_type from metadata
|
||||
mime_type = str(metadata_dict.get("mime_type", MemoryMimeType.TEXT.value))
|
||||
|
||||
# Create MemoryContent
|
||||
content = MemoryContent(
|
||||
content=doc,
|
||||
mime_type=mime_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
memory_results.append(content)
|
||||
|
||||
return MemoryQueryResult(results=memory_results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query ChromaDB: {e}")
|
||||
raise
|
||||
|
||||
async def clear(self) -> None:
|
||||
self._ensure_initialized()
|
||||
if self._collection is None:
|
||||
raise RuntimeError("Failed to initialize ChromaDB")
|
||||
|
||||
try:
|
||||
results = self._collection.get()
|
||||
if results and results["ids"]:
|
||||
self._collection.delete(ids=results["ids"])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear ChromaDB collection: {e}")
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up ChromaDB client and resources."""
|
||||
self._collection = None
|
||||
self._client = None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._ensure_initialized()
|
||||
if not self._config.allow_reset:
|
||||
raise RuntimeError("Reset not allowed. Set allow_reset=True in config to enable.")
|
||||
|
||||
if self._client is not None:
|
||||
try:
|
||||
self._client.reset()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ChromaDB reset: {e}")
|
||||
finally:
|
||||
self._collection = None
|
||||
|
||||
def _to_config(self) -> ChromaDBVectorMemoryConfig:
|
||||
"""Serialize the memory configuration."""
|
||||
|
||||
return self._config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: ChromaDBVectorMemoryConfig) -> Self:
|
||||
"""Deserialize the memory configuration."""
|
||||
|
||||
return cls(config=config)
|
||||
6
agent_dhal/agentdhal_extensions/memory/mem0/__init__.py
Normal file
6
agent_dhal/agentdhal_extensions/memory/mem0/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from ._mem0 import Mem0Memory, Mem0MemoryConfig
|
||||
|
||||
__all__ = [
|
||||
"Mem0Memory",
|
||||
"Mem0MemoryConfig",
|
||||
]
|
||||
449
agent_dhal/agentdhal_extensions/memory/mem0/_mem0.py
Normal file
449
agent_dhal/agentdhal_extensions/memory/mem0/_mem0.py
Normal file
@@ -0,0 +1,449 @@
|
||||
import io
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, TypedDict, cast
|
||||
|
||||
from agentdhal_core import CancellationToken, Component, ComponentBase
|
||||
from agentdhal_core.memory import Memory, MemoryContent, MemoryQueryResult, UpdateContextResult
|
||||
from agentdhal_core.model_context import ChatCompletionContext
|
||||
from agentdhal_core.models import SystemMessage
|
||||
from mem0 import Memory as Memory0
|
||||
from mem0 import MemoryClient
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger("chromadb").setLevel(logging.ERROR)
|
||||
|
||||
|
||||
class Mem0MemoryConfig(BaseModel):
|
||||
"""Configuration for Mem0Memory component."""
|
||||
|
||||
user_id: Optional[str] = Field(
|
||||
default=None, description="User ID for memory operations. If not provided, a UUID will be generated."
|
||||
)
|
||||
limit: int = Field(default=10, description="Maximum number of results to return in memory queries.")
|
||||
is_cloud: bool = Field(default=True, description="Whether to use cloud Mem0 client (True) or local client (False).")
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, description="API key for cloud Mem0 client. Required if is_cloud=True."
|
||||
)
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Configuration dictionary for local Mem0 client. Required if is_cloud=False."
|
||||
)
|
||||
|
||||
|
||||
class MemoryResult(TypedDict, total=False):
|
||||
memory: str
|
||||
score: float
|
||||
metadata: Dict[str, Any]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
categories: List[str]
|
||||
|
||||
|
||||
# pyright: reportGeneralTypeIssues=false
|
||||
class Mem0Memory(Memory, Component[Mem0MemoryConfig], ComponentBase[Mem0MemoryConfig]):
|
||||
"""Mem0 memory implementation for AutoGen.
|
||||
|
||||
This component integrates with Mem0.ai's memory system, providing an implementation
|
||||
of AutoGen's Memory interface. It supports both cloud and local backends through the
|
||||
mem0ai Python package.
|
||||
|
||||
To use this component, you need to have the `mem0` (for cloud-only) or `mem0-local` (for local)
|
||||
extra installed for the `autogen-ext` package:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-ext[mem0]" # For cloud-based Mem0
|
||||
pip install -U "agentdhal-ext[mem0-local]" # For local Mem0
|
||||
|
||||
The memory component can store and retrieve information that agents need to remember
|
||||
across conversations. It also provides context updating for language models with
|
||||
relevant memories.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.memory.mem0 import Mem0Memory
|
||||
from agentdhal_core.memory import MemoryContent
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create a local Mem0Memory (no API key required)
|
||||
memory = Mem0Memory(
|
||||
is_cloud=False,
|
||||
config={"path": ":memory:"}, # Use in-memory storage for testing
|
||||
)
|
||||
print("Memory initialized successfully!")
|
||||
|
||||
# Add something to memory
|
||||
test_content = "User likes the color blue."
|
||||
await memory.add(MemoryContent(content=test_content, mime_type="text/plain"))
|
||||
print(f"Added content: {test_content}")
|
||||
|
||||
# Retrieve memories with a search query
|
||||
results = await memory.query("What color does the user like?")
|
||||
print(f"Query results: {len(results.results)} found")
|
||||
|
||||
for i, result in enumerate(results.results):
|
||||
print(f"Result {i+1}: {result}")
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Output:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
Memory initialized successfully!
|
||||
Added content: User likes the color blue.
|
||||
Query results: 1 found
|
||||
Result 1: content='User likes the color blue' mime_type='text/plain' metadata={'score': 0.6977155806281953, 'created_at': datetime.datetime(2025, 7, 6, 17, 25, 18, 754725, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200)))}
|
||||
|
||||
Using it with an :class:`~agentdhal_agentchat.agents.AssistantAgent`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_core.memory import MemoryContent
|
||||
from agentdhal_extensions.memory.mem0 import Mem0Memory
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create a model client
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1")
|
||||
|
||||
# Create a Mem0 memory instance
|
||||
memory = Mem0Memory(
|
||||
user_id="user123",
|
||||
is_cloud=False,
|
||||
config={"path": ":memory:"}, # Use in-memory storage for testing
|
||||
)
|
||||
|
||||
# Add something to memory
|
||||
test_content = "User likes the color blue."
|
||||
await memory.add(MemoryContent(content=test_content, mime_type="text/plain"))
|
||||
|
||||
# Create an assistant agent with Mem0 memory
|
||||
agent = AssistantAgent(
|
||||
name="assistant",
|
||||
model_client=model_client,
|
||||
memory=[memory],
|
||||
system_message="You are a helpful assistant that remembers user preferences.",
|
||||
)
|
||||
|
||||
# Run a sample task
|
||||
result = await agent.run(task="What color does the user like?")
|
||||
print(result.messages[-1].content) # type: ignore
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Output:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
User likes the color blue.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for memory operations. If not provided, a UUID will be generated.
|
||||
limit: Maximum number of results to return in memory queries.
|
||||
is_cloud: Whether to use cloud Mem0 client (True) or local client (False).
|
||||
api_key: API key for cloud Mem0 client. It will read from the environment MEM0_API_KEY if not provided.
|
||||
config: Configuration dictionary for local Mem0 client. Required if is_cloud=False.
|
||||
"""
|
||||
|
||||
component_type = "memory"
|
||||
component_provider_override = "agentdhal_extensions.memory.mem0.Mem0Memory"
|
||||
component_config_schema = Mem0MemoryConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
is_cloud: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
# Validate parameters
|
||||
if not is_cloud and config is None:
|
||||
raise ValueError("config is required when using local Mem0 client (is_cloud=False)")
|
||||
|
||||
# Initialize instance variables
|
||||
self._user_id = user_id or str(uuid.uuid4())
|
||||
self._limit = limit
|
||||
self._is_cloud = is_cloud
|
||||
self._api_key = api_key
|
||||
self._config = config
|
||||
|
||||
# Initialize client
|
||||
if self._is_cloud:
|
||||
self._client = MemoryClient(api_key=self._api_key)
|
||||
else:
|
||||
assert self._config is not None
|
||||
config_dict = self._config
|
||||
self._client = Memory0.from_config(config_dict=config_dict) # type: ignore
|
||||
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
"""Get the user ID for memory operations."""
|
||||
return self._user_id
|
||||
|
||||
@property
|
||||
def limit(self) -> int:
|
||||
"""Get the maximum number of results to return in memory queries."""
|
||||
return self._limit
|
||||
|
||||
@property
|
||||
def is_cloud(self) -> bool:
|
||||
"""Check if the Mem0 client is cloud-based."""
|
||||
return self._is_cloud
|
||||
|
||||
@property
|
||||
def config(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get the configuration for the Mem0 client."""
|
||||
return self._config
|
||||
|
||||
async def add(
|
||||
self,
|
||||
content: MemoryContent,
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> None:
|
||||
"""Add content to memory.
|
||||
|
||||
Args:
|
||||
content: The memory content to add.
|
||||
cancellation_token: Optional token to cancel operation.
|
||||
|
||||
Raises:
|
||||
Exception: If there's an error adding content to mem0 memory.
|
||||
"""
|
||||
# Extract content based on mime type
|
||||
if hasattr(content, "content") and hasattr(content, "mime_type"):
|
||||
if content.mime_type in ["text/plain", "text/markdown"]:
|
||||
message = str(content.content)
|
||||
elif content.mime_type == "application/json":
|
||||
# Convert JSON content to string representation
|
||||
if isinstance(content.content, str):
|
||||
message = content.content
|
||||
else:
|
||||
# Convert dict or other JSON serializable objects to string
|
||||
import json
|
||||
|
||||
message = json.dumps(content.content)
|
||||
else:
|
||||
message = str(content.content)
|
||||
|
||||
# Extract metadata
|
||||
metadata = content.metadata or {}
|
||||
else:
|
||||
# Handle case where content is directly provided as string
|
||||
message = str(content)
|
||||
metadata = {}
|
||||
|
||||
# Check if operation is cancelled
|
||||
if cancellation_token is not None and cancellation_token.cancelled: # type: ignore
|
||||
return
|
||||
|
||||
# Add to mem0 client
|
||||
try:
|
||||
user_id = metadata.pop("user_id", self._user_id)
|
||||
# Suppress warning messages from mem0 MemoryClient
|
||||
kwargs = {} if self._client.__class__.__name__ == "Memory" else {"output_format": "v1.1"}
|
||||
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
|
||||
self._client.add([{"role": "user", "content": message}], user_id=user_id, metadata=metadata, **kwargs) # type: ignore
|
||||
except Exception as e:
|
||||
# Log the error but don't crash
|
||||
logger.error(f"Error adding to mem0 memory: {str(e)}")
|
||||
raise
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str | MemoryContent = "",
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
**kwargs: Any,
|
||||
) -> MemoryQueryResult:
|
||||
"""Query memory for relevant content.
|
||||
|
||||
Args:
|
||||
query: The query to search for, either as string or MemoryContent.
|
||||
cancellation_token: Optional token to cancel operation.
|
||||
**kwargs: Additional query parameters to pass to mem0.
|
||||
|
||||
Returns:
|
||||
MemoryQueryResult containing search results.
|
||||
"""
|
||||
# Extract query text
|
||||
if isinstance(query, str):
|
||||
query_text = query
|
||||
elif hasattr(query, "content"):
|
||||
query_text = str(query.content)
|
||||
else:
|
||||
query_text = str(query)
|
||||
|
||||
# Check if operation is cancelled
|
||||
if (
|
||||
cancellation_token
|
||||
and hasattr(cancellation_token, "cancelled")
|
||||
and getattr(cancellation_token, "cancelled", False)
|
||||
):
|
||||
return MemoryQueryResult(results=[])
|
||||
|
||||
try:
|
||||
limit = kwargs.pop("limit", self._limit)
|
||||
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
|
||||
# Query mem0 client
|
||||
results = self._client.search( # type: ignore
|
||||
query_text,
|
||||
user_id=self._user_id,
|
||||
limit=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Type-safe handling of results
|
||||
if isinstance(results, dict) and "results" in results:
|
||||
result_list = cast(List[MemoryResult], results["results"])
|
||||
else:
|
||||
result_list = cast(List[MemoryResult], results)
|
||||
|
||||
# Convert results to MemoryContent objects
|
||||
memory_contents: List[MemoryContent] = []
|
||||
for result in result_list:
|
||||
content_text = result.get("memory", "")
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
if "metadata" in result and result["metadata"]:
|
||||
metadata = result["metadata"]
|
||||
|
||||
# Add relevant fields to metadata
|
||||
if "score" in result:
|
||||
metadata["score"] = result["score"]
|
||||
|
||||
# For created_at
|
||||
if "created_at" in result and result.get("created_at"):
|
||||
try:
|
||||
metadata["created_at"] = datetime.fromisoformat(result["created_at"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# For updated_at
|
||||
if "updated_at" in result and result.get("updated_at"):
|
||||
try:
|
||||
metadata["updated_at"] = datetime.fromisoformat(result["updated_at"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# For categories
|
||||
if "categories" in result and result.get("categories"):
|
||||
metadata["categories"] = result["categories"]
|
||||
|
||||
# Create MemoryContent object
|
||||
memory_content = MemoryContent(
|
||||
content=content_text,
|
||||
mime_type="text/plain", # Default to text/plain
|
||||
metadata=metadata,
|
||||
)
|
||||
memory_contents.append(memory_content)
|
||||
|
||||
return MemoryQueryResult(results=memory_contents)
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but return empty results
|
||||
logger.error(f"Error querying mem0 memory: {str(e)}")
|
||||
return MemoryQueryResult(results=[])
|
||||
|
||||
async def update_context(
|
||||
self,
|
||||
model_context: ChatCompletionContext,
|
||||
) -> UpdateContextResult:
|
||||
"""Update the model context with relevant memories.
|
||||
|
||||
This method retrieves the conversation history from the model context,
|
||||
uses the last message as a query to find relevant memories, and then
|
||||
adds those memories to the context as a system message.
|
||||
|
||||
Args:
|
||||
model_context: The model context to update.
|
||||
|
||||
Returns:
|
||||
UpdateContextResult containing memories added to the context.
|
||||
"""
|
||||
# Get messages from context
|
||||
messages = await model_context.get_messages()
|
||||
if not messages:
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
|
||||
|
||||
# Use the last message as query
|
||||
last_message = messages[-1]
|
||||
query_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
|
||||
|
||||
# Query memory
|
||||
query_results = await self.query(query_text, limit=self._limit)
|
||||
|
||||
# If we have results, add them to the context
|
||||
if query_results.results:
|
||||
# Format memories as numbered list
|
||||
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)]
|
||||
memory_context = "\nRelevant memories:\n" + "\n".join(memory_strings)
|
||||
|
||||
# Add as system message
|
||||
await model_context.add_message(SystemMessage(content=memory_context))
|
||||
|
||||
return UpdateContextResult(memories=query_results)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all content from memory for the current user.
|
||||
|
||||
Raises:
|
||||
Exception: If there's an error clearing mem0 memory.
|
||||
"""
|
||||
try:
|
||||
self._client.delete_all(user_id=self._user_id) # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing mem0 memory: {str(e)}")
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources if needed.
|
||||
|
||||
This is a no-op for Mem0 clients as they don't require explicit cleanup.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: Mem0MemoryConfig) -> Self:
|
||||
"""Create instance from configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration for Mem0Memory component.
|
||||
|
||||
Returns:
|
||||
A new Mem0Memory instance.
|
||||
"""
|
||||
return cls(
|
||||
user_id=config.user_id,
|
||||
limit=config.limit,
|
||||
is_cloud=config.is_cloud,
|
||||
api_key=config.api_key,
|
||||
config=config.config,
|
||||
)
|
||||
|
||||
def _to_config(self) -> Mem0MemoryConfig:
|
||||
"""Convert instance to configuration.
|
||||
|
||||
Returns:
|
||||
Configuration representing this Mem0Memory instance.
|
||||
"""
|
||||
return Mem0MemoryConfig(
|
||||
user_id=self._user_id,
|
||||
limit=self._limit,
|
||||
is_cloud=self._is_cloud,
|
||||
api_key=self._api_key,
|
||||
config=self._config,
|
||||
)
|
||||
9
agent_dhal/agentdhal_extensions/memory/redis/__init__.py
Normal file
9
agent_dhal/agentdhal_extensions/memory/redis/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from ._redis_memory import (
|
||||
RedisMemory,
|
||||
RedisMemoryConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RedisMemoryConfig",
|
||||
"RedisMemory",
|
||||
]
|
||||
325
agent_dhal/agentdhal_extensions/memory/redis/_redis_memory.py
Normal file
325
agent_dhal/agentdhal_extensions/memory/redis/_redis_memory.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import logging
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from agentdhal_core import CancellationToken, Component
|
||||
from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
|
||||
from agentdhal_core.model_context import ChatCompletionContext
|
||||
from agentdhal_core.models import SystemMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from redis import Redis
|
||||
from redisvl.extensions.message_history import SemanticMessageHistory
|
||||
from redisvl.utils.utils import deserialize, serialize
|
||||
except ImportError as e:
|
||||
raise ImportError("To use Redis Memory RedisVL must be installed. Run `pip install autogen-ext[redisvl]`") from e
|
||||
|
||||
|
||||
class RedisMemoryConfig(BaseModel):
|
||||
"""
|
||||
Configuration for Redis-based vector memory.
|
||||
|
||||
This class defines the configuration options for using Redis as a vector memory store,
|
||||
supporting semantic memory. It allows customization of the Redis connection, index settings,
|
||||
similarity search parameters, and embedding model.
|
||||
"""
|
||||
|
||||
redis_url: str = Field(default="redis://localhost:6379", description="url of the Redis instance")
|
||||
index_name: str = Field(default="chat_history", description="Name of the Redis collection")
|
||||
prefix: str = Field(default="memory", description="prefix of the Redis collection")
|
||||
distance_metric: Literal["cosine", "ip", "l2"] = "cosine"
|
||||
algorithm: Literal["flat", "hnsw"] = "flat"
|
||||
top_k: int = Field(default=10, description="Number of results to return in queries")
|
||||
datatype: Literal["uint8", "int8", "float16", "float32", "float64", "bfloat16"] = "float32"
|
||||
distance_threshold: float = Field(default=0.7, description="Minimum similarity score threshold")
|
||||
model_name: str | None = Field(
|
||||
default="sentence-transformers/all-mpnet-base-v2", description="Embedding model name"
|
||||
)
|
||||
|
||||
|
||||
class RedisMemory(Memory, Component[RedisMemoryConfig]):
|
||||
"""
|
||||
Store and retrieve memory using vector similarity search powered by RedisVL.
|
||||
|
||||
`RedisMemory` provides a vector-based memory implementation that uses RedisVL for storing and
|
||||
retrieving content based on semantic similarity. It enhances agents with the ability to recall
|
||||
contextually relevant information during conversations by leveraging vector embeddings to find
|
||||
similar content.
|
||||
|
||||
This implementation requires the RedisVL extra to be installed. Install with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[redisvl]"
|
||||
|
||||
Additionally, you will need access to a Redis instance.
|
||||
To run a local instance of redis in docker:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker run -d --name redis -p 6379:6379 redis:8
|
||||
|
||||
To download and run Redis locally:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
|
||||
echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list
|
||||
sudo apt-get update > /dev/null 2>&1
|
||||
sudo apt-get install redis-server > /dev/null 2>&1
|
||||
redis-server --daemonize yes
|
||||
|
||||
Args:
|
||||
config (RedisMemoryConfig | None): Configuration for the Redis memory.
|
||||
If None, defaults to a RedisMemoryConfig with recommended settings.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from logging import WARNING, getLogger
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_core.memory import MemoryContent, MemoryMimeType
|
||||
from agentdhal_extensions.memory.redis import RedisMemory, RedisMemoryConfig
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
logger = getLogger()
|
||||
logger.setLevel(WARNING)
|
||||
|
||||
|
||||
# Define tool to use
|
||||
async def get_weather(city: str, units: str = "imperial") -> str:
|
||||
if units == "imperial":
|
||||
return f"The weather in {city} is 73 °F and Sunny."
|
||||
elif units == "metric":
|
||||
return f"The weather in {city} is 23 °C and Sunny."
|
||||
else:
|
||||
return f"Sorry, I don't know the weather in {city}."
|
||||
|
||||
|
||||
async def main():
|
||||
# Initailize Redis memory
|
||||
redis_memory = RedisMemory(
|
||||
config=RedisMemoryConfig(
|
||||
redis_url="redis://localhost:6379",
|
||||
index_name="chat_history",
|
||||
prefix="memory",
|
||||
)
|
||||
)
|
||||
|
||||
# Add user preferences to memory
|
||||
await redis_memory.add(
|
||||
MemoryContent(
|
||||
content="The weather should be in metric units",
|
||||
mime_type=MemoryMimeType.TEXT,
|
||||
metadata={"category": "preferences", "type": "units"},
|
||||
)
|
||||
)
|
||||
|
||||
await redis_memory.add(
|
||||
MemoryContent(
|
||||
content="Meal recipe must be vegan",
|
||||
mime_type=MemoryMimeType.TEXT,
|
||||
metadata={"category": "preferences", "type": "dietary"},
|
||||
)
|
||||
)
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
# Create assistant agent with ChromaDB memory
|
||||
assistant_agent = AssistantAgent(
|
||||
name="assistant_agent",
|
||||
model_client=model_client,
|
||||
tools=[get_weather],
|
||||
memory=[redis_memory],
|
||||
)
|
||||
|
||||
stream = assistant_agent.run_stream(task="What is the weather in New York?")
|
||||
await Console(stream)
|
||||
|
||||
await model_client.close()
|
||||
await redis_memory.close()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Output:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
---------- TextMessage (user) ----------
|
||||
What is the weather in New York?
|
||||
---------- MemoryQueryEvent (assistant_agent) ----------
|
||||
[MemoryContent(content='The weather should be in metric units', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata={'category': 'preferences', 'type': 'units'})]
|
||||
---------- ToolCallRequestEvent (assistant_agent) ----------
|
||||
[FunctionCall(id='call_tyCPvPPAV4SHWhtfpM6UMemr', arguments='{"city":"New York","units":"metric"}', name='get_weather')]
|
||||
---------- ToolCallExecutionEvent (assistant_agent) ----------
|
||||
[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_tyCPvPPAV4SHWhtfpM6UMemr', is_error=False)]
|
||||
---------- ToolCallSummaryMessage (assistant_agent) ----------
|
||||
The weather in New York is 23 °C and Sunny.
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = RedisMemoryConfig
|
||||
component_provider_override = "agentdhal_extensions.memory.redis_memory.RedisMemory"
|
||||
|
||||
def __init__(self, config: RedisMemoryConfig | None = None) -> None:
|
||||
"""Initialize RedisMemory."""
|
||||
self.config = config or RedisMemoryConfig()
|
||||
client = Redis.from_url(url=self.config.redis_url) # type: ignore[reportUknownMemberType]
|
||||
|
||||
self.message_history = SemanticMessageHistory(name=self.config.index_name, redis_client=client)
|
||||
|
||||
async def update_context(
|
||||
self,
|
||||
model_context: ChatCompletionContext,
|
||||
) -> UpdateContextResult:
|
||||
"""
|
||||
Update the model context with relevant memory content.
|
||||
|
||||
This method retrieves memory content relevant to the last message in the context
|
||||
and adds it as a system message. This implementation uses the last message in the context
|
||||
as a query to find semantically similar memories and adds them all to the context as a
|
||||
single system message.
|
||||
|
||||
Args:
|
||||
model_context (ChatCompletionContext): The model context to update with relevant
|
||||
memories.
|
||||
|
||||
Returns:
|
||||
UpdateContextResult: Object containing the memories that were used to update the
|
||||
context.
|
||||
"""
|
||||
messages = await model_context.get_messages()
|
||||
if messages:
|
||||
last_message = str(messages[-1].content)
|
||||
else:
|
||||
last_message = ""
|
||||
|
||||
query_results = await self.query(last_message)
|
||||
|
||||
stringified_messages = "\n\n".join([str(m.content) for m in query_results.results])
|
||||
|
||||
await model_context.add_message(SystemMessage(content=stringified_messages))
|
||||
|
||||
return UpdateContextResult(memories=query_results)
|
||||
|
||||
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
|
||||
"""Add a memory content object to Redis.
|
||||
|
||||
.. note::
|
||||
|
||||
To perform semantic search over stored memories RedisMemory creates a vector embedding
|
||||
from the content field of a MemoryContent object. This content is assumed to be text,
|
||||
JSON, or Markdown, and is passed to the vector embedding model specified in
|
||||
RedisMemoryConfig.
|
||||
|
||||
Args:
|
||||
content (MemoryContent): The memory content to store within Redis.
|
||||
cancellation_token (CancellationToken): Token passed to cease operation. Not used.
|
||||
"""
|
||||
if content.mime_type == MemoryMimeType.TEXT:
|
||||
memory_content = content.content
|
||||
mime_type = "text/plain"
|
||||
elif content.mime_type == MemoryMimeType.JSON:
|
||||
memory_content = serialize(content.content)
|
||||
mime_type = "application/json"
|
||||
elif content.mime_type == MemoryMimeType.MARKDOWN:
|
||||
memory_content = content.content
|
||||
mime_type = "text/markdown"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Error: {content.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, and MemoryMimeType.MARKDOWN are currently supported."
|
||||
)
|
||||
metadata = {"mime_type": mime_type}
|
||||
metadata.update(content.metadata if content.metadata else {})
|
||||
self.message_history.add_message(
|
||||
{"role": "user", "content": memory_content, "tool_call_id": serialize(metadata)} # type: ignore[reportArgumentType]
|
||||
)
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str | MemoryContent,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MemoryQueryResult:
|
||||
"""Query memory content based on semantic vector similarity.
|
||||
|
||||
.. note::
|
||||
|
||||
RedisMemory.query() supports additional keyword arguments to improve query performance.
|
||||
top_k (int): The maximum number of relevant memories to include. Defaults to 10.
|
||||
distance_threshold (float): The maximum distance in vector space to consider a memory
|
||||
semantically similar when performining cosine similarity search. Defaults to 0.7.
|
||||
|
||||
Args:
|
||||
query (str | MemoryContent): query to perform vector similarity search with. If a
|
||||
string is passed, a vector embedding is created from it with the model specified
|
||||
in the RedisMemoryConfig. If a MemoryContent object is passed, the content field
|
||||
of this object is extracted and a vector embedding is created from it with the
|
||||
model specified in the RedisMemoryConfig.
|
||||
cancellation_token (CancellationToken): Token passed to cease operation. Not used.
|
||||
|
||||
Returns:
|
||||
memoryQueryResult: Object containing memories relevant to the provided query.
|
||||
"""
|
||||
# get the query string, or raise an error for unsupported MemoryContent types
|
||||
if isinstance(query, str):
|
||||
prompt = query
|
||||
elif isinstance(query, MemoryContent):
|
||||
if query.mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
|
||||
prompt = str(query.content)
|
||||
elif query.mime_type == MemoryMimeType.JSON:
|
||||
prompt = serialize(query.content)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Error: {query.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, MemoryMimeType.MARKDOWN are currently supported."
|
||||
)
|
||||
else:
|
||||
raise TypeError("'query' must be either a string or MemoryContent")
|
||||
|
||||
top_k = kwargs.pop("top_k", self.config.top_k)
|
||||
distance_threshold = kwargs.pop("distance_threshold", self.config.distance_threshold)
|
||||
|
||||
results = self.message_history.get_relevant(
|
||||
prompt=prompt, # type: ignore[reportArgumentType]
|
||||
top_k=top_k,
|
||||
distance_threshold=distance_threshold,
|
||||
raw=False,
|
||||
)
|
||||
|
||||
memories: List[MemoryContent] = []
|
||||
for result in results:
|
||||
metadata = deserialize(result["tool_call_id"]) # type: ignore[reportArgumentType]
|
||||
mime_type = MemoryMimeType(metadata.pop("mime_type"))
|
||||
if mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
|
||||
memory_content = result["content"] # type: ignore[reportArgumentType]
|
||||
elif mime_type == MemoryMimeType.JSON:
|
||||
memory_content = deserialize(result["content"]) # type: ignore[reportArgumentType]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Error: {mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, and MemoryMimeType.MARKDOWN are currently supported."
|
||||
)
|
||||
memory = MemoryContent(
|
||||
content=memory_content, # type: ignore[reportArgumentType]
|
||||
mime_type=mime_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
memories.append(memory) # type: ignore[reportUknownMemberType]
|
||||
|
||||
return MemoryQueryResult(results=memories) # type: ignore[reportUknownMemberType]
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all entries from memory, preserving the RedisMemory resources."""
|
||||
self.message_history.clear()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clears all entries from memory, and cleans up Redis client, index and resources."""
|
||||
self.message_history.delete()
|
||||
0
agent_dhal/agentdhal_extensions/py.typed
Normal file
0
agent_dhal/agentdhal_extensions/py.typed
Normal file
16
agent_dhal/agentdhal_extensions/runtimes/grpc/__init__.py
Normal file
16
agent_dhal/agentdhal_extensions/runtimes/grpc/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from ._worker_runtime import GrpcWorkerAgentRuntime
|
||||
from ._worker_runtime_host import GrpcWorkerAgentRuntimeHost
|
||||
from ._worker_runtime_host_servicer import GrpcWorkerAgentRuntimeHostServicer
|
||||
|
||||
try:
|
||||
import grpc # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"To use the GRPC runtime the grpc extra must be installed. Run `pip install autogen-ext[grpc]`"
|
||||
) from e
|
||||
|
||||
__all__ = [
|
||||
"GrpcWorkerAgentRuntime",
|
||||
"GrpcWorkerAgentRuntimeHost",
|
||||
"GrpcWorkerAgentRuntimeHostServicer",
|
||||
]
|
||||
13
agent_dhal/agentdhal_extensions/runtimes/grpc/_constants.py
Normal file
13
agent_dhal/agentdhal_extensions/runtimes/grpc/_constants.py
Normal file
@@ -0,0 +1,13 @@
|
||||
GRPC_IMPORT_ERROR_STR = (
|
||||
"Distributed runtime features require additional dependencies. Install them with: pip install autogen-core[grpc]"
|
||||
)
|
||||
|
||||
DATA_CONTENT_TYPE_ATTR = "datacontenttype"
|
||||
DATA_SCHEMA_ATTR = "dataschema"
|
||||
AGENT_SENDER_TYPE_ATTR = "agagentsendertype"
|
||||
AGENT_SENDER_KEY_ATTR = "agagentsenderkey"
|
||||
MESSAGE_KIND_ATTR = "agmsgkind"
|
||||
MESSAGE_KIND_VALUE_PUBLISH = "publish"
|
||||
MESSAGE_KIND_VALUE_RPC_REQUEST = "rpc_request"
|
||||
MESSAGE_KIND_VALUE_RPC_RESPONSE = "rpc_response"
|
||||
MESSAGE_KIND_VALUE_RPC_ERROR = "error"
|
||||
@@ -0,0 +1,4 @@
|
||||
from typing import Any, Sequence, Tuple
|
||||
|
||||
# Had to redefine this from grpc.aio._typing as using that one was causing mypy errors
|
||||
ChannelArgumentType = Sequence[Tuple[str, Any]]
|
||||
45
agent_dhal/agentdhal_extensions/runtimes/grpc/_utils.py
Normal file
45
agent_dhal/agentdhal_extensions/runtimes/grpc/_utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from agentdhal_core._subscription import Subscription
|
||||
from agentdhal_core._type_prefix_subscription import TypePrefixSubscription
|
||||
from agentdhal_core._type_subscription import TypeSubscription
|
||||
|
||||
from .protos import agent_worker_pb2
|
||||
|
||||
|
||||
def subscription_to_proto(subscription: Subscription) -> agent_worker_pb2.Subscription:
|
||||
match subscription:
|
||||
case TypeSubscription(topic_type=topic_type, agent_type=agent_type, id=id):
|
||||
return agent_worker_pb2.Subscription(
|
||||
id=id,
|
||||
typeSubscription=agent_worker_pb2.TypeSubscription(topic_type=topic_type, agent_type=agent_type),
|
||||
)
|
||||
case TypePrefixSubscription(topic_type_prefix=topic_type_prefix, agent_type=agent_type, id=id):
|
||||
return agent_worker_pb2.Subscription(
|
||||
id=id,
|
||||
typePrefixSubscription=agent_worker_pb2.TypePrefixSubscription(
|
||||
topic_type_prefix=topic_type_prefix, agent_type=agent_type
|
||||
),
|
||||
)
|
||||
case _:
|
||||
raise ValueError("Unsupported subscription type.")
|
||||
|
||||
|
||||
def subscription_from_proto(subscription: agent_worker_pb2.Subscription) -> Subscription:
|
||||
oneofcase = subscription.WhichOneof("subscription")
|
||||
match oneofcase:
|
||||
case "typeSubscription":
|
||||
type_subscription_msg: agent_worker_pb2.TypeSubscription = subscription.typeSubscription
|
||||
return TypeSubscription(
|
||||
topic_type=type_subscription_msg.topic_type,
|
||||
agent_type=type_subscription_msg.agent_type,
|
||||
id=subscription.id,
|
||||
)
|
||||
|
||||
case "typePrefixSubscription":
|
||||
type_prefix_subscription_msg: agent_worker_pb2.TypePrefixSubscription = subscription.typePrefixSubscription
|
||||
return TypePrefixSubscription(
|
||||
topic_type_prefix=type_prefix_subscription_msg.topic_type_prefix,
|
||||
agent_type=type_prefix_subscription_msg.agent_type,
|
||||
id=subscription.id,
|
||||
)
|
||||
case None:
|
||||
raise ValueError("Invalid subscription message.")
|
||||
856
agent_dhal/agentdhal_extensions/runtimes/grpc/_worker_runtime.py
Normal file
856
agent_dhal/agentdhal_extensions/runtimes/grpc/_worker_runtime.py
Normal file
@@ -0,0 +1,856 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import signal
|
||||
import uuid
|
||||
import warnings
|
||||
from asyncio import Future, Task
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ClassVar,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
ParamSpec,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
from agentdhal_core import (
|
||||
JSON_DATA_CONTENT_TYPE,
|
||||
PROTOBUF_DATA_CONTENT_TYPE,
|
||||
Agent,
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
AgentMetadata,
|
||||
AgentRuntime,
|
||||
AgentType,
|
||||
CancellationToken,
|
||||
MessageContext,
|
||||
MessageHandlerContext,
|
||||
MessageSerializer,
|
||||
Subscription,
|
||||
TopicId,
|
||||
)
|
||||
from agentdhal_core._runtime_impl_helpers import SubscriptionManager, get_impl
|
||||
from agentdhal_core._serialization import (
|
||||
SerializationRegistry,
|
||||
)
|
||||
from agentdhal_core._telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata
|
||||
from google.protobuf import any_pb2
|
||||
from opentelemetry.trace import TracerProvider
|
||||
from typing_extensions import Self
|
||||
|
||||
from agentdhal_extensions.runtimes.grpc._utils import subscription_to_proto
|
||||
|
||||
from . import _constants
|
||||
from ._constants import GRPC_IMPORT_ERROR_STR
|
||||
from ._type_helpers import ChannelArgumentType
|
||||
from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
|
||||
|
||||
try:
|
||||
import grpc.aio
|
||||
except ImportError as e:
|
||||
raise ImportError(GRPC_IMPORT_ERROR_STR) from e
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protos.agent_worker_pb2_grpc import AgentRpcAsyncStub
|
||||
|
||||
logger = logging.getLogger("agentdhal_core")
|
||||
event_logger = logging.getLogger("agentdhal_core.events")
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
|
||||
type_func_alias = type
|
||||
|
||||
|
||||
class QueueAsyncIterable(AsyncIterator[Any], AsyncIterable[Any]):
|
||||
def __init__(self, queue: asyncio.Queue[Any]) -> None:
|
||||
self._queue = queue
|
||||
|
||||
async def __anext__(self) -> Any:
|
||||
return await self._queue.get()
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[Any]:
|
||||
return self
|
||||
|
||||
|
||||
class HostConnection:
|
||||
DEFAULT_GRPC_CONFIG: ClassVar[ChannelArgumentType] = [
|
||||
(
|
||||
"grpc.service_config",
|
||||
json.dumps(
|
||||
{
|
||||
"methodConfig": [
|
||||
{
|
||||
"name": [{}],
|
||||
"retryPolicy": {
|
||||
"maxAttempts": 3,
|
||||
"initialBackoff": "0.01s",
|
||||
"maxBackoff": "5s",
|
||||
"backoffMultiplier": 2,
|
||||
"retryableStatusCodes": ["UNAVAILABLE"],
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
def __init__(self, channel: grpc.aio.Channel, stub: Any) -> None: # type: ignore
|
||||
self._channel = channel
|
||||
self._send_queue = asyncio.Queue[agent_worker_pb2.Message]()
|
||||
self._recv_queue = asyncio.Queue[agent_worker_pb2.Message]()
|
||||
self._connection_task: Task[None] | None = None
|
||||
self._stub: AgentRpcAsyncStub = stub
|
||||
self._client_id = str(uuid.uuid4())
|
||||
|
||||
@property
|
||||
def stub(self) -> Any:
|
||||
return self._stub
|
||||
|
||||
@property
|
||||
def metadata(self) -> Sequence[Tuple[str, str]]:
|
||||
return [("client-id", self._client_id)]
|
||||
|
||||
@classmethod
|
||||
async def from_host_address(
|
||||
cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG
|
||||
) -> Self:
|
||||
logger.info("Connecting to %s", host_address)
|
||||
# Always use DEFAULT_GRPC_CONFIG and override it with provided grpc_config
|
||||
merged_options = [
|
||||
(k, v) for k, v in {**dict(HostConnection.DEFAULT_GRPC_CONFIG), **dict(extra_grpc_config)}.items()
|
||||
]
|
||||
|
||||
channel = grpc.aio.insecure_channel(
|
||||
host_address,
|
||||
options=merged_options,
|
||||
)
|
||||
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore
|
||||
instance = cls(channel, stub)
|
||||
|
||||
instance._connection_task = await instance._connect(
|
||||
stub, instance._send_queue, instance._recv_queue, instance._client_id
|
||||
)
|
||||
|
||||
return instance
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._connection_task is None:
|
||||
raise RuntimeError("Connection is not open.")
|
||||
await self._channel.close()
|
||||
await self._connection_task
|
||||
|
||||
@staticmethod
|
||||
async def _connect(
|
||||
stub: Any, # AgentRpcAsyncStub
|
||||
send_queue: asyncio.Queue[agent_worker_pb2.Message],
|
||||
receive_queue: asyncio.Queue[agent_worker_pb2.Message],
|
||||
client_id: str,
|
||||
) -> Task[None]:
|
||||
from grpc.aio import StreamStreamCall
|
||||
|
||||
# TODO: where do exceptions from reading the iterable go? How do we recover from those?
|
||||
stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
|
||||
QueueAsyncIterable(send_queue), metadata=[("client-id", client_id)]
|
||||
)
|
||||
|
||||
await stream.wait_for_connection()
|
||||
|
||||
async def read_loop() -> None:
|
||||
while True:
|
||||
logger.info("Waiting for message from host")
|
||||
message = cast(agent_worker_pb2.Message, await stream.read()) # type: ignore
|
||||
if message == grpc.aio.EOF: # type: ignore
|
||||
logger.info("EOF")
|
||||
break
|
||||
logger.info(f"Received a message from host: {message}")
|
||||
await receive_queue.put(message)
|
||||
logger.info("Put message in receive queue")
|
||||
|
||||
return asyncio.create_task(read_loop())
|
||||
|
||||
async def send(self, message: agent_worker_pb2.Message) -> None:
|
||||
logger.info(f"Send message to host: {message}")
|
||||
await self._send_queue.put(message)
|
||||
logger.info("Put message in send queue")
|
||||
|
||||
async def recv(self) -> agent_worker_pb2.Message:
|
||||
logger.info("Getting message from queue")
|
||||
return await self._recv_queue.get()
|
||||
|
||||
|
||||
# TODO: Lots of types need to have protobuf equivalents:
|
||||
# Core:
|
||||
# - FunctionCall, CodeResult, possibly CodeBlock
|
||||
# - All the types in https://github.com/microsoft/autogen/blob/main/python/packages/autogen-core/src/agentdhal_core/models/_types.py
|
||||
#
|
||||
# Agentchat:
|
||||
# - All the types in https://github.com/microsoft/autogen/blob/main/python/packages/autogen-agentchat/src/agentdhal_agentchat/messages.py to protobufs.
|
||||
#
|
||||
# Ext --
|
||||
# CodeExecutor:
|
||||
# - CommandLineCodeResult
|
||||
|
||||
|
||||
class GrpcWorkerAgentRuntime(AgentRuntime):
|
||||
"""An agent runtime for running remote or cross-language agents.
|
||||
|
||||
Agent messaging uses protobufs from `agent_worker.proto`_ and ``CloudEvent`` from `cloudevent.proto`_.
|
||||
|
||||
Cross-language agents will additionally require all agents use shared protobuf schemas for any message types that are sent between agents.
|
||||
|
||||
.. _agent_worker.proto: https://github.com/microsoft/autogen/blob/main/protos/agent_worker.proto
|
||||
|
||||
.. _cloudevent.proto: https://github.com/microsoft/autogen/blob/main/protos/cloudevent.proto
|
||||
|
||||
"""
|
||||
|
||||
# TODO: Needs to handle agent close() call
|
||||
def __init__(
|
||||
self,
|
||||
host_address: str,
|
||||
tracer_provider: TracerProvider | None = None,
|
||||
extra_grpc_config: ChannelArgumentType | None = None,
|
||||
payload_serialization_format: str = JSON_DATA_CONTENT_TYPE,
|
||||
) -> None:
|
||||
self._host_address = host_address
|
||||
self._trace_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("Worker Runtime"))
|
||||
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
|
||||
self._agent_factories: Dict[
|
||||
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
|
||||
] = {}
|
||||
self._instantiated_agents: Dict[AgentId, Agent] = {}
|
||||
self._known_namespaces: set[str] = set()
|
||||
self._read_task: None | Task[None] = None
|
||||
self._running = False
|
||||
self._pending_requests: Dict[str, Future[Any]] = {}
|
||||
self._pending_requests_lock = asyncio.Lock()
|
||||
self._next_request_id = 0
|
||||
self._host_connection: HostConnection | None = None
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._serialization_registry = SerializationRegistry()
|
||||
self._extra_grpc_config = extra_grpc_config or []
|
||||
self._agent_instance_types: Dict[str, Type[Agent]] = {}
|
||||
|
||||
if payload_serialization_format not in {JSON_DATA_CONTENT_TYPE, PROTOBUF_DATA_CONTENT_TYPE}:
|
||||
raise ValueError(f"Unsupported payload serialization format: {payload_serialization_format}")
|
||||
|
||||
self._payload_serialization_format = payload_serialization_format
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the runtime in a background task."""
|
||||
if self._running:
|
||||
raise ValueError("Runtime is already running.")
|
||||
logger.info(f"Connecting to host: {self._host_address}")
|
||||
self._host_connection = await HostConnection.from_host_address(
|
||||
self._host_address, extra_grpc_config=self._extra_grpc_config
|
||||
)
|
||||
logger.info("Connection established")
|
||||
if self._read_task is None:
|
||||
self._read_task = asyncio.create_task(self._run_read_loop())
|
||||
self._running = True
|
||||
|
||||
def _raise_on_exception(self, task: Task[Any]) -> None:
|
||||
exception = task.exception()
|
||||
if exception is not None:
|
||||
raise exception
|
||||
|
||||
async def _run_read_loop(self) -> None:
|
||||
logger.info("Starting read loop")
|
||||
assert self._host_connection is not None
|
||||
# TODO: catch exceptions and reconnect
|
||||
while self._running:
|
||||
try:
|
||||
message = await self._host_connection.recv()
|
||||
oneofcase = agent_worker_pb2.Message.WhichOneof(message, "message")
|
||||
match oneofcase:
|
||||
case "request":
|
||||
task = asyncio.create_task(self._process_request(message.request))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "response":
|
||||
task = asyncio.create_task(self._process_response(message.response))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "cloudEvent":
|
||||
task = asyncio.create_task(self._process_event(message.cloudEvent))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case None:
|
||||
logger.warning("No message")
|
||||
except Exception as e:
|
||||
logger.error("Error in read loop", exc_info=e)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the runtime immediately."""
|
||||
if not self._running:
|
||||
raise RuntimeError("Runtime is not running.")
|
||||
self._running = False
|
||||
# Wait for all background tasks to finish.
|
||||
final_tasks_results = await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
for task_result in final_tasks_results:
|
||||
if isinstance(task_result, Exception):
|
||||
logger.error("Error in background task", exc_info=task_result)
|
||||
# Close the host connection.
|
||||
if self._host_connection is not None:
|
||||
try:
|
||||
await self._host_connection.close()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Cancel the read task.
|
||||
if self._read_task is not None:
|
||||
self._read_task.cancel()
|
||||
try:
|
||||
await self._read_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def stop_when_signal(self, signals: Sequence[signal.Signals] = (signal.SIGTERM, signal.SIGINT)) -> None:
|
||||
"""Stop the runtime when a signal is received."""
|
||||
loop = asyncio.get_running_loop()
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
def signal_handler() -> None:
|
||||
logger.info("Received exit signal, shutting down gracefully...")
|
||||
shutdown_event.set()
|
||||
|
||||
for sig in signals:
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
# Wait for the signal to trigger the shutdown event.
|
||||
await shutdown_event.wait()
|
||||
|
||||
# Stop the runtime.
|
||||
await self.stop()
|
||||
|
||||
@property
|
||||
def _known_agent_names(self) -> Set[str]:
|
||||
return set(self._agent_factories.keys())
|
||||
|
||||
async def _send_message(
|
||||
self,
|
||||
runtime_message: agent_worker_pb2.Message,
|
||||
send_type: Literal["send", "publish"],
|
||||
recipient: AgentId | TopicId,
|
||||
telemetry_metadata: Mapping[str, str],
|
||||
) -> None:
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
with self._trace_helper.trace_block(send_type, recipient, parent=telemetry_metadata):
|
||||
await self._host_connection.send(runtime_message)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
recipient: AgentId,
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
# TODO: use message_id
|
||||
if not self._running:
|
||||
raise ValueError("Runtime must be running when sending message.")
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
data_type = self._serialization_registry.type_name(message)
|
||||
with self._trace_helper.trace_block(
|
||||
"create", recipient, parent=None, extraAttributes={"message_type": data_type}
|
||||
):
|
||||
# create a new future for the result
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
request_id = await self._get_new_request_id()
|
||||
self._pending_requests[request_id] = future
|
||||
serialized_message = self._serialization_registry.serialize(
|
||||
message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE
|
||||
)
|
||||
telemetry_metadata = get_telemetry_grpc_metadata()
|
||||
runtime_message = agent_worker_pb2.Message(
|
||||
request=agent_worker_pb2.RpcRequest(
|
||||
request_id=request_id,
|
||||
target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key),
|
||||
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None,
|
||||
metadata=telemetry_metadata,
|
||||
payload=agent_worker_pb2.Payload(
|
||||
data_type=data_type,
|
||||
data=serialized_message,
|
||||
data_content_type=JSON_DATA_CONTENT_TYPE,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: Find a way to handle timeouts/errors
|
||||
task = asyncio.create_task(self._send_message(runtime_message, "send", recipient, telemetry_metadata))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
return await future
|
||||
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
topic_id: TopicId,
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
if not self._running:
|
||||
raise ValueError("Runtime must be running when publishing message.")
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
message_type = self._serialization_registry.type_name(message)
|
||||
with self._trace_helper.trace_block(
|
||||
"create", topic_id, parent=None, extraAttributes={"message_type": message_type}
|
||||
):
|
||||
serialized_message = self._serialization_registry.serialize(
|
||||
message, type_name=message_type, data_content_type=self._payload_serialization_format
|
||||
)
|
||||
|
||||
sender_id = sender or AgentId("unknown", "unknown")
|
||||
attributes = {
|
||||
_constants.DATA_CONTENT_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
|
||||
ce_string=self._payload_serialization_format
|
||||
),
|
||||
_constants.DATA_SCHEMA_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=message_type),
|
||||
_constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
|
||||
ce_string=sender_id.type
|
||||
),
|
||||
_constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
|
||||
ce_string=sender_id.key
|
||||
),
|
||||
_constants.MESSAGE_KIND_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
|
||||
ce_string=_constants.MESSAGE_KIND_VALUE_PUBLISH
|
||||
),
|
||||
}
|
||||
|
||||
# If sending JSON we fill text_data with the serialized message
|
||||
# If sending Protobuf we fill proto_data with the serialized message
|
||||
# TODO: add an encoding field for serializer
|
||||
|
||||
if self._payload_serialization_format == JSON_DATA_CONTENT_TYPE:
|
||||
runtime_message = agent_worker_pb2.Message(
|
||||
cloudEvent=cloudevent_pb2.CloudEvent(
|
||||
id=message_id,
|
||||
spec_version="1.0",
|
||||
type=topic_id.type,
|
||||
source=topic_id.source,
|
||||
attributes=attributes,
|
||||
# TODO: use text, or proto fields appropriately
|
||||
binary_data=serialized_message,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# We need to unpack the serialized proto back into an Any
|
||||
# TODO: find a way to prevent the roundtrip serialization
|
||||
any_proto = any_pb2.Any()
|
||||
any_proto.ParseFromString(serialized_message)
|
||||
runtime_message = agent_worker_pb2.Message(
|
||||
cloudEvent=cloudevent_pb2.CloudEvent(
|
||||
id=message_id,
|
||||
spec_version="1.0",
|
||||
type=topic_id.type,
|
||||
source=topic_id.source,
|
||||
attributes=attributes,
|
||||
proto_data=any_proto,
|
||||
)
|
||||
)
|
||||
|
||||
telemetry_metadata = get_telemetry_grpc_metadata()
|
||||
task = asyncio.create_task(self._send_message(runtime_message, "publish", topic_id, telemetry_metadata))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
raise NotImplementedError("Saving state is not yet implemented.")
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
raise NotImplementedError("Loading state is not yet implemented.")
|
||||
|
||||
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
raise NotImplementedError("Agent metadata is not yet implemented.")
|
||||
|
||||
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
|
||||
raise NotImplementedError("Agent save_state is not yet implemented.")
|
||||
|
||||
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
raise NotImplementedError("Agent load_state is not yet implemented.")
|
||||
|
||||
async def _get_new_request_id(self) -> str:
|
||||
async with self._pending_requests_lock:
|
||||
self._next_request_id += 1
|
||||
return str(self._next_request_id)
|
||||
|
||||
async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
|
||||
assert self._host_connection is not None
|
||||
recipient = AgentId(request.target.type, request.target.key)
|
||||
sender: AgentId | None = None
|
||||
if request.HasField("source"):
|
||||
sender = AgentId(request.source.type, request.source.key)
|
||||
logging.info(f"Processing request from {sender} to {recipient}")
|
||||
else:
|
||||
logging.info(f"Processing request from unknown source to {recipient}")
|
||||
|
||||
# Deserialize the message.
|
||||
message = self._serialization_registry.deserialize(
|
||||
request.payload.data,
|
||||
type_name=request.payload.data_type,
|
||||
data_content_type=request.payload.data_content_type,
|
||||
)
|
||||
|
||||
# Get the receiving agent and prepare the message context.
|
||||
rec_agent = await self._get_agent(recipient)
|
||||
message_context = MessageContext(
|
||||
sender=sender,
|
||||
topic_id=None,
|
||||
is_rpc=True,
|
||||
cancellation_token=CancellationToken(),
|
||||
message_id=request.request_id,
|
||||
)
|
||||
|
||||
# Call the receiving agent.
|
||||
try:
|
||||
with MessageHandlerContext.populate_context(rec_agent.id):
|
||||
with self._trace_helper.trace_block(
|
||||
"process",
|
||||
rec_agent.id,
|
||||
parent=request.metadata,
|
||||
attributes={"request_id": request.request_id},
|
||||
extraAttributes={"message_type": request.payload.data_type},
|
||||
):
|
||||
result = await rec_agent.on_message(message, ctx=message_context)
|
||||
except BaseException as e:
|
||||
response_message = agent_worker_pb2.Message(
|
||||
response=agent_worker_pb2.RpcResponse(
|
||||
request_id=request.request_id,
|
||||
error=str(e),
|
||||
metadata=get_telemetry_grpc_metadata(),
|
||||
),
|
||||
)
|
||||
# Send the error response.
|
||||
await self._host_connection.send(response_message)
|
||||
return
|
||||
|
||||
# Serialize the result.
|
||||
result_type = self._serialization_registry.type_name(result)
|
||||
serialized_result = self._serialization_registry.serialize(
|
||||
result, type_name=result_type, data_content_type=JSON_DATA_CONTENT_TYPE
|
||||
)
|
||||
|
||||
# Create the response message.
|
||||
response_message = agent_worker_pb2.Message(
|
||||
response=agent_worker_pb2.RpcResponse(
|
||||
request_id=request.request_id,
|
||||
payload=agent_worker_pb2.Payload(
|
||||
data_type=result_type,
|
||||
data=serialized_result,
|
||||
data_content_type=JSON_DATA_CONTENT_TYPE,
|
||||
),
|
||||
metadata=get_telemetry_grpc_metadata(),
|
||||
)
|
||||
)
|
||||
|
||||
# Send the response.
|
||||
await self._host_connection.send(response_message)
|
||||
|
||||
async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> None:
|
||||
with self._trace_helper.trace_block(
|
||||
"ack",
|
||||
None,
|
||||
parent=response.metadata,
|
||||
attributes={"request_id": response.request_id},
|
||||
extraAttributes={"message_type": response.payload.data_type},
|
||||
):
|
||||
# Deserialize the result.
|
||||
result = self._serialization_registry.deserialize(
|
||||
response.payload.data,
|
||||
type_name=response.payload.data_type,
|
||||
data_content_type=response.payload.data_content_type,
|
||||
)
|
||||
# Get the future and set the result.
|
||||
future = self._pending_requests.pop(response.request_id)
|
||||
if len(response.error) > 0:
|
||||
future.set_exception(Exception(response.error))
|
||||
else:
|
||||
future.set_result(result)
|
||||
|
||||
async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
|
||||
event_attributes = event.attributes
|
||||
sender: AgentId | None = None
|
||||
if (
|
||||
_constants.AGENT_SENDER_TYPE_ATTR in event_attributes
|
||||
and _constants.AGENT_SENDER_KEY_ATTR in event_attributes
|
||||
):
|
||||
sender = AgentId(
|
||||
event_attributes[_constants.AGENT_SENDER_TYPE_ATTR].ce_string,
|
||||
event_attributes[_constants.AGENT_SENDER_KEY_ATTR].ce_string,
|
||||
)
|
||||
topic_id = TopicId(event.type, event.source)
|
||||
# Get the recipients for the topic.
|
||||
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
|
||||
|
||||
message_content_type = event_attributes[_constants.DATA_CONTENT_TYPE_ATTR].ce_string
|
||||
message_type = event_attributes[_constants.DATA_SCHEMA_ATTR].ce_string
|
||||
|
||||
if message_content_type == JSON_DATA_CONTENT_TYPE:
|
||||
message = self._serialization_registry.deserialize(
|
||||
event.binary_data, type_name=message_type, data_content_type=message_content_type
|
||||
)
|
||||
elif message_content_type == PROTOBUF_DATA_CONTENT_TYPE:
|
||||
# TODO: find a way to prevent the roundtrip serialization
|
||||
proto_binary_data = event.proto_data.SerializeToString()
|
||||
message = self._serialization_registry.deserialize(
|
||||
proto_binary_data, type_name=message_type, data_content_type=message_content_type
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message content type: {message_content_type}")
|
||||
|
||||
# TODO: dont read these values in the runtime
|
||||
topic_type_suffix = topic_id.type.split(":", maxsplit=1)[1] if ":" in topic_id.type else ""
|
||||
is_rpc = topic_type_suffix == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST
|
||||
is_marked_rpc_type = (
|
||||
_constants.MESSAGE_KIND_ATTR in event_attributes
|
||||
and event_attributes[_constants.MESSAGE_KIND_ATTR].ce_string == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST
|
||||
)
|
||||
if is_rpc and not is_marked_rpc_type:
|
||||
warnings.warn("Received RPC request with topic type suffix but not marked as RPC request.", stacklevel=2)
|
||||
|
||||
# Send the message to each recipient.
|
||||
responses: List[Awaitable[Any]] = []
|
||||
for agent_id in recipients:
|
||||
if agent_id == sender:
|
||||
continue
|
||||
message_context = MessageContext(
|
||||
sender=sender,
|
||||
topic_id=topic_id,
|
||||
is_rpc=is_rpc,
|
||||
cancellation_token=CancellationToken(),
|
||||
message_id=event.id,
|
||||
)
|
||||
agent = await self._get_agent(agent_id)
|
||||
with MessageHandlerContext.populate_context(agent.id):
|
||||
|
||||
def stringify_attributes(
|
||||
attributes: Mapping[str, cloudevent_pb2.CloudEvent.CloudEventAttributeValue],
|
||||
) -> Mapping[str, str]:
|
||||
result: Dict[str, str] = {}
|
||||
for key, value in attributes.items():
|
||||
item = None
|
||||
match value.WhichOneof("attr"):
|
||||
case "ce_boolean":
|
||||
item = str(value.ce_boolean)
|
||||
case "ce_integer":
|
||||
item = str(value.ce_integer)
|
||||
case "ce_string":
|
||||
item = value.ce_string
|
||||
case "ce_bytes":
|
||||
item = str(value.ce_bytes)
|
||||
case "ce_uri":
|
||||
item = value.ce_uri
|
||||
case "ce_uri_ref":
|
||||
item = value.ce_uri_ref
|
||||
case "ce_timestamp":
|
||||
item = str(value.ce_timestamp)
|
||||
case _:
|
||||
raise ValueError("Unknown attribute kind")
|
||||
result[key] = item
|
||||
|
||||
return result
|
||||
|
||||
async def send_message(agent: Agent, message_context: MessageContext) -> Any:
|
||||
with self._trace_helper.trace_block(
|
||||
"process",
|
||||
agent.id,
|
||||
parent=stringify_attributes(event.attributes),
|
||||
extraAttributes={"message_type": message_type},
|
||||
):
|
||||
await agent.on_message(message, ctx=message_context)
|
||||
|
||||
future = send_message(agent, message_context)
|
||||
responses.append(future)
|
||||
# Wait for all responses.
|
||||
try:
|
||||
await asyncio.gather(*responses)
|
||||
except BaseException as e:
|
||||
logger.error("Error handling event", exc_info=e)
|
||||
|
||||
async def _register_agent_type(self, agent_type: str) -> None:
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
message = agent_worker_pb2.RegisterAgentTypeRequest(type=agent_type)
|
||||
_response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent(
|
||||
message, metadata=self._host_connection.metadata
|
||||
)
|
||||
|
||||
async def register_factory(
|
||||
self,
|
||||
type: str | AgentType,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
*,
|
||||
expected_class: type[T] | None = None,
|
||||
) -> AgentType:
|
||||
if isinstance(type, str):
|
||||
type = AgentType(type)
|
||||
|
||||
if type.type in self._agent_factories:
|
||||
raise ValueError(f"Agent with type {type} already exists.")
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
|
||||
async def factory_wrapper() -> T:
|
||||
maybe_agent_instance = agent_factory()
|
||||
if inspect.isawaitable(maybe_agent_instance):
|
||||
agent_instance = await maybe_agent_instance
|
||||
else:
|
||||
agent_instance = maybe_agent_instance
|
||||
|
||||
if expected_class is not None and type_func_alias(agent_instance) != expected_class:
|
||||
raise ValueError("Factory registered using the wrong type.")
|
||||
|
||||
return agent_instance
|
||||
|
||||
self._agent_factories[type.type] = factory_wrapper
|
||||
# Send the registration request message to the host.
|
||||
await self._register_agent_type(type.type)
|
||||
|
||||
return type
|
||||
|
||||
async def register_agent_instance(
|
||||
self,
|
||||
agent_instance: Agent,
|
||||
agent_id: AgentId,
|
||||
) -> AgentId:
|
||||
def agent_factory() -> Agent:
|
||||
raise RuntimeError(
|
||||
"Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent."
|
||||
)
|
||||
|
||||
if agent_id in self._instantiated_agents:
|
||||
raise ValueError(f"Agent with id {agent_id} already exists.")
|
||||
|
||||
if agent_id.type not in self._agent_factories:
|
||||
self._agent_factories[agent_id.type] = agent_factory
|
||||
await self._register_agent_type(agent_id.type)
|
||||
self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance)
|
||||
else:
|
||||
if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__:
|
||||
raise ValueError("Agent factories and agent instances cannot be registered to the same type.")
|
||||
if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance):
|
||||
raise ValueError("Agent instances must be the same object type.")
|
||||
|
||||
await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self)
|
||||
self._instantiated_agents[agent_id] = agent_instance
|
||||
return agent_id
|
||||
|
||||
async def _invoke_agent_factory(
|
||||
self,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
agent_id: AgentId,
|
||||
) -> T:
|
||||
with AgentInstantiationContext.populate_context((self, agent_id)):
|
||||
if len(inspect.signature(agent_factory).parameters) == 0:
|
||||
factory_one = cast(Callable[[], T], agent_factory)
|
||||
agent = factory_one()
|
||||
elif len(inspect.signature(agent_factory).parameters) == 2:
|
||||
warnings.warn(
|
||||
"Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
|
||||
stacklevel=2,
|
||||
)
|
||||
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
|
||||
agent = factory_two(self, agent_id)
|
||||
else:
|
||||
raise ValueError("Agent factory must take 0 or 2 arguments.")
|
||||
|
||||
if inspect.isawaitable(agent):
|
||||
agent = cast(T, await agent)
|
||||
|
||||
return agent
|
||||
|
||||
async def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||
if agent_id in self._instantiated_agents:
|
||||
return self._instantiated_agents[agent_id]
|
||||
|
||||
if agent_id.type not in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {agent_id.type} not found.")
|
||||
|
||||
agent_factory = self._agent_factories[agent_id.type]
|
||||
agent = await self._invoke_agent_factory(agent_factory, agent_id)
|
||||
self._instantiated_agents[agent_id] = agent
|
||||
return agent
|
||||
|
||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||
if id.type not in self._agent_factories:
|
||||
raise LookupError(f"Agent with name {id.type} not found.")
|
||||
|
||||
# TODO: check if remote
|
||||
agent_instance = await self._get_agent(id)
|
||||
|
||||
if not isinstance(agent_instance, type):
|
||||
raise TypeError(f"Agent with name {id.type} is not of type {type.__name__}")
|
||||
|
||||
return agent_instance
|
||||
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
|
||||
message = agent_worker_pb2.AddSubscriptionRequest(subscription=subscription_to_proto(subscription))
|
||||
_response: agent_worker_pb2.AddSubscriptionResponse = await self._host_connection.stub.AddSubscription(
|
||||
message, metadata=self._host_connection.metadata
|
||||
)
|
||||
|
||||
# Add to local subscription manager.
|
||||
await self._subscription_manager.add_subscription(subscription)
|
||||
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
|
||||
message = agent_worker_pb2.RemoveSubscriptionRequest(id=id)
|
||||
_response: agent_worker_pb2.RemoveSubscriptionResponse = await self._host_connection.stub.RemoveSubscription(
|
||||
message, metadata=self._host_connection.metadata
|
||||
)
|
||||
|
||||
await self._subscription_manager.remove_subscription(id)
|
||||
|
||||
async def get(
|
||||
self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True
|
||||
) -> AgentId:
|
||||
return await get_impl(
|
||||
id_or_type=id_or_type,
|
||||
key=key,
|
||||
lazy=lazy,
|
||||
instance_getter=self._get_agent,
|
||||
)
|
||||
|
||||
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
|
||||
self._serialization_registry.add_serializer(serializer)
|
||||
@@ -0,0 +1,73 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from ._constants import GRPC_IMPORT_ERROR_STR
|
||||
from ._type_helpers import ChannelArgumentType
|
||||
from ._worker_runtime_host_servicer import GrpcWorkerAgentRuntimeHostServicer
|
||||
|
||||
try:
|
||||
import grpc
|
||||
except ImportError as e:
|
||||
raise ImportError(GRPC_IMPORT_ERROR_STR) from e
|
||||
from .protos import agent_worker_pb2_grpc
|
||||
|
||||
logger = logging.getLogger("agentdhal_core")
|
||||
|
||||
|
||||
class GrpcWorkerAgentRuntimeHost:
|
||||
def __init__(self, address: str, extra_grpc_config: Optional[ChannelArgumentType] = None) -> None:
|
||||
self._server = grpc.aio.server(options=extra_grpc_config)
|
||||
self._servicer = GrpcWorkerAgentRuntimeHostServicer()
|
||||
agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(self._servicer, self._server)
|
||||
self._server.add_insecure_port(address)
|
||||
self._address = address
|
||||
self._serve_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def _serve(self) -> None:
|
||||
await self._server.start()
|
||||
logger.info(f"Server started at {self._address}.")
|
||||
await self._server.wait_for_termination()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the server in a background task."""
|
||||
if self._serve_task is not None:
|
||||
raise RuntimeError("Host runtime is already started.")
|
||||
self._serve_task = asyncio.create_task(self._serve())
|
||||
|
||||
async def stop(self, grace: int = 5) -> None:
|
||||
"""Stop the server."""
|
||||
if self._serve_task is None:
|
||||
raise RuntimeError("Host runtime is not started.")
|
||||
await self._server.stop(grace=grace)
|
||||
self._serve_task.cancel()
|
||||
try:
|
||||
await self._serve_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Server stopped.")
|
||||
self._serve_task = None
|
||||
|
||||
async def stop_when_signal(
|
||||
self, grace: int = 5, signals: Sequence[signal.Signals] = (signal.SIGTERM, signal.SIGINT)
|
||||
) -> None:
|
||||
"""Stop the server when a signal is received."""
|
||||
if self._serve_task is None:
|
||||
raise RuntimeError("Host runtime is not started.")
|
||||
# Set up signal handling for graceful shutdown.
|
||||
loop = asyncio.get_running_loop()
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
def signal_handler() -> None:
|
||||
logger.info("Received exit signal, shutting down gracefully...")
|
||||
shutdown_event.set()
|
||||
|
||||
for sig in signals:
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
# Wait for the signal to trigger the shutdown event.
|
||||
await shutdown_event.wait()
|
||||
|
||||
# Shutdown the server.
|
||||
await self.stop(grace=grace)
|
||||
@@ -0,0 +1,364 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Future, Task
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Sequence, Set, Tuple, TypeVar
|
||||
|
||||
from agentdhal_core import TopicId
|
||||
from agentdhal_core._agent_id import AgentId
|
||||
from agentdhal_core._runtime_impl_helpers import SubscriptionManager
|
||||
|
||||
from ._constants import GRPC_IMPORT_ERROR_STR
|
||||
from ._utils import subscription_from_proto, subscription_to_proto
|
||||
|
||||
try:
|
||||
import grpc
|
||||
except ImportError as e:
|
||||
raise ImportError(GRPC_IMPORT_ERROR_STR) from e
|
||||
|
||||
from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
|
||||
|
||||
logger = logging.getLogger("agentdhal_core")
|
||||
event_logger = logging.getLogger("agentdhal_core.events")
|
||||
|
||||
ClientConnectionId = str
|
||||
|
||||
|
||||
def metadata_to_dict(metadata: Sequence[Tuple[str, str]] | None) -> Dict[str, str]:
|
||||
if metadata is None:
|
||||
return {}
|
||||
return {key: value for key, value in metadata}
|
||||
|
||||
|
||||
async def get_client_id_or_abort(context: grpc.aio.ServicerContext[Any, Any]) -> str: # type: ignore
|
||||
# The type hint on context.invocation_metadata() is incorrect.
|
||||
metadata = metadata_to_dict(context.invocation_metadata()) # type: ignore
|
||||
if (client_id := metadata.get("client-id")) is None:
|
||||
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, "client-id metadata not found.")
|
||||
|
||||
return client_id # type: ignore
|
||||
|
||||
|
||||
SendT = TypeVar("SendT")
|
||||
ReceiveT = TypeVar("ReceiveT")
|
||||
|
||||
|
||||
class ChannelConnection(ABC, Generic[SendT, ReceiveT]):
|
||||
def __init__(self, request_iterator: AsyncIterator[ReceiveT], client_id: str) -> None:
|
||||
self._request_iterator = request_iterator
|
||||
self._client_id = client_id
|
||||
self._send_queue: asyncio.Queue[SendT] = asyncio.Queue()
|
||||
self._receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator))
|
||||
|
||||
async def _receive_messages(self, client_id: ClientConnectionId, request_iterator: AsyncIterator[ReceiveT]) -> None:
|
||||
# Receive messages from the client and process them.
|
||||
async for message in request_iterator:
|
||||
logger.info(f"Received message from client {client_id}: {message}")
|
||||
await self._handle_message(message)
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[SendT]:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> SendT:
|
||||
try:
|
||||
return await self._send_queue.get()
|
||||
except StopAsyncIteration:
|
||||
await self._receiving_task
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get message from send queue: {e}", exc_info=True)
|
||||
await self._receiving_task
|
||||
raise
|
||||
|
||||
@abstractmethod
|
||||
async def _handle_message(self, message: ReceiveT) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, message: SendT) -> None:
|
||||
await self._send_queue.put(message)
|
||||
|
||||
|
||||
class CallbackChannelConnection(ChannelConnection[SendT, ReceiveT]):
|
||||
def __init__(
|
||||
self,
|
||||
request_iterator: AsyncIterator[ReceiveT],
|
||||
client_id: str,
|
||||
handle_callback: Callable[[ReceiveT], Awaitable[None]],
|
||||
) -> None:
|
||||
self._handle_callback = handle_callback
|
||||
super().__init__(request_iterator, client_id)
|
||||
|
||||
async def _handle_message(self, message: ReceiveT) -> None:
|
||||
await self._handle_callback(message)
|
||||
|
||||
|
||||
class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||
"""A gRPC servicer that hosts message delivery service for agents."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._data_connections: Dict[
|
||||
ClientConnectionId, ChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message]
|
||||
] = {}
|
||||
self._control_connections: Dict[
|
||||
ClientConnectionId, ChannelConnection[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage]
|
||||
] = {}
|
||||
self._agent_type_to_client_id_lock = asyncio.Lock()
|
||||
self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {}
|
||||
self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {}
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._client_id_to_subscription_id_mapping: Dict[ClientConnectionId, set[str]] = {}
|
||||
|
||||
async def OpenChannel( # type: ignore
|
||||
self,
|
||||
request_iterator: AsyncIterator[agent_worker_pb2.Message],
|
||||
context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message],
|
||||
) -> AsyncIterator[agent_worker_pb2.Message]:
|
||||
client_id = await get_client_id_or_abort(context)
|
||||
|
||||
async def handle_callback(message: agent_worker_pb2.Message) -> None:
|
||||
await self._receive_message(client_id, message)
|
||||
|
||||
connection = CallbackChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message](
|
||||
request_iterator, client_id, handle_callback=handle_callback
|
||||
)
|
||||
self._data_connections[client_id] = connection
|
||||
logger.info(f"Client {client_id} connected.")
|
||||
|
||||
try:
|
||||
async for message in connection:
|
||||
yield message
|
||||
finally:
|
||||
# Clean up the client connection.
|
||||
del self._data_connections[client_id]
|
||||
# Cancel pending requests sent to this client.
|
||||
for future in self._pending_responses.pop(client_id, {}).values():
|
||||
future.cancel()
|
||||
# Remove the client id from the agent type to client id mapping.
|
||||
await self._on_client_disconnect(client_id)
|
||||
|
||||
async def OpenControlChannel( # type: ignore
|
||||
self,
|
||||
request_iterator: AsyncIterator[agent_worker_pb2.ControlMessage],
|
||||
context: grpc.aio.ServicerContext[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage],
|
||||
) -> AsyncIterator[agent_worker_pb2.ControlMessage]:
|
||||
client_id = await get_client_id_or_abort(context)
|
||||
|
||||
async def handle_callback(message: agent_worker_pb2.ControlMessage) -> None:
|
||||
await self._receive_control_message(client_id, message)
|
||||
|
||||
connection = CallbackChannelConnection[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage](
|
||||
request_iterator, client_id, handle_callback=handle_callback
|
||||
)
|
||||
self._control_connections[client_id] = connection
|
||||
logger.info(f"Client {client_id} connected.")
|
||||
|
||||
try:
|
||||
async for message in connection:
|
||||
yield message
|
||||
finally:
|
||||
# Clean up the client connection.
|
||||
del self._control_connections[client_id]
|
||||
|
||||
async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None:
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id]
|
||||
for agent_type in agent_types:
|
||||
logger.info(f"Removing agent type {agent_type} from agent type to client id mapping")
|
||||
del self._agent_type_to_client_id[agent_type]
|
||||
for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, set()):
|
||||
logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}")
|
||||
try:
|
||||
await self._subscription_manager.remove_subscription(sub_id)
|
||||
# Catch and ignore if the subscription does not exist.
|
||||
except ValueError:
|
||||
continue
|
||||
logger.info(f"Client {client_id} disconnected successfully")
|
||||
|
||||
def _raise_on_exception(self, task: Task[Any]) -> None:
|
||||
exception = task.exception()
|
||||
if exception is not None:
|
||||
raise exception
|
||||
|
||||
async def _receive_message(self, client_id: ClientConnectionId, message: agent_worker_pb2.Message) -> None:
|
||||
logger.info(f"Received message from client {client_id}: {message}")
|
||||
oneofcase = message.WhichOneof("message")
|
||||
match oneofcase:
|
||||
case "request":
|
||||
request: agent_worker_pb2.RpcRequest = message.request
|
||||
task = asyncio.create_task(self._process_request(request, client_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "response":
|
||||
response: agent_worker_pb2.RpcResponse = message.response
|
||||
task = asyncio.create_task(self._process_response(response, client_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "cloudEvent":
|
||||
task = asyncio.create_task(self._process_event(message.cloudEvent))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case None:
|
||||
logger.warning("Received empty message")
|
||||
|
||||
async def _receive_control_message(
|
||||
self, client_id: ClientConnectionId, message: agent_worker_pb2.ControlMessage
|
||||
) -> None:
|
||||
logger.info(f"Received message from client {client_id}: {message}")
|
||||
destination = message.destination
|
||||
if destination.startswith("agentid="):
|
||||
agent_id = AgentId.from_str(destination[len("agentid=") :])
|
||||
target_client_id = self._agent_type_to_client_id.get(agent_id.type)
|
||||
if target_client_id is None:
|
||||
logger.error(f"Agent client id not found for agent type {agent_id.type}.")
|
||||
return
|
||||
elif destination.startswith("clientid="):
|
||||
target_client_id = destination[len("clientid=") :]
|
||||
else:
|
||||
logger.error(f"Invalid destination {destination}")
|
||||
return
|
||||
|
||||
target_send_queue = self._control_connections.get(target_client_id)
|
||||
if target_send_queue is None:
|
||||
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
|
||||
return
|
||||
await target_send_queue.send(message)
|
||||
|
||||
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None:
|
||||
# Deliver the message to a client given the target agent type.
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
target_client_id = self._agent_type_to_client_id.get(request.target.type)
|
||||
if target_client_id is None:
|
||||
logger.error(f"Agent {request.target.type} not found, failed to deliver message.")
|
||||
return
|
||||
target_send_queue = self._data_connections.get(target_client_id)
|
||||
if target_send_queue is None:
|
||||
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
|
||||
return
|
||||
await target_send_queue.send(agent_worker_pb2.Message(request=request))
|
||||
|
||||
# Create a future to wait for the response from the target.
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self._pending_responses.setdefault(target_client_id, {})[request.request_id] = future
|
||||
|
||||
# Create a task to wait for the response and send it back to the client.
|
||||
send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id))
|
||||
self._background_tasks.add(send_response_task)
|
||||
send_response_task.add_done_callback(self._raise_on_exception)
|
||||
send_response_task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
async def _wait_and_send_response(
|
||||
self, future: Future[agent_worker_pb2.RpcResponse], client_id: ClientConnectionId
|
||||
) -> None:
|
||||
response = await future
|
||||
message = agent_worker_pb2.Message(response=response)
|
||||
send_queue = self._data_connections.get(client_id)
|
||||
if send_queue is None:
|
||||
logger.error(f"Client {client_id} not found, failed to send response message.")
|
||||
return
|
||||
await send_queue.send(message)
|
||||
|
||||
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: ClientConnectionId) -> None:
|
||||
# Setting the result of the future will send the response back to the original sender.
|
||||
future = self._pending_responses[client_id].pop(response.request_id)
|
||||
future.set_result(response)
|
||||
|
||||
async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
|
||||
topic_id = TopicId(type=event.type, source=event.source)
|
||||
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
|
||||
# Get the client ids of the recipients.
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
client_ids: Set[ClientConnectionId] = set()
|
||||
for recipient in recipients:
|
||||
client_id = self._agent_type_to_client_id.get(recipient.type)
|
||||
if client_id is not None:
|
||||
client_ids.add(client_id)
|
||||
else:
|
||||
logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.")
|
||||
# Deliver the event to clients.
|
||||
for client_id in client_ids:
|
||||
await self._data_connections[client_id].send(agent_worker_pb2.Message(cloudEvent=event))
|
||||
|
||||
async def RegisterAgent( # type: ignore
|
||||
self,
|
||||
request: agent_worker_pb2.RegisterAgentTypeRequest,
|
||||
context: grpc.aio.ServicerContext[
|
||||
agent_worker_pb2.RegisterAgentTypeRequest, agent_worker_pb2.RegisterAgentTypeResponse
|
||||
],
|
||||
) -> agent_worker_pb2.RegisterAgentTypeResponse:
|
||||
client_id = await get_client_id_or_abort(context)
|
||||
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
if request.type in self._agent_type_to_client_id:
|
||||
existing_client_id = self._agent_type_to_client_id[request.type]
|
||||
await context.abort(
|
||||
grpc.StatusCode.INVALID_ARGUMENT,
|
||||
f"Agent type {request.type} already registered with client {existing_client_id}.",
|
||||
)
|
||||
else:
|
||||
self._agent_type_to_client_id[request.type] = client_id
|
||||
|
||||
return agent_worker_pb2.RegisterAgentTypeResponse()
|
||||
|
||||
async def AddSubscription( # type: ignore
|
||||
self,
|
||||
request: agent_worker_pb2.AddSubscriptionRequest,
|
||||
context: grpc.aio.ServicerContext[
|
||||
agent_worker_pb2.AddSubscriptionRequest, agent_worker_pb2.AddSubscriptionResponse
|
||||
],
|
||||
) -> agent_worker_pb2.AddSubscriptionResponse:
|
||||
client_id = await get_client_id_or_abort(context)
|
||||
|
||||
subscription = subscription_from_proto(request.subscription)
|
||||
try:
|
||||
await self._subscription_manager.add_subscription(subscription)
|
||||
subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set())
|
||||
subscription_ids.add(subscription.id)
|
||||
except ValueError as e:
|
||||
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
||||
return agent_worker_pb2.AddSubscriptionResponse()
|
||||
|
||||
async def RemoveSubscription( # type: ignore
|
||||
self,
|
||||
request: agent_worker_pb2.RemoveSubscriptionRequest,
|
||||
context: grpc.aio.ServicerContext[
|
||||
agent_worker_pb2.RemoveSubscriptionRequest, agent_worker_pb2.RemoveSubscriptionResponse
|
||||
],
|
||||
) -> agent_worker_pb2.RemoveSubscriptionResponse:
|
||||
_client_id = await get_client_id_or_abort(context)
|
||||
await self._subscription_manager.remove_subscription(request.id)
|
||||
return agent_worker_pb2.RemoveSubscriptionResponse()
|
||||
|
||||
async def GetSubscriptions( # type: ignore
|
||||
self,
|
||||
request: agent_worker_pb2.GetSubscriptionsRequest,
|
||||
context: grpc.aio.ServicerContext[
|
||||
agent_worker_pb2.GetSubscriptionsRequest, agent_worker_pb2.GetSubscriptionsResponse
|
||||
],
|
||||
) -> agent_worker_pb2.GetSubscriptionsResponse:
|
||||
_client_id = await get_client_id_or_abort(context)
|
||||
subscriptions = self._subscription_manager.subscriptions
|
||||
return agent_worker_pb2.GetSubscriptionsResponse(
|
||||
subscriptions=[subscription_to_proto(sub) for sub in subscriptions]
|
||||
)
|
||||
|
||||
# async def GetState( # type: ignore
|
||||
# self,
|
||||
# request: agent_worker_pb2.AgentId,
|
||||
# context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.GetStateResponse],
|
||||
# ) -> agent_worker_pb2.GetStateResponse:
|
||||
# _client_id = await get_client_id_or_abort(context)
|
||||
# raise NotImplementedError("Method not implemented!")
|
||||
|
||||
# async def SaveState( # type: ignore
|
||||
# self,
|
||||
# request: agent_worker_pb2.AgentState,
|
||||
# context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.SaveStateResponse],
|
||||
# ) -> agent_worker_pb2.SaveStateResponse:
|
||||
# _client_id = await get_client_id_or_abort(context)
|
||||
# raise NotImplementedError("Method not implemented!")
|
||||
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
The :mod:`agentdhal_extensions.runtimes.grpc.protos` module provides Google Protobuf classes for agent-worker communication
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: agent_worker.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'agent_worker.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
from . import cloudevent_pb2 as cloudevent__pb2
|
||||
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"(\n\x18RegisterAgentTypeRequest\x12\x0c\n\x04type\x18\x01 \x01(\t\"\x1b\n\x19RegisterAgentTypeResponse\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\xa2\x01\n\x0cSubscription\x12\n\n\x02id\x18\x01 \x01(\t\x12\x34\n\x10typeSubscription\x18\x02 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x03 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"D\n\x16\x41\x64\x64SubscriptionRequest\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\x19\n\x17\x41\x64\x64SubscriptionResponse\"\'\n\x19RemoveSubscriptionRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\x1c\n\x1aRemoveSubscriptionResponse\"\x19\n\x17GetSubscriptionsRequest\"G\n\x18GetSubscriptionsResponse\x12+\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x14.agents.Subscription\"\x99\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x33\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x1d.io.cloudevents.v1.CloudEventH\x00\x42\t\n\x07message\"4\n\x10SaveStateRequest\x12 \n\x07\x61gentId\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\"@\n\x11SaveStateResponse\x12\r\n\x05state\x18\x01 \x01(\t\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"C\n\x10LoadStateRequest\x12 \n\x07\x61gentId\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\r\n\x05state\x18\x02 \x01(\t\"1\n\x11LoadStateResponse\x12\x12\n\x05\x65rror\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x87\x01\n\x0e\x43ontrolMessage\x12\x0e\n\x06rpc_id\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65stination\x18\x02 \x01(\t\x12\x17\n\nrespond_to\x18\x03 \x01(\tH\x00\x88\x01\x01\x12(\n\nrpcMessage\x18\x04 \x01(\x0b\x32\x14.google.protobuf.AnyB\r\n\x0b_respond_to2\xe7\x03\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12H\n\x12OpenControlChannel\x12\x16.agents.ControlMessage\x1a\x16.agents.ControlMessage(\x01\x30\x01\x12T\n\rRegisterAgent\x12 .agents.RegisterAgentTypeRequest\x1a!.agents.RegisterAgentTypeResponse\x12R\n\x0f\x41\x64\x64Subscription\x12\x1e.agents.AddSubscriptionRequest\x1a\x1f.agents.AddSubscriptionResponse\x12[\n\x12RemoveSubscription\x12!.agents.RemoveSubscriptionRequest\x1a\".agents.RemoveSubscriptionResponse\x12U\n\x10GetSubscriptions\x12\x1f.agents.GetSubscriptionsRequest\x1a .agents.GetSubscriptionsResponseB\x1d\xaa\x02\x1aMicrosoft.AutoGen.Protobufb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'agent_worker_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
_globals['DESCRIPTOR']._loaded_options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'\252\002\032Microsoft.AutoGen.Protobuf'
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._loaded_options = None
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_options = b'8\001'
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._loaded_options = None
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_options = b'8\001'
|
||||
_globals['_AGENTID']._serialized_start=75
|
||||
_globals['_AGENTID']._serialized_end=111
|
||||
_globals['_PAYLOAD']._serialized_start=113
|
||||
_globals['_PAYLOAD']._serialized_end=182
|
||||
_globals['_RPCREQUEST']._serialized_start=185
|
||||
_globals['_RPCREQUEST']._serialized_end=450
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=392
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=439
|
||||
_globals['_RPCRESPONSE']._serialized_start=453
|
||||
_globals['_RPCRESPONSE']._serialized_end=637
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=392
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=439
|
||||
_globals['_REGISTERAGENTTYPEREQUEST']._serialized_start=639
|
||||
_globals['_REGISTERAGENTTYPEREQUEST']._serialized_end=679
|
||||
_globals['_REGISTERAGENTTYPERESPONSE']._serialized_start=681
|
||||
_globals['_REGISTERAGENTTYPERESPONSE']._serialized_end=708
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_start=710
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_end=768
|
||||
_globals['_TYPEPREFIXSUBSCRIPTION']._serialized_start=770
|
||||
_globals['_TYPEPREFIXSUBSCRIPTION']._serialized_end=841
|
||||
_globals['_SUBSCRIPTION']._serialized_start=844
|
||||
_globals['_SUBSCRIPTION']._serialized_end=1006
|
||||
_globals['_ADDSUBSCRIPTIONREQUEST']._serialized_start=1008
|
||||
_globals['_ADDSUBSCRIPTIONREQUEST']._serialized_end=1076
|
||||
_globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_start=1078
|
||||
_globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_end=1103
|
||||
_globals['_REMOVESUBSCRIPTIONREQUEST']._serialized_start=1105
|
||||
_globals['_REMOVESUBSCRIPTIONREQUEST']._serialized_end=1144
|
||||
_globals['_REMOVESUBSCRIPTIONRESPONSE']._serialized_start=1146
|
||||
_globals['_REMOVESUBSCRIPTIONRESPONSE']._serialized_end=1174
|
||||
_globals['_GETSUBSCRIPTIONSREQUEST']._serialized_start=1176
|
||||
_globals['_GETSUBSCRIPTIONSREQUEST']._serialized_end=1201
|
||||
_globals['_GETSUBSCRIPTIONSRESPONSE']._serialized_start=1203
|
||||
_globals['_GETSUBSCRIPTIONSRESPONSE']._serialized_end=1274
|
||||
_globals['_MESSAGE']._serialized_start=1277
|
||||
_globals['_MESSAGE']._serialized_end=1430
|
||||
_globals['_SAVESTATEREQUEST']._serialized_start=1432
|
||||
_globals['_SAVESTATEREQUEST']._serialized_end=1484
|
||||
_globals['_SAVESTATERESPONSE']._serialized_start=1486
|
||||
_globals['_SAVESTATERESPONSE']._serialized_end=1550
|
||||
_globals['_LOADSTATEREQUEST']._serialized_start=1552
|
||||
_globals['_LOADSTATEREQUEST']._serialized_end=1619
|
||||
_globals['_LOADSTATERESPONSE']._serialized_start=1621
|
||||
_globals['_LOADSTATERESPONSE']._serialized_end=1670
|
||||
_globals['_CONTROLMESSAGE']._serialized_start=1673
|
||||
_globals['_CONTROLMESSAGE']._serialized_end=1808
|
||||
_globals['_AGENTRPC']._serialized_start=1811
|
||||
_globals['_AGENTRPC']._serialized_end=2298
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
|
||||
import builtins
|
||||
from . import cloudevent_pb2
|
||||
import collections.abc
|
||||
import google.protobuf.any_pb2
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.message
|
||||
import typing
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class AgentId(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
type: builtins.str
|
||||
key: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
type: builtins.str = ...,
|
||||
key: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "type", b"type"]) -> None: ...
|
||||
|
||||
global___AgentId = AgentId
|
||||
|
||||
@typing.final
|
||||
class Payload(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
DATA_TYPE_FIELD_NUMBER: builtins.int
|
||||
DATA_CONTENT_TYPE_FIELD_NUMBER: builtins.int
|
||||
DATA_FIELD_NUMBER: builtins.int
|
||||
data_type: builtins.str
|
||||
data_content_type: builtins.str
|
||||
data: builtins.bytes
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
data_type: builtins.str = ...,
|
||||
data_content_type: builtins.str = ...,
|
||||
data: builtins.bytes = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "data_content_type", b"data_content_type", "data_type", b"data_type"]) -> None: ...
|
||||
|
||||
global___Payload = Payload
|
||||
|
||||
@typing.final
|
||||
class RpcRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@typing.final
|
||||
class MetadataEntry(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
key: builtins.str
|
||||
value: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.str = ...,
|
||||
value: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
REQUEST_ID_FIELD_NUMBER: builtins.int
|
||||
SOURCE_FIELD_NUMBER: builtins.int
|
||||
TARGET_FIELD_NUMBER: builtins.int
|
||||
METHOD_FIELD_NUMBER: builtins.int
|
||||
PAYLOAD_FIELD_NUMBER: builtins.int
|
||||
METADATA_FIELD_NUMBER: builtins.int
|
||||
request_id: builtins.str
|
||||
method: builtins.str
|
||||
@property
|
||||
def source(self) -> global___AgentId: ...
|
||||
@property
|
||||
def target(self) -> global___AgentId: ...
|
||||
@property
|
||||
def payload(self) -> global___Payload: ...
|
||||
@property
|
||||
def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_id: builtins.str = ...,
|
||||
source: global___AgentId | None = ...,
|
||||
target: global___AgentId | None = ...,
|
||||
method: builtins.str = ...,
|
||||
payload: global___Payload | None = ...,
|
||||
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ...
|
||||
|
||||
global___RpcRequest = RpcRequest
|
||||
|
||||
@typing.final
|
||||
class RpcResponse(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@typing.final
|
||||
class MetadataEntry(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
key: builtins.str
|
||||
value: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.str = ...,
|
||||
value: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
REQUEST_ID_FIELD_NUMBER: builtins.int
|
||||
PAYLOAD_FIELD_NUMBER: builtins.int
|
||||
ERROR_FIELD_NUMBER: builtins.int
|
||||
METADATA_FIELD_NUMBER: builtins.int
|
||||
request_id: builtins.str
|
||||
error: builtins.str
|
||||
@property
|
||||
def payload(self) -> global___Payload: ...
|
||||
@property
|
||||
def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_id: builtins.str = ...,
|
||||
payload: global___Payload | None = ...,
|
||||
error: builtins.str = ...,
|
||||
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "payload", b"payload", "request_id", b"request_id"]) -> None: ...
|
||||
|
||||
global___RpcResponse = RpcResponse
|
||||
|
||||
@typing.final
|
||||
class RegisterAgentTypeRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
type: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
type: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["type", b"type"]) -> None: ...
|
||||
|
||||
global___RegisterAgentTypeRequest = RegisterAgentTypeRequest
|
||||
|
||||
@typing.final
|
||||
class RegisterAgentTypeResponse(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None: ...
|
||||
|
||||
global___RegisterAgentTypeResponse = RegisterAgentTypeResponse
|
||||
|
||||
@typing.final
|
||||
class TypeSubscription(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TOPIC_TYPE_FIELD_NUMBER: builtins.int
|
||||
AGENT_TYPE_FIELD_NUMBER: builtins.int
|
||||
topic_type: builtins.str
|
||||
agent_type: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
topic_type: builtins.str = ...,
|
||||
agent_type: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type", b"topic_type"]) -> None: ...
|
||||
|
||||
global___TypeSubscription = TypeSubscription
|
||||
|
||||
@typing.final
|
||||
class TypePrefixSubscription(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TOPIC_TYPE_PREFIX_FIELD_NUMBER: builtins.int
|
||||
AGENT_TYPE_FIELD_NUMBER: builtins.int
|
||||
topic_type_prefix: builtins.str
|
||||
agent_type: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
topic_type_prefix: builtins.str = ...,
|
||||
agent_type: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type_prefix", b"topic_type_prefix"]) -> None: ...
|
||||
|
||||
global___TypePrefixSubscription = TypePrefixSubscription
|
||||
|
||||
@typing.final
|
||||
class Subscription(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
ID_FIELD_NUMBER: builtins.int
|
||||
TYPESUBSCRIPTION_FIELD_NUMBER: builtins.int
|
||||
TYPEPREFIXSUBSCRIPTION_FIELD_NUMBER: builtins.int
|
||||
id: builtins.str
|
||||
@property
|
||||
def typeSubscription(self) -> global___TypeSubscription: ...
|
||||
@property
|
||||
def typePrefixSubscription(self) -> global___TypePrefixSubscription: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: builtins.str = ...,
|
||||
typeSubscription: global___TypeSubscription | None = ...,
|
||||
typePrefixSubscription: global___TypePrefixSubscription | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["subscription", b"subscription", "typePrefixSubscription", b"typePrefixSubscription", "typeSubscription", b"typeSubscription"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["id", b"id", "subscription", b"subscription", "typePrefixSubscription", b"typePrefixSubscription", "typeSubscription", b"typeSubscription"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["subscription", b"subscription"]) -> typing.Literal["typeSubscription", "typePrefixSubscription"] | None: ...
|
||||
|
||||
global___Subscription = Subscription
|
||||
|
||||
@typing.final
|
||||
class AddSubscriptionRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
SUBSCRIPTION_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def subscription(self) -> global___Subscription: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
subscription: global___Subscription | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["subscription", b"subscription"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["subscription", b"subscription"]) -> None: ...
|
||||
|
||||
global___AddSubscriptionRequest = AddSubscriptionRequest
|
||||
|
||||
@typing.final
|
||||
class AddSubscriptionResponse(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None: ...
|
||||
|
||||
global___AddSubscriptionResponse = AddSubscriptionResponse
|
||||
|
||||
@typing.final
|
||||
class RemoveSubscriptionRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
ID_FIELD_NUMBER: builtins.int
|
||||
id: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["id", b"id"]) -> None: ...
|
||||
|
||||
global___RemoveSubscriptionRequest = RemoveSubscriptionRequest
|
||||
|
||||
@typing.final
|
||||
class RemoveSubscriptionResponse(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None: ...
|
||||
|
||||
global___RemoveSubscriptionResponse = RemoveSubscriptionResponse
|
||||
|
||||
@typing.final
|
||||
class GetSubscriptionsRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None: ...
|
||||
|
||||
global___GetSubscriptionsRequest = GetSubscriptionsRequest
|
||||
|
||||
@typing.final
|
||||
class GetSubscriptionsResponse(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
SUBSCRIPTIONS_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def subscriptions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Subscription]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
subscriptions: collections.abc.Iterable[global___Subscription] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["subscriptions", b"subscriptions"]) -> None: ...
|
||||
|
||||
global___GetSubscriptionsResponse = GetSubscriptionsResponse
|
||||
|
||||
@typing.final
|
||||
class Message(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
REQUEST_FIELD_NUMBER: builtins.int
|
||||
RESPONSE_FIELD_NUMBER: builtins.int
|
||||
CLOUDEVENT_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def request(self) -> global___RpcRequest: ...
|
||||
@property
|
||||
def response(self) -> global___RpcResponse: ...
|
||||
@property
|
||||
def cloudEvent(self) -> cloudevent_pb2.CloudEvent: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request: global___RpcRequest | None = ...,
|
||||
response: global___RpcResponse | None = ...,
|
||||
cloudEvent: cloudevent_pb2.CloudEvent | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["cloudEvent", b"cloudEvent", "message", b"message", "request", b"request", "response", b"response"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["cloudEvent", b"cloudEvent", "message", b"message", "request", b"request", "response", b"response"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "cloudEvent"] | None: ...
|
||||
|
||||
global___Message = Message
|
||||
|
||||
@typing.final
|
||||
class SaveStateRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
AGENTID_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def agentId(self) -> global___AgentId: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agentId: global___AgentId | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["agentId", b"agentId"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["agentId", b"agentId"]) -> None: ...
|
||||
|
||||
global___SaveStateRequest = SaveStateRequest
|
||||
|
||||
@typing.final
|
||||
class SaveStateResponse(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
STATE_FIELD_NUMBER: builtins.int
|
||||
ERROR_FIELD_NUMBER: builtins.int
|
||||
state: builtins.str
|
||||
error: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
state: builtins.str = ...,
|
||||
error: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "state", b"state"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
|
||||
|
||||
global___SaveStateResponse = SaveStateResponse
|
||||
|
||||
@typing.final
|
||||
class LoadStateRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
AGENTID_FIELD_NUMBER: builtins.int
|
||||
STATE_FIELD_NUMBER: builtins.int
|
||||
state: builtins.str
|
||||
@property
|
||||
def agentId(self) -> global___AgentId: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agentId: global___AgentId | None = ...,
|
||||
state: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["agentId", b"agentId"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["agentId", b"agentId", "state", b"state"]) -> None: ...
|
||||
|
||||
global___LoadStateRequest = LoadStateRequest
|
||||
|
||||
@typing.final
|
||||
class LoadStateResponse(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
ERROR_FIELD_NUMBER: builtins.int
|
||||
error: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
error: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
|
||||
|
||||
global___LoadStateResponse = LoadStateResponse
|
||||
|
||||
@typing.final
|
||||
class ControlMessage(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
RPC_ID_FIELD_NUMBER: builtins.int
|
||||
DESTINATION_FIELD_NUMBER: builtins.int
|
||||
RESPOND_TO_FIELD_NUMBER: builtins.int
|
||||
RPCMESSAGE_FIELD_NUMBER: builtins.int
|
||||
rpc_id: builtins.str
|
||||
"""A response message should have the same id as the request message"""
|
||||
destination: builtins.str
|
||||
"""This is either:
|
||||
agentid=AGENT_ID
|
||||
clientid=CLIENT_ID
|
||||
"""
|
||||
respond_to: builtins.str
|
||||
"""This is either:
|
||||
agentid=AGENT_ID
|
||||
clientid=CLIENT_ID
|
||||
Empty string means the message is a response
|
||||
"""
|
||||
@property
|
||||
def rpcMessage(self) -> google.protobuf.any_pb2.Any:
|
||||
"""One of:
|
||||
SaveStateRequest saveStateRequest = 2;
|
||||
SaveStateResponse saveStateResponse = 3;
|
||||
LoadStateRequest loadStateRequest = 4;
|
||||
LoadStateResponse loadStateResponse = 5;
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
rpc_id: builtins.str = ...,
|
||||
destination: builtins.str = ...,
|
||||
respond_to: builtins.str | None = ...,
|
||||
rpcMessage: google.protobuf.any_pb2.Any | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_respond_to", b"_respond_to", "respond_to", b"respond_to", "rpcMessage", b"rpcMessage"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_respond_to", b"_respond_to", "destination", b"destination", "respond_to", b"respond_to", "rpcMessage", b"rpcMessage", "rpc_id", b"rpc_id"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_respond_to", b"_respond_to"]) -> typing.Literal["respond_to"] | None: ...
|
||||
|
||||
global___ControlMessage = ControlMessage
|
||||
@@ -0,0 +1,312 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from . import agent_worker_pb2 as agent__worker__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.70.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in agent_worker_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class AgentRpcStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.OpenChannel = channel.stream_stream(
|
||||
'/agents.AgentRpc/OpenChannel',
|
||||
request_serializer=agent__worker__pb2.Message.SerializeToString,
|
||||
response_deserializer=agent__worker__pb2.Message.FromString,
|
||||
_registered_method=True)
|
||||
self.OpenControlChannel = channel.stream_stream(
|
||||
'/agents.AgentRpc/OpenControlChannel',
|
||||
request_serializer=agent__worker__pb2.ControlMessage.SerializeToString,
|
||||
response_deserializer=agent__worker__pb2.ControlMessage.FromString,
|
||||
_registered_method=True)
|
||||
self.RegisterAgent = channel.unary_unary(
|
||||
'/agents.AgentRpc/RegisterAgent',
|
||||
request_serializer=agent__worker__pb2.RegisterAgentTypeRequest.SerializeToString,
|
||||
response_deserializer=agent__worker__pb2.RegisterAgentTypeResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.AddSubscription = channel.unary_unary(
|
||||
'/agents.AgentRpc/AddSubscription',
|
||||
request_serializer=agent__worker__pb2.AddSubscriptionRequest.SerializeToString,
|
||||
response_deserializer=agent__worker__pb2.AddSubscriptionResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.RemoveSubscription = channel.unary_unary(
|
||||
'/agents.AgentRpc/RemoveSubscription',
|
||||
request_serializer=agent__worker__pb2.RemoveSubscriptionRequest.SerializeToString,
|
||||
response_deserializer=agent__worker__pb2.RemoveSubscriptionResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.GetSubscriptions = channel.unary_unary(
|
||||
'/agents.AgentRpc/GetSubscriptions',
|
||||
request_serializer=agent__worker__pb2.GetSubscriptionsRequest.SerializeToString,
|
||||
response_deserializer=agent__worker__pb2.GetSubscriptionsResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class AgentRpcServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def OpenChannel(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def OpenControlChannel(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def RegisterAgent(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def AddSubscription(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def RemoveSubscription(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetSubscriptions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_AgentRpcServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'OpenChannel': grpc.stream_stream_rpc_method_handler(
|
||||
servicer.OpenChannel,
|
||||
request_deserializer=agent__worker__pb2.Message.FromString,
|
||||
response_serializer=agent__worker__pb2.Message.SerializeToString,
|
||||
),
|
||||
'OpenControlChannel': grpc.stream_stream_rpc_method_handler(
|
||||
servicer.OpenControlChannel,
|
||||
request_deserializer=agent__worker__pb2.ControlMessage.FromString,
|
||||
response_serializer=agent__worker__pb2.ControlMessage.SerializeToString,
|
||||
),
|
||||
'RegisterAgent': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.RegisterAgent,
|
||||
request_deserializer=agent__worker__pb2.RegisterAgentTypeRequest.FromString,
|
||||
response_serializer=agent__worker__pb2.RegisterAgentTypeResponse.SerializeToString,
|
||||
),
|
||||
'AddSubscription': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.AddSubscription,
|
||||
request_deserializer=agent__worker__pb2.AddSubscriptionRequest.FromString,
|
||||
response_serializer=agent__worker__pb2.AddSubscriptionResponse.SerializeToString,
|
||||
),
|
||||
'RemoveSubscription': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.RemoveSubscription,
|
||||
request_deserializer=agent__worker__pb2.RemoveSubscriptionRequest.FromString,
|
||||
response_serializer=agent__worker__pb2.RemoveSubscriptionResponse.SerializeToString,
|
||||
),
|
||||
'GetSubscriptions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetSubscriptions,
|
||||
request_deserializer=agent__worker__pb2.GetSubscriptionsRequest.FromString,
|
||||
response_serializer=agent__worker__pb2.GetSubscriptionsResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'agents.AgentRpc', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('agents.AgentRpc', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class AgentRpc(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def OpenChannel(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_stream(
|
||||
request_iterator,
|
||||
target,
|
||||
'/agents.AgentRpc/OpenChannel',
|
||||
agent__worker__pb2.Message.SerializeToString,
|
||||
agent__worker__pb2.Message.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def OpenControlChannel(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_stream(
|
||||
request_iterator,
|
||||
target,
|
||||
'/agents.AgentRpc/OpenControlChannel',
|
||||
agent__worker__pb2.ControlMessage.SerializeToString,
|
||||
agent__worker__pb2.ControlMessage.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def RegisterAgent(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/agents.AgentRpc/RegisterAgent',
|
||||
agent__worker__pb2.RegisterAgentTypeRequest.SerializeToString,
|
||||
agent__worker__pb2.RegisterAgentTypeResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def AddSubscription(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/agents.AgentRpc/AddSubscription',
|
||||
agent__worker__pb2.AddSubscriptionRequest.SerializeToString,
|
||||
agent__worker__pb2.AddSubscriptionResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def RemoveSubscription(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/agents.AgentRpc/RemoveSubscription',
|
||||
agent__worker__pb2.RemoveSubscriptionRequest.SerializeToString,
|
||||
agent__worker__pb2.RemoveSubscriptionResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetSubscriptions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/agents.AgentRpc/GetSubscriptions',
|
||||
agent__worker__pb2.GetSubscriptionsRequest.SerializeToString,
|
||||
agent__worker__pb2.GetSubscriptionsResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
|
||||
import abc
|
||||
from . import agent_worker_pb2
|
||||
import collections.abc
|
||||
import grpc
|
||||
import grpc.aio
|
||||
import typing
|
||||
|
||||
_T = typing.TypeVar("_T")
|
||||
|
||||
class _MaybeAsyncIterator(collections.abc.AsyncIterator[_T], collections.abc.Iterator[_T], metaclass=abc.ABCMeta): ...
|
||||
|
||||
class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type: ignore[misc, type-arg]
|
||||
...
|
||||
|
||||
class AgentRpcStub:
|
||||
def __init__(self, channel: typing.Union[grpc.Channel, grpc.aio.Channel]) -> None: ...
|
||||
OpenChannel: grpc.StreamStreamMultiCallable[
|
||||
agent_worker_pb2.Message,
|
||||
agent_worker_pb2.Message,
|
||||
]
|
||||
|
||||
OpenControlChannel: grpc.StreamStreamMultiCallable[
|
||||
agent_worker_pb2.ControlMessage,
|
||||
agent_worker_pb2.ControlMessage,
|
||||
]
|
||||
|
||||
RegisterAgent: grpc.UnaryUnaryMultiCallable[
|
||||
agent_worker_pb2.RegisterAgentTypeRequest,
|
||||
agent_worker_pb2.RegisterAgentTypeResponse,
|
||||
]
|
||||
|
||||
AddSubscription: grpc.UnaryUnaryMultiCallable[
|
||||
agent_worker_pb2.AddSubscriptionRequest,
|
||||
agent_worker_pb2.AddSubscriptionResponse,
|
||||
]
|
||||
|
||||
RemoveSubscription: grpc.UnaryUnaryMultiCallable[
|
||||
agent_worker_pb2.RemoveSubscriptionRequest,
|
||||
agent_worker_pb2.RemoveSubscriptionResponse,
|
||||
]
|
||||
|
||||
GetSubscriptions: grpc.UnaryUnaryMultiCallable[
|
||||
agent_worker_pb2.GetSubscriptionsRequest,
|
||||
agent_worker_pb2.GetSubscriptionsResponse,
|
||||
]
|
||||
|
||||
class AgentRpcAsyncStub:
|
||||
OpenChannel: grpc.aio.StreamStreamMultiCallable[
|
||||
agent_worker_pb2.Message,
|
||||
agent_worker_pb2.Message,
|
||||
]
|
||||
|
||||
OpenControlChannel: grpc.aio.StreamStreamMultiCallable[
|
||||
agent_worker_pb2.ControlMessage,
|
||||
agent_worker_pb2.ControlMessage,
|
||||
]
|
||||
|
||||
RegisterAgent: grpc.aio.UnaryUnaryMultiCallable[
|
||||
agent_worker_pb2.RegisterAgentTypeRequest,
|
||||
agent_worker_pb2.RegisterAgentTypeResponse,
|
||||
]
|
||||
|
||||
AddSubscription: grpc.aio.UnaryUnaryMultiCallable[
|
||||
agent_worker_pb2.AddSubscriptionRequest,
|
||||
agent_worker_pb2.AddSubscriptionResponse,
|
||||
]
|
||||
|
||||
RemoveSubscription: grpc.aio.UnaryUnaryMultiCallable[
|
||||
agent_worker_pb2.RemoveSubscriptionRequest,
|
||||
agent_worker_pb2.RemoveSubscriptionResponse,
|
||||
]
|
||||
|
||||
GetSubscriptions: grpc.aio.UnaryUnaryMultiCallable[
|
||||
agent_worker_pb2.GetSubscriptionsRequest,
|
||||
agent_worker_pb2.GetSubscriptionsResponse,
|
||||
]
|
||||
|
||||
class AgentRpcServicer(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def OpenChannel(
|
||||
self,
|
||||
request_iterator: _MaybeAsyncIterator[agent_worker_pb2.Message],
|
||||
context: _ServicerContext,
|
||||
) -> typing.Union[collections.abc.Iterator[agent_worker_pb2.Message], collections.abc.AsyncIterator[agent_worker_pb2.Message]]: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def OpenControlChannel(
|
||||
self,
|
||||
request_iterator: _MaybeAsyncIterator[agent_worker_pb2.ControlMessage],
|
||||
context: _ServicerContext,
|
||||
) -> typing.Union[collections.abc.Iterator[agent_worker_pb2.ControlMessage], collections.abc.AsyncIterator[agent_worker_pb2.ControlMessage]]: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def RegisterAgent(
|
||||
self,
|
||||
request: agent_worker_pb2.RegisterAgentTypeRequest,
|
||||
context: _ServicerContext,
|
||||
) -> typing.Union[agent_worker_pb2.RegisterAgentTypeResponse, collections.abc.Awaitable[agent_worker_pb2.RegisterAgentTypeResponse]]: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def AddSubscription(
|
||||
self,
|
||||
request: agent_worker_pb2.AddSubscriptionRequest,
|
||||
context: _ServicerContext,
|
||||
) -> typing.Union[agent_worker_pb2.AddSubscriptionResponse, collections.abc.Awaitable[agent_worker_pb2.AddSubscriptionResponse]]: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def RemoveSubscription(
|
||||
self,
|
||||
request: agent_worker_pb2.RemoveSubscriptionRequest,
|
||||
context: _ServicerContext,
|
||||
) -> typing.Union[agent_worker_pb2.RemoveSubscriptionResponse, collections.abc.Awaitable[agent_worker_pb2.RemoveSubscriptionResponse]]: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def GetSubscriptions(
|
||||
self,
|
||||
request: agent_worker_pb2.GetSubscriptionsRequest,
|
||||
context: _ServicerContext,
|
||||
) -> typing.Union[agent_worker_pb2.GetSubscriptionsResponse, collections.abc.Awaitable[agent_worker_pb2.GetSubscriptionsResponse]]: ...
|
||||
|
||||
def add_AgentRpcServicer_to_server(servicer: AgentRpcServicer, server: typing.Union[grpc.Server, grpc.aio.Server]) -> None: ...
|
||||
@@ -0,0 +1,45 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: cloudevent.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'cloudevent.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
|
||||
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63loudevent.proto\x12\x11io.cloudevents.v1\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xb0\x04\n\nCloudEvent\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\x12\x14\n\x0cspec_version\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x41\n\nattributes\x18\x05 \x03(\x0b\x32-.io.cloudevents.v1.CloudEvent.AttributesEntry\x12\x15\n\x0b\x62inary_data\x18\x06 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x07 \x01(\tH\x00\x12*\n\nproto_data\x18\x08 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x1ai\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x45\n\x05value\x18\x02 \x01(\x0b\x32\x36.io.cloudevents.v1.CloudEvent.CloudEventAttributeValue:\x02\x38\x01\x1a\xd3\x01\n\x18\x43loudEventAttributeValue\x12\x14\n\nce_boolean\x18\x01 \x01(\x08H\x00\x12\x14\n\nce_integer\x18\x02 \x01(\x05H\x00\x12\x13\n\tce_string\x18\x03 \x01(\tH\x00\x12\x12\n\x08\x63\x65_bytes\x18\x04 \x01(\x0cH\x00\x12\x10\n\x06\x63\x65_uri\x18\x05 \x01(\tH\x00\x12\x14\n\nce_uri_ref\x18\x06 \x01(\tH\x00\x12\x32\n\x0c\x63\x65_timestamp\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x00\x42\x06\n\x04\x61ttrB\x06\n\x04\x64\x61taB\x1e\xaa\x02\x1bMicrosoft.AutoGen.Contractsb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'cloudevent_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
_globals['DESCRIPTOR']._loaded_options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'\252\002\033Microsoft.AutoGen.Contracts'
|
||||
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._loaded_options = None
|
||||
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_options = b'8\001'
|
||||
_globals['_CLOUDEVENT']._serialized_start=100
|
||||
_globals['_CLOUDEVENT']._serialized_end=660
|
||||
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_start=333
|
||||
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_end=438
|
||||
_globals['_CLOUDEVENT_CLOUDEVENTATTRIBUTEVALUE']._serialized_start=441
|
||||
_globals['_CLOUDEVENT_CLOUDEVENTATTRIBUTEVALUE']._serialized_end=652
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
*
|
||||
CloudEvent Protobuf Format
|
||||
|
||||
- Required context attributes are explicitly represented.
|
||||
- Optional and Extension context attributes are carried in a map structure.
|
||||
- Data may be represented as binary, text, or protobuf messages.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.any_pb2
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.message
|
||||
import google.protobuf.timestamp_pb2
|
||||
import typing
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class CloudEvent(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@typing.final
|
||||
class AttributesEntry(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
key: builtins.str
|
||||
@property
|
||||
def value(self) -> global___CloudEvent.CloudEventAttributeValue: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.str = ...,
|
||||
value: global___CloudEvent.CloudEventAttributeValue | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["value", b"value"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class CloudEventAttributeValue(google.protobuf.message.Message):
|
||||
"""*
|
||||
The CloudEvent specification defines
|
||||
seven attribute value types...
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
CE_BOOLEAN_FIELD_NUMBER: builtins.int
|
||||
CE_INTEGER_FIELD_NUMBER: builtins.int
|
||||
CE_STRING_FIELD_NUMBER: builtins.int
|
||||
CE_BYTES_FIELD_NUMBER: builtins.int
|
||||
CE_URI_FIELD_NUMBER: builtins.int
|
||||
CE_URI_REF_FIELD_NUMBER: builtins.int
|
||||
CE_TIMESTAMP_FIELD_NUMBER: builtins.int
|
||||
ce_boolean: builtins.bool
|
||||
ce_integer: builtins.int
|
||||
ce_string: builtins.str
|
||||
ce_bytes: builtins.bytes
|
||||
ce_uri: builtins.str
|
||||
ce_uri_ref: builtins.str
|
||||
@property
|
||||
def ce_timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ce_boolean: builtins.bool = ...,
|
||||
ce_integer: builtins.int = ...,
|
||||
ce_string: builtins.str = ...,
|
||||
ce_bytes: builtins.bytes = ...,
|
||||
ce_uri: builtins.str = ...,
|
||||
ce_uri_ref: builtins.str = ...,
|
||||
ce_timestamp: google.protobuf.timestamp_pb2.Timestamp | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["attr", b"attr", "ce_boolean", b"ce_boolean", "ce_bytes", b"ce_bytes", "ce_integer", b"ce_integer", "ce_string", b"ce_string", "ce_timestamp", b"ce_timestamp", "ce_uri", b"ce_uri", "ce_uri_ref", b"ce_uri_ref"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["attr", b"attr", "ce_boolean", b"ce_boolean", "ce_bytes", b"ce_bytes", "ce_integer", b"ce_integer", "ce_string", b"ce_string", "ce_timestamp", b"ce_timestamp", "ce_uri", b"ce_uri", "ce_uri_ref", b"ce_uri_ref"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["attr", b"attr"]) -> typing.Literal["ce_boolean", "ce_integer", "ce_string", "ce_bytes", "ce_uri", "ce_uri_ref", "ce_timestamp"] | None: ...
|
||||
|
||||
ID_FIELD_NUMBER: builtins.int
|
||||
SOURCE_FIELD_NUMBER: builtins.int
|
||||
SPEC_VERSION_FIELD_NUMBER: builtins.int
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
ATTRIBUTES_FIELD_NUMBER: builtins.int
|
||||
BINARY_DATA_FIELD_NUMBER: builtins.int
|
||||
TEXT_DATA_FIELD_NUMBER: builtins.int
|
||||
PROTO_DATA_FIELD_NUMBER: builtins.int
|
||||
id: builtins.str
|
||||
"""-- CloudEvent Context Attributes
|
||||
|
||||
Required Attributes
|
||||
"""
|
||||
source: builtins.str
|
||||
"""URI-reference"""
|
||||
spec_version: builtins.str
|
||||
type: builtins.str
|
||||
binary_data: builtins.bytes
|
||||
text_data: builtins.str
|
||||
@property
|
||||
def attributes(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___CloudEvent.CloudEventAttributeValue]:
|
||||
"""Optional & Extension Attributes"""
|
||||
|
||||
@property
|
||||
def proto_data(self) -> google.protobuf.any_pb2.Any: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: builtins.str = ...,
|
||||
source: builtins.str = ...,
|
||||
spec_version: builtins.str = ...,
|
||||
type: builtins.str = ...,
|
||||
attributes: collections.abc.Mapping[builtins.str, global___CloudEvent.CloudEventAttributeValue] | None = ...,
|
||||
binary_data: builtins.bytes = ...,
|
||||
text_data: builtins.str = ...,
|
||||
proto_data: google.protobuf.any_pb2.Any | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["binary_data", b"binary_data", "data", b"data", "proto_data", b"proto_data", "text_data", b"text_data"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["attributes", b"attributes", "binary_data", b"binary_data", "data", b"data", "id", b"id", "proto_data", b"proto_data", "source", b"source", "spec_version", b"spec_version", "text_data", b"text_data", "type", b"type"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["data", b"data"]) -> typing.Literal["binary_data", "text_data", "proto_data"] | None: ...
|
||||
|
||||
global___CloudEvent = CloudEvent
|
||||
@@ -0,0 +1,24 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.70.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in cloudevent_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
*
|
||||
CloudEvent Protobuf Format
|
||||
|
||||
- Required context attributes are explicitly represented.
|
||||
- Optional and Extension context attributes are carried in a map structure.
|
||||
- Data may be represented as binary, text, or protobuf messages.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import collections.abc
|
||||
import grpc
|
||||
import grpc.aio
|
||||
import typing
|
||||
|
||||
_T = typing.TypeVar("_T")
|
||||
|
||||
class _MaybeAsyncIterator(collections.abc.AsyncIterator[_T], collections.abc.Iterator[_T], metaclass=abc.ABCMeta): ...
|
||||
|
||||
class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type: ignore[misc, type-arg]
|
||||
...
|
||||
0
agent_dhal/agentdhal_extensions/teams/__init__.py
Normal file
0
agent_dhal/agentdhal_extensions/teams/__init__.py
Normal file
287
agent_dhal/agentdhal_extensions/teams/magentic_one.py
Normal file
287
agent_dhal/agentdhal_extensions/teams/magentic_one.py
Normal file
@@ -0,0 +1,287 @@
|
||||
import warnings
|
||||
from typing import Awaitable, Callable, List, Optional, Union
|
||||
|
||||
from agentdhal_agentchat.agents import ApprovalFuncType, CodeExecutorAgent, UserProxyAgent
|
||||
from agentdhal_agentchat.base import ChatAgent
|
||||
from agentdhal_agentchat.teams import MagenticOneGroupChat
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.code_executor import CodeExecutor
|
||||
from agentdhal_core.models import ChatCompletionClient
|
||||
|
||||
from agentdhal_extensions.agents.file_surfer import FileSurfer
|
||||
from agentdhal_extensions.agents.magentic_one import MagenticOneCoderAgent
|
||||
from agentdhal_extensions.agents.web_surfer import MultimodalWebSurfer
|
||||
from agentdhal_extensions.code_executors.local import LocalCommandLineCodeExecutor
|
||||
from agentdhal_extensions.models.openai._openai_client import BaseOpenAIChatCompletionClient
|
||||
|
||||
# Docker imports for default code executor
|
||||
try:
|
||||
import docker
|
||||
from docker.errors import DockerException
|
||||
|
||||
from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor
|
||||
|
||||
_docker_available = True
|
||||
except ImportError:
|
||||
docker = None # type: ignore
|
||||
DockerException = Exception # type: ignore
|
||||
DockerCommandLineCodeExecutor = None # type: ignore
|
||||
_docker_available = False
|
||||
|
||||
SyncInputFunc = Callable[[str], str]
|
||||
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
|
||||
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
|
||||
|
||||
|
||||
def _is_docker_available() -> bool:
|
||||
"""Check if Docker is available and running."""
|
||||
if not _docker_available:
|
||||
return False
|
||||
|
||||
try:
|
||||
if docker is not None:
|
||||
client = docker.from_env()
|
||||
client.ping() # type: ignore
|
||||
return True
|
||||
except DockerException:
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _create_default_code_executor() -> CodeExecutor:
|
||||
"""Create the default code executor, preferring Docker if available."""
|
||||
if _is_docker_available() and DockerCommandLineCodeExecutor is not None:
|
||||
try:
|
||||
return DockerCommandLineCodeExecutor()
|
||||
except Exception:
|
||||
# Fallback to local if Docker fails to initialize
|
||||
pass
|
||||
|
||||
# Issue warning and use local executor if Docker is not available
|
||||
warnings.warn(
|
||||
"Docker is not available or not running. Using LocalCommandLineCodeExecutor instead of the recommended DockerCommandLineCodeExecutor. "
|
||||
"For security, it is recommended to install Docker and ensure it's running before using MagenticOne. "
|
||||
"To install Docker, visit: https://docs.docker.com/get-docker/",
|
||||
UserWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return LocalCommandLineCodeExecutor()
|
||||
|
||||
|
||||
class MagenticOne(MagenticOneGroupChat):
|
||||
"""
|
||||
MagenticOne is a specialized group chat class that integrates various agents
|
||||
such as FileSurfer, WebSurfer, Coder, and Executor to solve complex tasks.
|
||||
To read more about the science behind Magentic-One, see the full blog post: `Magentic-One: A Generalist Multi-Agent System for Solving Complex Tasks <https://www.microsoft.com/en-us/research/articles/magentic-one-a-generalist-multi-agent-system-for-solving-complex-tasks>`_ and the references below.
|
||||
|
||||
Installation:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[magentic-one]"
|
||||
|
||||
|
||||
Args:
|
||||
client (ChatCompletionClient): The client used for model interactions.
|
||||
hil_mode (bool): Optional; If set to True, adds the UserProxyAgent to the list of agents.
|
||||
input_func (InputFuncType | None): Optional; Function to use for user input in human-in-the-loop mode.
|
||||
code_executor (CodeExecutor | None): Optional; Code executor to use. If None, will use Docker if available, otherwise local executor.
|
||||
approval_func (ApprovalFuncType | None): Optional; Function to approve code execution before running. If None, code will execute without approval.
|
||||
|
||||
.. warning::
|
||||
Using Magentic-One involves interacting with a digital world designed for humans, which carries inherent risks. To minimize these risks, consider the following precautions:
|
||||
|
||||
1. **Use Containers**: Run all tasks in docker containers to isolate the agents and prevent direct system attacks.
|
||||
2. **Virtual Environment**: Use a virtual environment to run the agents and prevent them from accessing sensitive data.
|
||||
3. **Monitor Logs**: Closely monitor logs during and after execution to detect and mitigate risky behavior.
|
||||
4. **Human Oversight**: Run the examples with a human in the loop to supervise the agents and prevent unintended consequences.
|
||||
5. **Limit Access**: Restrict the agents' access to the internet and other resources to prevent unauthorized actions.
|
||||
6. **Safeguard Data**: Ensure that the agents do not have access to sensitive data or resources that could be compromised. Do not share sensitive information with the agents.
|
||||
|
||||
Be aware that agents may occasionally attempt risky actions, such as recruiting humans for help or accepting cookie agreements without human involvement. Always ensure agents are monitored and operate within a controlled environment to prevent unintended consequences. Moreover, be cautious that Magentic-One may be susceptible to prompt injection attacks from webpages.
|
||||
|
||||
Architecture:
|
||||
|
||||
Magentic-One is a generalist multi-agent system for solving open-ended web and file-based tasks across a variety of domains. It represents a significant step towards developing agents that can complete tasks that people encounter in their work and personal lives.
|
||||
|
||||
Magentic-One work is based on a multi-agent architecture where a lead Orchestrator agent is responsible for high-level planning, directing other agents, and tracking task progress. The Orchestrator begins by creating a plan to tackle the task, gathering needed facts and educated guesses in a Task Ledger that is maintained. At each step of its plan, the Orchestrator creates a Progress Ledger where it self-reflects on task progress and checks whether the task is completed. If the task is not yet completed, it assigns one of Magentic-One's other agents a subtask to complete. After the assigned agent completes its subtask, the Orchestrator updates the Progress Ledger and continues in this way until the task is complete. If the Orchestrator finds that progress is not being made for enough steps, it can update the Task Ledger and create a new plan.
|
||||
|
||||
Overall, Magentic-One consists of the following agents:
|
||||
|
||||
- Orchestrator: The lead agent responsible for task decomposition and planning, directing other agents in executing subtasks, tracking overall progress, and taking corrective actions as needed.
|
||||
- WebSurfer: An LLM-based agent proficient in commanding and managing the state of a Chromium-based web browser. It performs actions on the browser and reports on the new state of the web page.
|
||||
- FileSurfer: An LLM-based agent that commands a markdown-based file preview application to read local files of most types. It can also perform common navigation tasks such as listing the contents of directories and navigating a folder structure.
|
||||
- Coder: An LLM-based agent specialized in writing code, analyzing information collected from other agents, or creating new artifacts.
|
||||
- ComputerTerminal: Provides the team with access to a console shell where the Coder's programs can be executed, and where new programming libraries can be installed.
|
||||
|
||||
Together, Magentic-One's agents provide the Orchestrator with the tools and capabilities needed to solve a broad variety of open-ended problems, as well as the ability to autonomously adapt to, and act in, dynamic and ever-changing web and file-system environments.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Autonomously complete a coding task:
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.teams.magentic_one import MagenticOne
|
||||
from agentdhal_agentchat.ui import Console
|
||||
|
||||
|
||||
async def example_usage():
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
m1 = MagenticOne(client=client) # Uses DockerCommandLineCodeExecutor by default
|
||||
task = "Write a Python script to fetch data from an API."
|
||||
result = await Console(m1.run_stream(task=task))
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(example_usage())
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Enable human-in-the-loop mode with explicit Docker executor and code approval
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.teams.magentic_one import MagenticOne
|
||||
from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.agents import ApprovalRequest, ApprovalResponse
|
||||
|
||||
|
||||
def user_input_func(prompt: str) -> str:
|
||||
\"\"\"Custom input function for user interaction.\"\"\"
|
||||
return input(prompt)
|
||||
|
||||
|
||||
def approval_func(request: ApprovalRequest) -> ApprovalResponse:
|
||||
\"\"\"Simple approval function that requests user input.\"\"\"
|
||||
print(f\"Code to execute:\\n{request.code}\")
|
||||
user_input = input("Do you approve this code execution? (y/n): ").strip().lower()
|
||||
if user_input == 'y':
|
||||
return ApprovalResponse(approved=True, reason=\"User approved the code execution\")
|
||||
else:
|
||||
return ApprovalResponse(approved=False, reason=\"User denied the code execution\")
|
||||
|
||||
|
||||
async def example_usage_hil():
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
# Explicitly specify Docker code executor for better security
|
||||
async with DockerCommandLineCodeExecutor() as code_executor:
|
||||
m1 = MagenticOne(
|
||||
client=client,
|
||||
hil_mode=True,
|
||||
input_func=user_input_func,
|
||||
code_executor=code_executor,
|
||||
approval_func=approval_func
|
||||
)
|
||||
task = "Write a Python script to fetch data from an API."
|
||||
result = await Console(m1.run_stream(task=task))
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(example_usage_hil())
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Enable code execution approval without human-in-the-loop mode
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.teams.magentic_one import MagenticOne
|
||||
from agentdhal_extensions.code_executors.docker import DockerCommandLineCodeExecutor
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.agents import ApprovalRequest, ApprovalResponse
|
||||
|
||||
|
||||
def approval_func(request: ApprovalRequest) -> ApprovalResponse:
|
||||
\"\"\"Simple approval function that requests user input.\"\"\"
|
||||
print(f\"Code to execute:\\n{request.code}\")
|
||||
user_input = input("Do you approve this code execution? (y/n): ").strip().lower()
|
||||
if user_input == 'y':
|
||||
return ApprovalResponse(approved=True, reason=\"User approved the code execution\")
|
||||
else:
|
||||
return ApprovalResponse(approved=False, reason=\"User denied the code execution\")
|
||||
|
||||
|
||||
async def example_usage_with_approval():
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
# Use approval_func for code approval only (hil_mode=False)
|
||||
async with DockerCommandLineCodeExecutor() as code_executor:
|
||||
m1 = MagenticOne(
|
||||
client=client,
|
||||
hil_mode=False, # No human-in-the-loop for general conversation
|
||||
code_executor=code_executor,
|
||||
approval_func=approval_func # But still ask for code execution approval
|
||||
)
|
||||
task = "Write a Python script to fetch data from an API."
|
||||
result = await Console(m1.run_stream(task=task))
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(example_usage_with_approval())
|
||||
|
||||
References:
|
||||
.. code-block:: bibtex
|
||||
|
||||
@article{fourney2024magentic,
|
||||
title={Magentic-one: A generalist multi-agent system for solving complex tasks},
|
||||
author={Fourney, Adam and Bansal, Gagan and Mozannar, Hussein and Tan, Cheng and Salinas, Eduardo and Niedtner, Friederike and Proebsting, Grace and Bassman, Griffin and Gerrits, Jack and Alber, Jacob and others},
|
||||
journal={arXiv preprint arXiv:2411.04468},
|
||||
year={2024},
|
||||
url={https://arxiv.org/abs/2411.04468}
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ChatCompletionClient,
|
||||
hil_mode: bool = False,
|
||||
input_func: InputFuncType | None = None,
|
||||
code_executor: CodeExecutor | None = None,
|
||||
approval_func: ApprovalFuncType | None = None,
|
||||
):
|
||||
self.client = client
|
||||
self._validate_client_capabilities(client)
|
||||
|
||||
if code_executor is None:
|
||||
warnings.warn(
|
||||
"Instantiating MagenticOne without a code_executor is deprecated. Provide a code_executor to clear this warning (e.g., code_executor=DockerCommandLineCodeExecutor() ).",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
code_executor = _create_default_code_executor()
|
||||
|
||||
fs = FileSurfer("FileSurfer", model_client=client)
|
||||
ws = MultimodalWebSurfer("WebSurfer", model_client=client)
|
||||
coder = MagenticOneCoderAgent("Coder", model_client=client)
|
||||
|
||||
executor = CodeExecutorAgent("ComputerTerminal", code_executor=code_executor, approval_func=approval_func)
|
||||
|
||||
agents: List[ChatAgent] = [fs, ws, coder, executor]
|
||||
if hil_mode:
|
||||
user_proxy = UserProxyAgent("User", input_func=input_func)
|
||||
agents.append(user_proxy)
|
||||
super().__init__(agents, model_client=client)
|
||||
|
||||
def _validate_client_capabilities(self, client: ChatCompletionClient) -> None:
|
||||
capabilities = client.model_info
|
||||
required_capabilities = ["function_calling", "json_output"]
|
||||
|
||||
if not all(capabilities.get(cap) for cap in required_capabilities):
|
||||
warnings.warn(
|
||||
"Client capabilities for MagenticOne must include vision, " "function calling, and json output.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if not isinstance(client, BaseOpenAIChatCompletionClient):
|
||||
warnings.warn(
|
||||
"MagenticOne performs best with OpenAI GPT-4o model either " "through OpenAI or Azure OpenAI.",
|
||||
stacklevel=2,
|
||||
)
|
||||
19
agent_dhal/agentdhal_extensions/tools/azure/__init__.py
Normal file
19
agent_dhal/agentdhal_extensions/tools/azure/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from ._ai_search import (
|
||||
AzureAISearchTool,
|
||||
BaseAzureAISearchTool,
|
||||
SearchQuery,
|
||||
SearchResult,
|
||||
SearchResults,
|
||||
VectorizableTextQuery,
|
||||
)
|
||||
from ._config import AzureAISearchConfig
|
||||
|
||||
__all__ = [
|
||||
"AzureAISearchTool",
|
||||
"BaseAzureAISearchTool",
|
||||
"SearchQuery",
|
||||
"SearchResult",
|
||||
"SearchResults",
|
||||
"AzureAISearchConfig",
|
||||
"VectorizableTextQuery",
|
||||
]
|
||||
1137
agent_dhal/agentdhal_extensions/tools/azure/_ai_search.py
Normal file
1137
agent_dhal/agentdhal_extensions/tools/azure/_ai_search.py
Normal file
File diff suppressed because it is too large
Load Diff
186
agent_dhal/agentdhal_extensions/tools/azure/_config.py
Normal file
186
agent_dhal/agentdhal_extensions/tools/azure/_config.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Configuration for Azure AI Search tool.
|
||||
|
||||
This module provides configuration classes for the Azure AI Search tool, including
|
||||
settings for authentication, search behavior, retry policies, and caching.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
|
||||
T = TypeVar("T", bound="AzureAISearchConfig")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QueryTypeLiteral = Literal["simple", "full", "semantic", "vector"]
|
||||
DEFAULT_API_VERSION = "2023-10-01-preview"
|
||||
|
||||
|
||||
class AzureAISearchConfig(BaseModel):
|
||||
"""Configuration for Azure AI Search with validation.
|
||||
|
||||
This class defines the configuration parameters for Azure AI Search tools, including
|
||||
authentication, search behavior, caching, and embedding settings.
|
||||
|
||||
.. note::
|
||||
This class requires the ``azure`` extra for the ``autogen-ext`` package.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-ext[azure]"
|
||||
|
||||
.. note::
|
||||
**Prerequisites:**
|
||||
|
||||
1. An Azure AI Search service must be created in your Azure subscription.
|
||||
2. The search index must be properly configured for your use case:
|
||||
|
||||
- For vector search: Index must have vector fields
|
||||
- For semantic search: Index must have semantic configuration
|
||||
- For hybrid search: Both vector fields and text fields must be configured
|
||||
3. Required packages:
|
||||
|
||||
- Base functionality: ``azure-search-documents>=11.4.0``
|
||||
- For Azure OpenAI embeddings: ``openai azure-identity``
|
||||
- For OpenAI embeddings: ``openai``
|
||||
|
||||
Example Usage:
|
||||
.. code-block:: python
|
||||
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from agentdhal_extensions.tools.azure import AzureAISearchConfig
|
||||
|
||||
# Basic configuration for full-text search
|
||||
config = AzureAISearchConfig(
|
||||
name="doc-search",
|
||||
endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint
|
||||
index_name="<your-index>", # Name of your search index
|
||||
credential=AzureKeyCredential("<your-key>"), # Your Azure AI Search admin key
|
||||
query_type="simple",
|
||||
search_fields=["content", "title"], # Update with your searchable fields
|
||||
top=5,
|
||||
)
|
||||
|
||||
# Configuration for vector search with Azure OpenAI embeddings
|
||||
vector_config = AzureAISearchConfig(
|
||||
name="vector-search",
|
||||
endpoint="https://your-search.search.windows.net",
|
||||
index_name="<your-index>",
|
||||
credential=AzureKeyCredential("<your-key>"),
|
||||
query_type="vector",
|
||||
vector_fields=["embedding"], # Update with your vector field name
|
||||
embedding_provider="azure_openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
openai_endpoint="https://your-openai.openai.azure.com", # Your Azure OpenAI endpoint
|
||||
openai_api_key="<your-openai-key>", # Your Azure OpenAI key
|
||||
top=5,
|
||||
)
|
||||
|
||||
# Configuration for hybrid search with semantic ranking
|
||||
hybrid_config = AzureAISearchConfig(
|
||||
name="hybrid-search",
|
||||
endpoint="https://your-search.search.windows.net",
|
||||
index_name="<your-index>",
|
||||
credential=AzureKeyCredential("<your-key>"),
|
||||
query_type="semantic",
|
||||
semantic_config_name="<your-semantic-config>", # Name of your semantic configuration
|
||||
search_fields=["content", "title"], # Update with your search fields
|
||||
vector_fields=["embedding"], # Update with your vector field name
|
||||
embedding_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
openai_api_key="<your-openai-key>", # Your OpenAI API key
|
||||
top=5,
|
||||
)
|
||||
"""
|
||||
|
||||
name: str = Field(description="The name of this tool instance")
|
||||
description: Optional[str] = Field(default=None, description="Description explaining the tool's purpose")
|
||||
endpoint: str = Field(description="The full URL of your Azure AI Search service")
|
||||
index_name: str = Field(description="Name of the search index to query")
|
||||
credential: Union[AzureKeyCredential, AsyncTokenCredential] = Field(
|
||||
description="Azure credential for authentication (API key or token)"
|
||||
)
|
||||
api_version: str = Field(
|
||||
default=DEFAULT_API_VERSION,
|
||||
description=f"Azure AI Search API version to use. Defaults to {DEFAULT_API_VERSION}.",
|
||||
)
|
||||
query_type: QueryTypeLiteral = Field(
|
||||
default="simple", description="Type of search to perform: simple, full, semantic, or vector"
|
||||
)
|
||||
search_fields: Optional[List[str]] = Field(default=None, description="Fields to search within documents")
|
||||
select_fields: Optional[List[str]] = Field(default=None, description="Fields to return in search results")
|
||||
vector_fields: Optional[List[str]] = Field(default=None, description="Fields to use for vector search")
|
||||
top: Optional[int] = Field(
|
||||
default=None, description="Maximum number of results to return. For vector searches, acts as k in k-NN."
|
||||
)
|
||||
filter: Optional[str] = Field(default=None, description="OData filter expression to refine search results")
|
||||
semantic_config_name: Optional[str] = Field(
|
||||
default=None, description="Semantic configuration name for enhanced results"
|
||||
)
|
||||
|
||||
enable_caching: bool = Field(default=False, description="Whether to cache search results")
|
||||
cache_ttl_seconds: int = Field(default=300, description="How long to cache results in seconds")
|
||||
|
||||
embedding_provider: Optional[str] = Field(
|
||||
default=None, description="Name of embedding provider for client-side embeddings"
|
||||
)
|
||||
embedding_model: Optional[str] = Field(default=None, description="Model name for client-side embeddings")
|
||||
openai_api_key: Optional[str] = Field(default=None, description="API key for OpenAI/Azure OpenAI embeddings")
|
||||
openai_api_version: Optional[str] = Field(default=None, description="API version for Azure OpenAI embeddings")
|
||||
openai_endpoint: Optional[str] = Field(default=None, description="Endpoint URL for Azure OpenAI embeddings")
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@field_validator("endpoint")
|
||||
def validate_endpoint(cls, v: str) -> str:
|
||||
"""Validate that the endpoint is a valid URL."""
|
||||
if not v.startswith(("http://", "https://")):
|
||||
raise ValueError("endpoint must be a valid URL starting with http:// or https://")
|
||||
return v
|
||||
|
||||
@field_validator("query_type")
|
||||
def normalize_query_type(cls, v: QueryTypeLiteral) -> QueryTypeLiteral:
|
||||
"""Normalize query type to standard values."""
|
||||
if not v:
|
||||
return "simple"
|
||||
|
||||
if isinstance(v, str) and v.lower() == "fulltext":
|
||||
return "full"
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("top")
|
||||
def validate_top(cls, v: Optional[int]) -> Optional[int]:
|
||||
"""Ensure top is a positive integer if provided."""
|
||||
if v is not None and v <= 0:
|
||||
raise ValueError("top must be a positive integer")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_interdependent_fields(self) -> "AzureAISearchConfig":
|
||||
"""Validate interdependent fields after all fields have been parsed."""
|
||||
if self.query_type == "semantic" and not self.semantic_config_name:
|
||||
raise ValueError("semantic_config_name must be provided when query_type is 'semantic'")
|
||||
|
||||
if self.query_type == "vector" and not self.vector_fields:
|
||||
raise ValueError("vector_fields must be provided for vector search")
|
||||
|
||||
if (
|
||||
self.embedding_provider
|
||||
and self.embedding_provider.lower() == "azure_openai"
|
||||
and self.embedding_model
|
||||
and not self.openai_endpoint
|
||||
):
|
||||
raise ValueError("openai_endpoint must be provided for azure_openai embedding provider")
|
||||
|
||||
return self
|
||||
@@ -0,0 +1,3 @@
|
||||
from ._code_execution import CodeExecutionInput, CodeExecutionResult, PythonCodeExecutionTool
|
||||
|
||||
__all__ = ["CodeExecutionInput", "CodeExecutionResult", "PythonCodeExecutionTool"]
|
||||
@@ -0,0 +1,96 @@
|
||||
from agentdhal_core import CancellationToken, Component, ComponentModel
|
||||
from agentdhal_core.code_executor import CodeBlock, CodeExecutor
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field, model_serializer
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class CodeExecutionInput(BaseModel):
|
||||
code: str = Field(description="The contents of the Python code block that should be executed")
|
||||
|
||||
|
||||
class CodeExecutionResult(BaseModel):
|
||||
success: bool
|
||||
output: str
|
||||
|
||||
@model_serializer
|
||||
def ser_model(self) -> str:
|
||||
return self.output
|
||||
|
||||
|
||||
class PythonCodeExecutionToolConfig(BaseModel):
|
||||
"""Configuration for PythonCodeExecutionTool"""
|
||||
|
||||
executor: ComponentModel
|
||||
description: str = "Execute Python code blocks."
|
||||
|
||||
|
||||
class PythonCodeExecutionTool(
|
||||
BaseTool[CodeExecutionInput, CodeExecutionResult], Component[PythonCodeExecutionToolConfig]
|
||||
):
|
||||
"""A tool that executes Python code in a code executor and returns output.
|
||||
|
||||
Example executors:
|
||||
|
||||
* :class:`agentdhal_extensions.code_executors.local.LocalCommandLineCodeExecutor`
|
||||
* :class:`agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor`
|
||||
* :class:`agentdhal_extensions.code_executors.azure.ACADynamicSessionsCodeExecutor`
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-agentchat" "agentdhal-ext[openai]" "yfinance" "matplotlib"
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.code_executors.local import LocalCommandLineCodeExecutor
|
||||
from agentdhal_extensions.tools.code_execution import PythonCodeExecutionTool
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
tool = PythonCodeExecutionTool(LocalCommandLineCodeExecutor(work_dir="coding"))
|
||||
agent = AssistantAgent(
|
||||
"assistant", OpenAIChatCompletionClient(model="gpt-4o"), tools=[tool], reflect_on_tool_use=True
|
||||
)
|
||||
await Console(
|
||||
agent.run_stream(
|
||||
task="Create a plot of MSFT stock prices in 2024 and save it to a file. Use yfinance and matplotlib."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Args:
|
||||
executor (CodeExecutor): The code executor that will be used to execute the code blocks.
|
||||
"""
|
||||
|
||||
component_config_schema = PythonCodeExecutionToolConfig
|
||||
component_provider_override = "agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool"
|
||||
|
||||
def __init__(self, executor: CodeExecutor):
|
||||
super().__init__(CodeExecutionInput, CodeExecutionResult, "CodeExecutor", "Execute Python code blocks.")
|
||||
self._executor = executor
|
||||
|
||||
async def run(self, args: CodeExecutionInput, cancellation_token: CancellationToken) -> CodeExecutionResult:
|
||||
code_blocks = [CodeBlock(code=args.code, language="python")]
|
||||
result = await self._executor.execute_code_blocks(
|
||||
code_blocks=code_blocks, cancellation_token=cancellation_token
|
||||
)
|
||||
return CodeExecutionResult(success=result.exit_code == 0, output=result.output)
|
||||
|
||||
def _to_config(self) -> PythonCodeExecutionToolConfig:
|
||||
"""Convert current instance to config object"""
|
||||
return PythonCodeExecutionToolConfig(executor=self._executor.dump_component())
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: PythonCodeExecutionToolConfig) -> Self:
|
||||
"""Create instance from config object"""
|
||||
executor = CodeExecutor.load_component(config.executor)
|
||||
return cls(executor=executor)
|
||||
25
agent_dhal/agentdhal_extensions/tools/graphrag/__init__.py
Normal file
25
agent_dhal/agentdhal_extensions/tools/graphrag/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from ._config import (
|
||||
GlobalContextConfig,
|
||||
GlobalDataConfig,
|
||||
LocalContextConfig,
|
||||
LocalDataConfig,
|
||||
MapReduceConfig,
|
||||
SearchConfig,
|
||||
)
|
||||
from ._global_search import GlobalSearchTool, GlobalSearchToolArgs, GlobalSearchToolReturn
|
||||
from ._local_search import LocalSearchTool, LocalSearchToolArgs, LocalSearchToolReturn
|
||||
|
||||
__all__ = [
|
||||
"GlobalSearchTool",
|
||||
"LocalSearchTool",
|
||||
"GlobalDataConfig",
|
||||
"LocalDataConfig",
|
||||
"GlobalContextConfig",
|
||||
"GlobalSearchToolArgs",
|
||||
"GlobalSearchToolReturn",
|
||||
"LocalContextConfig",
|
||||
"LocalSearchToolArgs",
|
||||
"LocalSearchToolReturn",
|
||||
"MapReduceConfig",
|
||||
"SearchConfig",
|
||||
]
|
||||
59
agent_dhal/agentdhal_extensions/tools/graphrag/_config.py
Normal file
59
agent_dhal/agentdhal_extensions/tools/graphrag/_config.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DataConfig(BaseModel):
|
||||
input_dir: str
|
||||
entity_table: str = "entities"
|
||||
entity_embedding_table: str = "entities"
|
||||
community_table: str = "communities"
|
||||
community_level: int = 2
|
||||
|
||||
|
||||
class GlobalDataConfig(DataConfig):
|
||||
community_report_table: str = "community_reports"
|
||||
|
||||
|
||||
class LocalDataConfig(DataConfig):
|
||||
relationship_table: str = "relationships"
|
||||
text_unit_table: str = "text_units"
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
max_data_tokens: int = 8000
|
||||
|
||||
|
||||
class GlobalContextConfig(ContextConfig):
|
||||
use_community_summary: bool = False
|
||||
shuffle_data: bool = True
|
||||
include_community_rank: bool = True
|
||||
min_community_rank: int = 0
|
||||
community_rank_name: str = "rank"
|
||||
include_community_weight: bool = True
|
||||
community_weight_name: str = "occurrence weight"
|
||||
normalize_community_weight: bool = True
|
||||
max_data_tokens: int = 12000
|
||||
|
||||
|
||||
class LocalContextConfig(ContextConfig):
|
||||
text_unit_prop: float = 0.5
|
||||
community_prop: float = 0.25
|
||||
include_entity_rank: bool = True
|
||||
rank_description: str = "number of relationships"
|
||||
include_relationship_weight: bool = True
|
||||
relationship_ranking_attribute: str = "rank"
|
||||
|
||||
|
||||
class MapReduceConfig(BaseModel):
|
||||
map_max_tokens: int = 1000
|
||||
map_temperature: float = 0.0
|
||||
reduce_max_tokens: int = 2000
|
||||
reduce_temperature: float = 0.0
|
||||
allow_general_knowledge: bool = False
|
||||
json_mode: bool = False
|
||||
response_type: str = "multiple paragraphs"
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
max_tokens: int = 1500
|
||||
temperature: float = 0.0
|
||||
response_type: str = "multiple paragraphs"
|
||||
233
agent_dhal/agentdhal_extensions/tools/graphrag/_global_search.py
Normal file
233
agent_dhal/agentdhal_extensions/tools/graphrag/_global_search.py
Normal file
@@ -0,0 +1,233 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol import ChatModel
|
||||
from graphrag.query.indexer_adapters import (
|
||||
read_indexer_communities,
|
||||
read_indexer_entities,
|
||||
read_indexer_reports,
|
||||
)
|
||||
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
|
||||
from graphrag.query.structured_search.global_search.search import GlobalSearch
|
||||
|
||||
from ._config import GlobalContextConfig as ContextConfig
|
||||
from ._config import GlobalDataConfig as DataConfig
|
||||
from ._config import MapReduceConfig
|
||||
|
||||
_default_context_config = ContextConfig()
|
||||
_default_mapreduce_config = MapReduceConfig()
|
||||
|
||||
|
||||
class GlobalSearchToolArgs(BaseModel):
|
||||
query: str = Field(..., description="The user query to perform global search on.")
|
||||
|
||||
|
||||
class GlobalSearchToolReturn(BaseModel):
|
||||
answer: str
|
||||
|
||||
|
||||
class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]):
|
||||
"""Enables running GraphRAG global search queries as an AutoGen tool.
|
||||
|
||||
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework.
|
||||
The search combines graph-based document relationships with semantic embeddings to find relevant information.
|
||||
|
||||
.. note::
|
||||
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package.
|
||||
|
||||
To install:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-agentchat" "agentdhal-ext[graphrag]"
|
||||
|
||||
Before using this tool, you must complete the GraphRAG setup and indexing process:
|
||||
|
||||
1. Follow the GraphRAG documentation to initialize your project and settings
|
||||
2. Configure and tune your prompts for the specific use case
|
||||
3. Run the indexing process to generate the required data files
|
||||
4. Ensure you have the settings.yaml file from the setup process
|
||||
|
||||
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/)
|
||||
for detailed instructions on completing these prerequisite steps.
|
||||
|
||||
Example usage with AssistantAgent:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.tools.graphrag import GlobalSearchTool
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize the OpenAI client
|
||||
openai_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key="<api-key>",
|
||||
)
|
||||
|
||||
# Set up global search tool
|
||||
global_tool = GlobalSearchTool.from_settings(root_dir=Path("./"), config_filepath=Path("./settings.yaml"))
|
||||
|
||||
# Create assistant agent with the global search tool
|
||||
assistant_agent = AssistantAgent(
|
||||
name="search_assistant",
|
||||
tools=[global_tool],
|
||||
model_client=openai_client,
|
||||
system_message=(
|
||||
"You are a tool selector AI assistant using the GraphRAG framework. "
|
||||
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
|
||||
"For broader, abstract questions requiring a comprehensive understanding of the dataset, call the 'global_search' function."
|
||||
),
|
||||
)
|
||||
|
||||
# Run a sample query
|
||||
query = "What is the overall sentiment of the community reports?"
|
||||
await Console(assistant_agent.run_stream(task=query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_encoder: tiktoken.Encoding,
|
||||
model: ChatModel,
|
||||
data_config: DataConfig,
|
||||
context_config: ContextConfig = _default_context_config,
|
||||
mapreduce_config: MapReduceConfig = _default_mapreduce_config,
|
||||
):
|
||||
super().__init__(
|
||||
args_type=GlobalSearchToolArgs,
|
||||
return_type=GlobalSearchToolReturn,
|
||||
name="global_search_tool",
|
||||
description="Perform a global search with given parameters using graphrag.",
|
||||
)
|
||||
# Use the provided model
|
||||
self._model = model
|
||||
|
||||
# Load parquet files
|
||||
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
|
||||
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
|
||||
report_df: pd.DataFrame = pd.read_parquet( # type: ignore
|
||||
f"{data_config.input_dir}/{data_config.community_report_table}.parquet"
|
||||
)
|
||||
|
||||
# Fix: Use correct argument order and types for GraphRAG API
|
||||
communities = read_indexer_communities(community_df, report_df)
|
||||
reports = read_indexer_reports(report_df, community_df, data_config.community_level)
|
||||
entities = read_indexer_entities(entity_df, community_df, data_config.community_level)
|
||||
|
||||
context_builder = GlobalCommunityContext(
|
||||
community_reports=reports,
|
||||
communities=communities,
|
||||
entities=entities,
|
||||
token_encoder=token_encoder,
|
||||
)
|
||||
|
||||
context_builder_params = {
|
||||
"use_community_summary": context_config.use_community_summary,
|
||||
"shuffle_data": context_config.shuffle_data,
|
||||
"include_community_rank": context_config.include_community_rank,
|
||||
"min_community_rank": context_config.min_community_rank,
|
||||
"community_rank_name": context_config.community_rank_name,
|
||||
"include_community_weight": context_config.include_community_weight,
|
||||
"community_weight_name": context_config.community_weight_name,
|
||||
"normalize_community_weight": context_config.normalize_community_weight,
|
||||
"max_tokens": context_config.max_data_tokens,
|
||||
"context_name": "Reports",
|
||||
}
|
||||
|
||||
map_llm_params = {
|
||||
"max_tokens": mapreduce_config.map_max_tokens,
|
||||
"temperature": mapreduce_config.map_temperature,
|
||||
"response_format": {"type": "json_object"},
|
||||
}
|
||||
|
||||
reduce_llm_params = {
|
||||
"max_tokens": mapreduce_config.reduce_max_tokens,
|
||||
"temperature": mapreduce_config.reduce_temperature,
|
||||
}
|
||||
|
||||
self._search_engine = GlobalSearch(
|
||||
model=self._model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
max_data_tokens=context_config.max_data_tokens,
|
||||
map_llm_params=map_llm_params,
|
||||
reduce_llm_params=reduce_llm_params,
|
||||
allow_general_knowledge=mapreduce_config.allow_general_knowledge,
|
||||
json_mode=mapreduce_config.json_mode,
|
||||
context_builder_params=context_builder_params,
|
||||
concurrent_coroutines=32,
|
||||
response_type=mapreduce_config.response_type,
|
||||
)
|
||||
|
||||
async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> GlobalSearchToolReturn:
|
||||
search_result = await self._search_engine.search(args.query)
|
||||
assert isinstance(search_result.response, str), "Expected response to be a string"
|
||||
return GlobalSearchToolReturn(answer=search_result.response)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls, root_dir: str | Path, config_filepath: str | Path | None = None) -> "GlobalSearchTool":
|
||||
"""Create a GlobalSearchTool instance from GraphRAG settings file.
|
||||
|
||||
Args:
|
||||
root_dir: Path to the GraphRAG root directory
|
||||
config_filepath: Path to the GraphRAG settings file (optional)
|
||||
|
||||
Returns:
|
||||
An initialized GlobalSearchTool instance
|
||||
"""
|
||||
# Load GraphRAG config
|
||||
if isinstance(root_dir, str):
|
||||
root_dir = Path(root_dir)
|
||||
if isinstance(config_filepath, str):
|
||||
config_filepath = Path(config_filepath)
|
||||
config = load_config(root_dir=root_dir, config_filepath=config_filepath)
|
||||
|
||||
# Get the language model configuration from the models section
|
||||
chat_model_config = config.models.get(defs.DEFAULT_CHAT_MODEL_ID)
|
||||
|
||||
if chat_model_config is None:
|
||||
raise ValueError("default_chat_model not found in config.models")
|
||||
|
||||
# Initialize token encoder based on the model being used
|
||||
try:
|
||||
token_encoder = tiktoken.encoding_for_model(chat_model_config.model)
|
||||
except KeyError:
|
||||
# Fallback to cl100k_base if model is not recognized by tiktoken
|
||||
token_encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
# Create the LLM using ModelManager
|
||||
model = ModelManager().get_or_create_chat_model(
|
||||
name="global_search_model",
|
||||
model_type=chat_model_config.type,
|
||||
config=chat_model_config,
|
||||
)
|
||||
|
||||
# Create data config from storage paths
|
||||
data_config = DataConfig(
|
||||
input_dir=str(config.output.base_dir),
|
||||
)
|
||||
|
||||
return cls(
|
||||
token_encoder=token_encoder,
|
||||
model=model,
|
||||
data_config=data_config,
|
||||
context_config=_default_context_config,
|
||||
mapreduce_config=_default_mapreduce_config,
|
||||
)
|
||||
245
agent_dhal/agentdhal_extensions/tools/graphrag/_local_search.py
Normal file
245
agent_dhal/agentdhal_extensions/tools/graphrag/_local_search.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# mypy: disable-error-code="no-any-unimported,misc"
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol import ChatModel, EmbeddingModel
|
||||
from graphrag.query.indexer_adapters import (
|
||||
read_indexer_entities,
|
||||
read_indexer_relationships,
|
||||
read_indexer_text_units,
|
||||
)
|
||||
from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext
|
||||
from graphrag.query.structured_search.local_search.search import LocalSearch
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
from ._config import LocalContextConfig, SearchConfig
|
||||
from ._config import LocalDataConfig as DataConfig
|
||||
|
||||
_default_context_config = LocalContextConfig()
|
||||
_default_search_config = SearchConfig()
|
||||
|
||||
|
||||
class LocalSearchToolArgs(BaseModel):
|
||||
query: str = Field(..., description="The user query to perform local search on.")
|
||||
|
||||
|
||||
class LocalSearchToolReturn(BaseModel):
|
||||
answer: str = Field(..., description="The answer to the user query.")
|
||||
|
||||
|
||||
class LocalSearchTool(BaseTool[LocalSearchToolArgs, LocalSearchToolReturn]):
|
||||
"""Enables running GraphRAG local search queries as an AutoGen tool.
|
||||
|
||||
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework.
|
||||
The search combines local document context with semantic embeddings to find relevant information.
|
||||
|
||||
.. note::
|
||||
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package.
|
||||
To install:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-agentchat" "agentdhal-ext[graphrag]"
|
||||
|
||||
Before using this tool, you must complete the GraphRAG setup and indexing process:
|
||||
|
||||
1. Follow the GraphRAG documentation to initialize your project and settings
|
||||
2. Configure and tune your prompts for the specific use case
|
||||
3. Run the indexing process to generate the required data files
|
||||
4. Ensure you have the settings.yaml file from the setup process
|
||||
|
||||
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/)
|
||||
for detailed instructions on completing these prerequisite steps.
|
||||
|
||||
Example usage with AssistantAgent:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.tools.graphrag import LocalSearchTool
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize the OpenAI client
|
||||
openai_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key="<api-key>",
|
||||
)
|
||||
|
||||
# Set up local search tool
|
||||
local_tool = LocalSearchTool.from_settings(root_dir=Path("./"), config_filepath=Path("./settings.yaml"))
|
||||
|
||||
# Create assistant agent with the local search tool
|
||||
assistant_agent = AssistantAgent(
|
||||
name="search_assistant",
|
||||
tools=[local_tool],
|
||||
model_client=openai_client,
|
||||
system_message=(
|
||||
"You are a tool selector AI assistant using the GraphRAG framework. "
|
||||
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
|
||||
"For specific, detailed information about particular entities or relationships, call the 'local_search' function."
|
||||
),
|
||||
)
|
||||
|
||||
# Run a sample query
|
||||
query = "What does the station-master say about Dr. Becher?"
|
||||
await Console(assistant_agent.run_stream(task=query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Args:
|
||||
token_encoder (tiktoken.Encoding): The tokenizer used for text encoding
|
||||
model: The chat model to use for search (GraphRAG ChatModel)
|
||||
embedder: The text embedding model to use (GraphRAG EmbeddingModel)
|
||||
data_config (DataConfig): Configuration for data source locations and settings
|
||||
context_config (LocalContextConfig, optional): Configuration for context building. Defaults to default config.
|
||||
search_config (SearchConfig, optional): Configuration for search operations. Defaults to default config.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_encoder: tiktoken.Encoding,
|
||||
model: ChatModel, # ChatModel from GraphRAG
|
||||
embedder: EmbeddingModel, # EmbeddingModel from GraphRAG
|
||||
data_config: DataConfig,
|
||||
context_config: LocalContextConfig = _default_context_config,
|
||||
search_config: SearchConfig = _default_search_config,
|
||||
):
|
||||
super().__init__(
|
||||
args_type=LocalSearchToolArgs,
|
||||
return_type=LocalSearchToolReturn,
|
||||
name="local_search_tool",
|
||||
description="Perform a local search with given parameters using graphrag.",
|
||||
)
|
||||
# Use the provided models
|
||||
self._model = model
|
||||
self._embedder = embedder
|
||||
|
||||
# Load parquet files
|
||||
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
|
||||
relationship_df: pd.DataFrame = pd.read_parquet( # type: ignore
|
||||
f"{data_config.input_dir}/{data_config.relationship_table}.parquet"
|
||||
)
|
||||
text_unit_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.text_unit_table}.parquet") # type: ignore
|
||||
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
|
||||
|
||||
# Read data using indexer adapters
|
||||
entities = read_indexer_entities(entity_df, community_df, data_config.community_level)
|
||||
relationships = read_indexer_relationships(relationship_df)
|
||||
text_units = read_indexer_text_units(text_unit_df)
|
||||
# Set up vector store for entity embeddings
|
||||
description_embedding_store = LanceDBVectorStore(
|
||||
collection_name="default-entity-description",
|
||||
)
|
||||
description_embedding_store.connect(db_uri=f"{data_config.input_dir}/lancedb")
|
||||
|
||||
# Set up context builder
|
||||
context_builder = LocalSearchMixedContext(
|
||||
entities=entities,
|
||||
entity_text_embeddings=description_embedding_store,
|
||||
text_embedder=self._embedder,
|
||||
text_units=text_units,
|
||||
relationships=relationships,
|
||||
token_encoder=token_encoder,
|
||||
)
|
||||
|
||||
context_builder_params = {
|
||||
"text_unit_prop": context_config.text_unit_prop,
|
||||
"community_prop": context_config.community_prop,
|
||||
"include_entity_rank": context_config.include_entity_rank,
|
||||
"rank_description": context_config.rank_description,
|
||||
"include_relationship_weight": context_config.include_relationship_weight,
|
||||
"relationship_ranking_attribute": context_config.relationship_ranking_attribute,
|
||||
"max_tokens": context_config.max_data_tokens,
|
||||
}
|
||||
|
||||
llm_params = {
|
||||
"max_tokens": search_config.max_tokens,
|
||||
"temperature": search_config.temperature,
|
||||
}
|
||||
|
||||
self._search_engine = LocalSearch(
|
||||
model=self._model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
response_type=search_config.response_type,
|
||||
context_builder_params=context_builder_params,
|
||||
model_params=llm_params,
|
||||
)
|
||||
|
||||
async def run(self, args: LocalSearchToolArgs, cancellation_token: CancellationToken) -> LocalSearchToolReturn:
|
||||
search_result = await self._search_engine.search(args.query) # type: ignore[reportUnknownMemberType]
|
||||
assert isinstance(search_result.response, str), "Expected response to be a string"
|
||||
return LocalSearchToolReturn(answer=search_result.response)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls, root_dir: Path, config_filepath: Path | None = None) -> "LocalSearchTool":
|
||||
"""Create a LocalSearchTool instance from GraphRAG settings file.
|
||||
|
||||
Args:
|
||||
root_dir: Path to the GraphRAG root directory
|
||||
config_filepath: Path to the GraphRAG settings file (optional)
|
||||
|
||||
Returns:
|
||||
An initialized LocalSearchTool instance
|
||||
"""
|
||||
# Load GraphRAG config
|
||||
config = load_config(root_dir=root_dir, config_filepath=config_filepath)
|
||||
|
||||
# Get the language model configurations from the models section
|
||||
chat_model_config = config.models.get(defs.DEFAULT_CHAT_MODEL_ID)
|
||||
embedding_model_config = config.models.get(defs.DEFAULT_EMBEDDING_MODEL_ID)
|
||||
|
||||
if chat_model_config is None:
|
||||
raise ValueError("default_chat_model not found in config.models")
|
||||
if embedding_model_config is None:
|
||||
raise ValueError("default_embedding_model not found in config.models")
|
||||
|
||||
# Initialize token encoder based on the model being used
|
||||
try:
|
||||
token_encoder = tiktoken.encoding_for_model(chat_model_config.model)
|
||||
except KeyError:
|
||||
# Fallback to cl100k_base if model is not recognized by tiktoken
|
||||
token_encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
# Create the models using ModelManager
|
||||
model = ModelManager().get_or_create_chat_model(
|
||||
name="local_search_model",
|
||||
model_type=chat_model_config.type,
|
||||
config=chat_model_config,
|
||||
)
|
||||
|
||||
embedder = ModelManager().get_or_create_embedding_model(
|
||||
name="local_search_embedder",
|
||||
model_type=embedding_model_config.type,
|
||||
config=embedding_model_config,
|
||||
)
|
||||
|
||||
# Create data config from storage paths
|
||||
data_config = DataConfig(
|
||||
input_dir=str(config.output.base_dir),
|
||||
)
|
||||
|
||||
return cls(
|
||||
token_encoder=token_encoder,
|
||||
model=model,
|
||||
embedder=embedder,
|
||||
data_config=data_config,
|
||||
context_config=_default_context_config,
|
||||
search_config=_default_search_config,
|
||||
)
|
||||
3
agent_dhal/agentdhal_extensions/tools/http/__init__.py
Normal file
3
agent_dhal/agentdhal_extensions/tools/http/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from ._http_tool import HttpTool
|
||||
|
||||
__all__ = ["HttpTool"]
|
||||
244
agent_dhal/agentdhal_extensions/tools/http/_http_tool.py
Normal file
244
agent_dhal/agentdhal_extensions/tools/http/_http_tool.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import re
|
||||
from typing import Any, Literal, Optional, Type
|
||||
|
||||
import httpx
|
||||
from agentdhal_core import CancellationToken, Component
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from json_schema_to_pydantic import create_model
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
DEFAULT_TIMEOUT_CONFIG = 5.0
|
||||
|
||||
|
||||
class HttpToolConfig(BaseModel):
|
||||
name: str
|
||||
"""
|
||||
The name of the tool.
|
||||
"""
|
||||
description: Optional[str]
|
||||
"""
|
||||
A description of the tool.
|
||||
"""
|
||||
scheme: Literal["http", "https"] = "http"
|
||||
"""
|
||||
The scheme to use for the request.
|
||||
"""
|
||||
host: str
|
||||
"""
|
||||
The URL to send the request to.
|
||||
"""
|
||||
port: int
|
||||
"""
|
||||
The port to send the request to.
|
||||
"""
|
||||
path: str = Field(default="/")
|
||||
"""
|
||||
The path to send the request to. defaults to "/"
|
||||
The path can accept parameters, e.g. "/{param1}/{param2}".
|
||||
These parameters will be templated from the inputs args, any additional parameters will be added as query parameters or the body of the request.
|
||||
"""
|
||||
method: Optional[Literal["GET", "POST", "PUT", "DELETE", "PATCH"]] = "POST"
|
||||
"""
|
||||
The HTTP method to use, will default to POST if not provided.
|
||||
"""
|
||||
headers: Optional[dict[str, Any]]
|
||||
"""
|
||||
A dictionary of headers to send with the request.
|
||||
"""
|
||||
json_schema: dict[str, Any]
|
||||
"""
|
||||
A JSON Schema object defining the expected parameters for the tool.
|
||||
Path parameters MUST also be included in the json_schema. They must also MUST be set to string
|
||||
"""
|
||||
return_type: Optional[Literal["text", "json"]] = "text"
|
||||
"""
|
||||
The type of response to return from the tool.
|
||||
"""
|
||||
timeout: float = DEFAULT_TIMEOUT_CONFIG
|
||||
"""
|
||||
The timeout for the tool request in seconds.
|
||||
"""
|
||||
|
||||
|
||||
class HttpTool(BaseTool[BaseModel, Any], Component[HttpToolConfig]):
|
||||
"""A wrapper for using an HTTP server as a tool.
|
||||
|
||||
Args:
|
||||
name (str): The name of the tool.
|
||||
description (str, optional): A description of the tool.
|
||||
scheme (str): The scheme to use for the request. Must be either "http" or "https".
|
||||
host (str): The host to send the request to.
|
||||
port (int): The port to send the request to.
|
||||
path (str, optional): The path to send the request to. Defaults to "/".
|
||||
Can include path parameters like "/{param1}/{param2}" which will be templated from input args.
|
||||
method (str, optional): The HTTP method to use, will default to POST if not provided.
|
||||
Must be one of "GET", "POST", "PUT", "DELETE", "PATCH".
|
||||
headers (dict[str, Any], optional): A dictionary of headers to send with the request.
|
||||
json_schema (dict[str, Any]): A JSON Schema object defining the expected parameters for the tool.
|
||||
Path parameters must also be included in the schema and must be strings.
|
||||
return_type (Literal["text", "json"], optional): The type of response to return from the tool.
|
||||
Defaults to "text".
|
||||
timeout (float, optional): The timeout for HTTP requests in seconds.
|
||||
Defaults to 5.0.
|
||||
|
||||
.. note::
|
||||
This tool requires the :code:`http-tool` extra for the :code:`autogen-ext` package.
|
||||
|
||||
To install:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-agentchat" "agentdhal-ext[http-tool]"
|
||||
|
||||
Example:
|
||||
Simple use case::
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.tools.http import HttpTool
|
||||
|
||||
# Define a JSON schema for a base64 decode tool
|
||||
base64_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {"type": "string", "description": "The base64 value to decode"},
|
||||
},
|
||||
"required": ["value"],
|
||||
}
|
||||
|
||||
# Create an HTTP tool for the httpbin API
|
||||
base64_tool = HttpTool(
|
||||
name="base64_decode",
|
||||
description="base64 decode a value",
|
||||
scheme="https",
|
||||
host="httpbin.org",
|
||||
port=443,
|
||||
path="/base64/{value}",
|
||||
method="GET",
|
||||
json_schema=base64_schema,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
# Create an assistant with the base64 tool
|
||||
model = OpenAIChatCompletionClient(model="gpt-4")
|
||||
assistant = AssistantAgent("base64_assistant", model_client=model, tools=[base64_tool])
|
||||
|
||||
# The assistant can now use the base64 tool to decode the string
|
||||
response = await assistant.on_messages(
|
||||
[TextMessage(content="Can you base64 decode the value 'YWJjZGU=', please?", source="user")],
|
||||
CancellationToken(),
|
||||
)
|
||||
print(response.chat_message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_type = "tool"
|
||||
component_provider_override = "agentdhal_extensions.tools.http.HttpTool"
|
||||
component_config_schema = HttpToolConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
host: str,
|
||||
port: int,
|
||||
json_schema: dict[str, Any],
|
||||
headers: Optional[dict[str, Any]] = None,
|
||||
description: str = "HTTP tool",
|
||||
path: str = "/",
|
||||
scheme: Literal["http", "https"] = "http",
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "POST",
|
||||
return_type: Literal["text", "json"] = "text",
|
||||
timeout: float = DEFAULT_TIMEOUT_CONFIG,
|
||||
) -> None:
|
||||
self.server_params = HttpToolConfig(
|
||||
name=name,
|
||||
description=description,
|
||||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
scheme=scheme,
|
||||
method=method,
|
||||
headers=headers,
|
||||
json_schema=json_schema,
|
||||
return_type=return_type,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Use regex to find all path parameters, we will need those later to template the path
|
||||
path_params = {match.group(1) for match in re.finditer(r"{([^}]*)}", path)}
|
||||
self._path_params = path_params
|
||||
|
||||
# Create the input model from the modified schema
|
||||
input_model = create_model(json_schema)
|
||||
|
||||
# Use Any as return type since HTTP responses can vary
|
||||
base_return_type: Type[Any] = object
|
||||
|
||||
super().__init__(input_model, base_return_type, name, description)
|
||||
|
||||
def _to_config(self) -> HttpToolConfig:
|
||||
copied_config = self.server_params.model_copy()
|
||||
return copied_config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: HttpToolConfig) -> Self:
|
||||
copied_config = config.model_copy().model_dump()
|
||||
return cls(**copied_config)
|
||||
|
||||
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
|
||||
"""Execute the HTTP tool with the given arguments.
|
||||
|
||||
Args:
|
||||
args: The validated input arguments
|
||||
cancellation_token: Token for cancelling the operation
|
||||
|
||||
Returns:
|
||||
The response body from the HTTP call in JSON format
|
||||
|
||||
Raises:
|
||||
Exception: If tool execution fails
|
||||
"""
|
||||
|
||||
model_dump = args.model_dump()
|
||||
path_params = {k: v for k, v in model_dump.items() if k in self._path_params}
|
||||
# Remove path params from the model dump
|
||||
for k in self._path_params:
|
||||
model_dump.pop(k)
|
||||
|
||||
path = self.server_params.path.format(**path_params)
|
||||
|
||||
url = httpx.URL(
|
||||
scheme=self.server_params.scheme,
|
||||
host=self.server_params.host,
|
||||
port=self.server_params.port,
|
||||
path=path,
|
||||
)
|
||||
timeout_config = httpx.Timeout(timeout=self.server_params.timeout)
|
||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||
match self.server_params.method:
|
||||
case "GET":
|
||||
response = await client.get(url, headers=self.server_params.headers, params=model_dump)
|
||||
case "PUT":
|
||||
response = await client.put(url, headers=self.server_params.headers, json=model_dump)
|
||||
case "DELETE":
|
||||
response = await client.delete(url, headers=self.server_params.headers, params=model_dump)
|
||||
case "PATCH":
|
||||
response = await client.patch(url, headers=self.server_params.headers, json=model_dump)
|
||||
case _: # Default case POST
|
||||
response = await client.post(url, headers=self.server_params.headers, json=model_dump)
|
||||
|
||||
match self.server_params.return_type:
|
||||
case "text":
|
||||
return response.text
|
||||
case "json":
|
||||
return response.json()
|
||||
case _:
|
||||
raise ValueError(f"Invalid return type: {self.server_params.return_type}")
|
||||
@@ -0,0 +1,3 @@
|
||||
from ._langchain_adapter import LangChainToolAdapter
|
||||
|
||||
__all__ = ["LangChainToolAdapter"]
|
||||
@@ -0,0 +1,198 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Type, cast
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.tools import BaseTool as LangChainTool
|
||||
|
||||
|
||||
class LangChainToolAdapter(BaseTool[BaseModel, Any]):
|
||||
"""Allows you to wrap a LangChain tool and make it available to AutoGen.
|
||||
|
||||
.. note::
|
||||
|
||||
This class requires the :code:`langchain` extra for the :code:`autogen-ext` package.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-ext[langchain]"
|
||||
|
||||
|
||||
Args:
|
||||
langchain_tool (LangChainTool): A LangChain tool to wrap
|
||||
|
||||
Examples:
|
||||
|
||||
Use the `PythonAstREPLTool` from the `langchain_experimental` package to
|
||||
create a tool that allows you to interact with a Pandas DataFrame.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
import pandas as pd
|
||||
from langchain_experimental.tools.python.tool import PythonAstREPLTool
|
||||
from agentdhal_extensions.tools.langchain import LangChainToolAdapter
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_core import CancellationToken
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
df = pd.read_csv("https://raw.githubusercontent.com/pandas-dev/pandas/main/doc/data/titanic.csv") # type: ignore
|
||||
tool = LangChainToolAdapter(PythonAstREPLTool(locals={"df": df}))
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
agent = AssistantAgent(
|
||||
"assistant",
|
||||
tools=[tool],
|
||||
model_client=model_client,
|
||||
system_message="Use the `df` variable to access the dataset.",
|
||||
)
|
||||
await Console(
|
||||
agent.on_messages_stream(
|
||||
[TextMessage(content="What's the average age of the passengers?", source="user")], CancellationToken()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
This example demonstrates how to use the `SQLDatabaseToolkit` from the `langchain_community`
|
||||
package to interact with an SQLite database.
|
||||
It uses the :class:`~agentdhal_agentchat.team.RoundRobinGroupChat` to iterate the single agent over multiple steps.
|
||||
If you want to one step at a time, you can just call `run_stream` method of the
|
||||
:class:`~agentdhal_agentchat.agents.AssistantAgent` class directly.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
import sqlite3
|
||||
|
||||
import requests
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.tools.langchain import LangChainToolAdapter
|
||||
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
from langchain_openai import ChatOpenAI
|
||||
from sqlalchemy import Engine, create_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
|
||||
def get_engine_for_chinook_db() -> Engine:
|
||||
url = "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql"
|
||||
response = requests.get(url)
|
||||
sql_script = response.text
|
||||
connection = sqlite3.connect(":memory:", check_same_thread=False)
|
||||
connection.executescript(sql_script)
|
||||
return create_engine(
|
||||
"sqlite://",
|
||||
creator=lambda: connection,
|
||||
poolclass=StaticPool,
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create the engine and database wrapper.
|
||||
engine = get_engine_for_chinook_db()
|
||||
db = SQLDatabase(engine)
|
||||
|
||||
# Create the toolkit.
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
||||
|
||||
# Create the LangChain tool adapter for every tool in the toolkit.
|
||||
tools = [LangChainToolAdapter(tool) for tool in toolkit.get_tools()]
|
||||
|
||||
# Create the chat completion client.
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
# Create the assistant agent.
|
||||
agent = AssistantAgent(
|
||||
"assistant",
|
||||
model_client=model_client,
|
||||
tools=tools, # type: ignore
|
||||
model_client_stream=True,
|
||||
system_message="Respond with 'TERMINATE' if the task is completed.",
|
||||
)
|
||||
|
||||
# Create termination condition.
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
|
||||
# Create a round-robin group chat to iterate the single agent over multiple steps.
|
||||
chat = RoundRobinGroupChat([agent], termination_condition=termination)
|
||||
|
||||
# Run the chat.
|
||||
await Console(chat.run_stream(task="Show some tables in the database"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, langchain_tool: LangChainTool):
|
||||
self._langchain_tool: LangChainTool = langchain_tool
|
||||
|
||||
# Extract name and description
|
||||
name = self._langchain_tool.name
|
||||
description = self._langchain_tool.description or ""
|
||||
|
||||
# Determine the callable method
|
||||
if hasattr(self._langchain_tool, "func") and callable(self._langchain_tool.func): # type: ignore
|
||||
assert self._langchain_tool.func is not None # type: ignore
|
||||
self._callable: Callable[..., Any] = self._langchain_tool.func # type: ignore
|
||||
elif hasattr(self._langchain_tool, "_run") and callable(self._langchain_tool._run): # type: ignore
|
||||
self._callable: Callable[..., Any] = self._langchain_tool._run # type: ignore
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"The provided LangChain tool '{name}' does not have a callable 'func' or '_run' method."
|
||||
)
|
||||
|
||||
# Determine args_type
|
||||
if self._langchain_tool.args_schema: # pyright: ignore
|
||||
args_type = self._langchain_tool.args_schema # pyright: ignore
|
||||
else:
|
||||
# Infer args_type from the callable's signature
|
||||
sig = inspect.signature(cast(Callable[..., Any], self._callable)) # type: ignore
|
||||
fields = {
|
||||
k: (v.annotation, Field(...))
|
||||
for k, v in sig.parameters.items()
|
||||
if k != "self" and v.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
|
||||
}
|
||||
args_type = create_model(f"{name}Args", **fields) # type: ignore
|
||||
# Note: type ignore is used due to a LangChain typing limitation
|
||||
|
||||
# Ensure args_type is a subclass of BaseModel
|
||||
if not issubclass(args_type, BaseModel):
|
||||
raise ValueError(f"Failed to create a valid Pydantic v2 model for {name}")
|
||||
|
||||
# Assume return_type as Any if not specified
|
||||
return_type: Type[Any] = object
|
||||
|
||||
super().__init__(args_type, return_type, name, description)
|
||||
|
||||
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
|
||||
# Prepare arguments
|
||||
kwargs = args.model_dump()
|
||||
|
||||
# Determine if the callable is asynchronous
|
||||
if inspect.iscoroutinefunction(self._callable):
|
||||
return await self._callable(**kwargs)
|
||||
else:
|
||||
# Run in a thread to avoid blocking the event loop
|
||||
return await asyncio.to_thread(self._call_sync, kwargs)
|
||||
|
||||
def _call_sync(self, kwargs: Dict[str, Any]) -> Any:
|
||||
return self._callable(**kwargs)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user