first commit
This commit is contained in:
10
agent_dhal/agentdhal_extensions/agents/azure/__init__.py
Normal file
10
agent_dhal/agentdhal_extensions/agents/azure/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
try:
|
||||
from ._azure_ai_agent import AzureAIAgent
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Dependencies for AzureAIAgent not found. "
|
||||
'Please install autogen-ext with the "azure" extra: '
|
||||
'pip install "agentdhal-ext[azure]"'
|
||||
) from e
|
||||
|
||||
__all__ = ["AzureAIAgent"]
|
||||
1096
agent_dhal/agentdhal_extensions/agents/azure/_azure_ai_agent.py
Normal file
1096
agent_dhal/agentdhal_extensions/agents/azure/_azure_ai_agent.py
Normal file
File diff suppressed because it is too large
Load Diff
61
agent_dhal/agentdhal_extensions/agents/azure/_types.py
Normal file
61
agent_dhal/agentdhal_extensions/agents/azure/_types.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import Any, Awaitable, Callable, Iterable, List, Literal, Optional, TypeGuard, Union
|
||||
|
||||
from agentdhal_core.tools import Tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from azure.ai.agents.models import (
|
||||
AzureAISearchToolDefinition,
|
||||
AzureFunctionToolDefinition,
|
||||
BingGroundingToolDefinition,
|
||||
CodeInterpreterToolDefinition,
|
||||
FileSearchToolDefinition,
|
||||
MessageTextUrlCitationAnnotation,
|
||||
)
|
||||
|
||||
ListToolType = Iterable[
|
||||
Union[
|
||||
Literal[
|
||||
"file_search",
|
||||
"code_interpreter",
|
||||
"bing_grounding",
|
||||
"azure_ai_search",
|
||||
"azure_function",
|
||||
],
|
||||
BingGroundingToolDefinition,
|
||||
CodeInterpreterToolDefinition,
|
||||
AzureAISearchToolDefinition,
|
||||
FileSearchToolDefinition,
|
||||
AzureFunctionToolDefinition,
|
||||
Tool,
|
||||
Callable[..., Any],
|
||||
Callable[..., Awaitable[Any]],
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class AzureAIAgentState(BaseModel):
|
||||
"""
|
||||
Represents the state of an AzureAIAgent that can be saved and loaded.
|
||||
|
||||
This state model keeps track of persistent information about an agent session
|
||||
including agent and thread identifiers, message history, and associated resources.
|
||||
|
||||
Attributes:
|
||||
type (str): The type identifier for the state object, always "AzureAIAgentState"
|
||||
agent_id (Optional[str]): The ID of the Azure AI agent
|
||||
thread_id (Optional[str]): The ID of the conversation thread
|
||||
initial_message_ids (List[str]): List of message IDs from the initial state
|
||||
vector_store_id (Optional[str]): The ID of the associated vector store for file search
|
||||
uploaded_file_ids (List[str]): List of IDs for files uploaded to the agent
|
||||
"""
|
||||
|
||||
type: str = Field(default="AzureAIAgentState")
|
||||
agent_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None
|
||||
initial_message_ids: List[str] = Field(default_factory=list)
|
||||
vector_store_id: Optional[str] = None
|
||||
uploaded_file_ids: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
def has_annotations(obj: Any) -> TypeGuard[list[MessageTextUrlCitationAnnotation]]:
|
||||
return obj is not None and isinstance(obj, list)
|
||||
@@ -0,0 +1,3 @@
|
||||
from ._file_surfer import FileSurfer
|
||||
|
||||
__all__ = ["FileSurfer"]
|
||||
@@ -0,0 +1,208 @@
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from agentdhal_agentchat.agents import BaseChatAgent
|
||||
from agentdhal_agentchat.base import Response
|
||||
from agentdhal_agentchat.messages import (
|
||||
BaseChatMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from agentdhal_agentchat.utils import remove_images
|
||||
from agentdhal_core import CancellationToken, Component, ComponentModel, FunctionCall
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._markdown_file_browser import MarkdownFileBrowser
|
||||
|
||||
# from typing_extensions import Annotated
|
||||
from ._tool_definitions import (
|
||||
TOOL_FIND_NEXT,
|
||||
TOOL_FIND_ON_PAGE_CTRL_F,
|
||||
TOOL_OPEN_PATH,
|
||||
TOOL_PAGE_DOWN,
|
||||
TOOL_PAGE_UP,
|
||||
)
|
||||
|
||||
|
||||
class FileSurferConfig(BaseModel):
|
||||
"""Configuration for FileSurfer agent"""
|
||||
|
||||
name: str
|
||||
model_client: ComponentModel
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class FileSurfer(BaseChatAgent, Component[FileSurferConfig]):
|
||||
"""An agent, used by MagenticOne, that acts as a local file previewer. FileSurfer can open and read a variety of common file types, and can navigate the local file hierarchy.
|
||||
|
||||
Installation:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[file-surfer]"
|
||||
|
||||
Args:
|
||||
name (str): The agent's name
|
||||
model_client (ChatCompletionClient): The model to use (must be tool-use enabled)
|
||||
description (str): The agent's description used by the team. Defaults to DEFAULT_DESCRIPTION
|
||||
base_path (str): The base path to use for the file browser. Defaults to the current working directory.
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = FileSurferConfig
|
||||
component_provider_override = "agentdhal_extensions.agents.file_surfer.FileSurfer"
|
||||
|
||||
DEFAULT_DESCRIPTION = "An agent that can handle local files."
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGES = [
|
||||
SystemMessage(
|
||||
content="""
|
||||
You are a helpful AI Assistant.
|
||||
When given a user query, use available functions to help the user with their request."""
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_client: ChatCompletionClient,
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
base_path: str = os.getcwd(),
|
||||
) -> None:
|
||||
super().__init__(name, description)
|
||||
self._model_client = model_client
|
||||
self._chat_history: List[LLMMessage] = []
|
||||
self._browser = MarkdownFileBrowser(viewport_size=1024 * 5, base_path=base_path)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (TextMessage,)
|
||||
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
for chat_message in messages:
|
||||
self._chat_history.append(chat_message.to_model_message())
|
||||
try:
|
||||
_, content = await self._generate_reply(cancellation_token=cancellation_token)
|
||||
self._chat_history.append(AssistantMessage(content=content, source=self.name))
|
||||
return Response(chat_message=TextMessage(content=content, source=self.name))
|
||||
|
||||
except BaseException:
|
||||
content = f"File surfing error:\n\n{traceback.format_exc()}"
|
||||
self._chat_history.append(AssistantMessage(content=content, source=self.name))
|
||||
return Response(chat_message=TextMessage(content=content, source=self.name))
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
self._chat_history.clear()
|
||||
|
||||
def _get_browser_state(self) -> Tuple[str, str]:
|
||||
"""
|
||||
Get the current state of the browser, including the header and content.
|
||||
"""
|
||||
header = f"Path: {self._browser.path}\n"
|
||||
|
||||
if self._browser.page_title is not None:
|
||||
header += f"Title: {self._browser.page_title}\n"
|
||||
|
||||
current_page = self._browser.viewport_current_page
|
||||
total_pages = len(self._browser.viewport_pages)
|
||||
header += f"Viewport position: Showing page {current_page+1} of {total_pages}.\n"
|
||||
|
||||
return (header, self._browser.viewport)
|
||||
|
||||
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, str]:
|
||||
history = self._chat_history[0:-1]
|
||||
last_message = self._chat_history[-1]
|
||||
assert isinstance(last_message, UserMessage)
|
||||
|
||||
task_content = last_message.content # the last message from the sender is the task
|
||||
|
||||
assert self._browser is not None
|
||||
|
||||
context_message = UserMessage(
|
||||
source="user",
|
||||
content=f"Your file viewer is currently open to the file or directory '{self._browser.page_title}' with path '{self._browser.path}'.",
|
||||
)
|
||||
|
||||
task_message = UserMessage(
|
||||
source="user",
|
||||
content=task_content,
|
||||
)
|
||||
|
||||
create_result = await self._model_client.create(
|
||||
messages=self._get_compatible_context(history + [context_message, task_message]),
|
||||
tools=[
|
||||
TOOL_OPEN_PATH,
|
||||
TOOL_PAGE_DOWN,
|
||||
TOOL_PAGE_UP,
|
||||
TOOL_FIND_NEXT,
|
||||
TOOL_FIND_ON_PAGE_CTRL_F,
|
||||
],
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
response = create_result.content
|
||||
|
||||
if isinstance(response, str):
|
||||
# Answer directly.
|
||||
return False, response
|
||||
|
||||
elif isinstance(response, list) and all(isinstance(item, FunctionCall) for item in response):
|
||||
function_calls = response
|
||||
for function_call in function_calls:
|
||||
tool_name = function_call.name
|
||||
|
||||
try:
|
||||
arguments = json.loads(function_call.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
error_str = f"File surfer encountered an error decoding JSON arguments: {e}"
|
||||
return False, error_str
|
||||
|
||||
if tool_name == "open_path":
|
||||
path = arguments["path"]
|
||||
self._browser.open_path(path)
|
||||
elif tool_name == "page_up":
|
||||
self._browser.page_up()
|
||||
elif tool_name == "page_down":
|
||||
self._browser.page_down()
|
||||
elif tool_name == "find_on_page_ctrl_f":
|
||||
search_string = arguments["search_string"]
|
||||
self._browser.find_on_page(search_string)
|
||||
elif tool_name == "find_next":
|
||||
self._browser.find_next()
|
||||
header, content = self._get_browser_state()
|
||||
final_response = header.strip() + "\n=======================\n" + content
|
||||
return False, final_response
|
||||
|
||||
final_response = "TERMINATE"
|
||||
return False, final_response
|
||||
|
||||
def _get_compatible_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
|
||||
if self._model_client.model_info["vision"]:
|
||||
return messages
|
||||
else:
|
||||
return remove_images(messages)
|
||||
|
||||
def _to_config(self) -> FileSurferConfig:
|
||||
return FileSurferConfig(
|
||||
name=self.name,
|
||||
model_client=self._model_client.dump_component(),
|
||||
description=self.description,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: FileSurferConfig) -> Self:
|
||||
return cls(
|
||||
name=config.name,
|
||||
model_client=ChatCompletionClient.load_component(config.model_client),
|
||||
description=config.description or cls.DEFAULT_DESCRIPTION,
|
||||
)
|
||||
@@ -0,0 +1,317 @@
|
||||
# ruff: noqa: E722
|
||||
import datetime
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
# TODO: Fix unfollowed import
|
||||
from markitdown import FileConversionException, MarkItDown, UnsupportedFormatException # type: ignore
|
||||
|
||||
|
||||
class MarkdownFileBrowser:
|
||||
"""
|
||||
(In preview) An extremely simple Markdown-powered file browser.
|
||||
"""
|
||||
|
||||
# TODO: Fix unfollowed import
|
||||
def __init__( # type: ignore
|
||||
self,
|
||||
viewport_size: Union[int, None] = 1024 * 8,
|
||||
base_path: str | None = os.getcwd(),
|
||||
cwd: str | None = None,
|
||||
):
|
||||
"""
|
||||
Instantiate a new MarkdownFileBrowser.
|
||||
|
||||
Arguments:
|
||||
viewport_size: Approximately how many *characters* fit in the viewport. Viewport dimensions are adjusted dynamically to avoid cutting off words (default: 8192).
|
||||
base_path: The base path to use for the file browser. Files outside this path cannot be accessed. Defaults to the current working directory.
|
||||
cwd: The browser's current working directory. Defaults to the system's current working directory.
|
||||
"""
|
||||
self.viewport_size = viewport_size # Applies only to the standard uri types
|
||||
self.history: List[Tuple[str, float]] = list()
|
||||
self.page_title: Optional[str] = None
|
||||
self.viewport_current_page = 0
|
||||
self.viewport_pages: List[Tuple[int, int]] = list()
|
||||
self._markdown_converter = MarkItDown()
|
||||
self._base_path = None if base_path is None else os.path.realpath(base_path)
|
||||
self._page_content: str = ""
|
||||
self._find_on_page_query: Union[str, None] = None
|
||||
self._find_on_page_last_result: Union[int, None] = None # Location of the last result
|
||||
|
||||
# Set the working directory
|
||||
if cwd is None:
|
||||
if self._validate_path(os.getcwd()):
|
||||
# Use the current working directory if it's in the base path
|
||||
cwd = os.path.realpath(os.getcwd())
|
||||
elif self._base_path is not None:
|
||||
# Otherwise, use the base path
|
||||
cwd = os.path.realpath(self._base_path)
|
||||
else:
|
||||
raise ValueError("No valid working directory (cwd) provided.")
|
||||
elif not self._validate_path(cwd):
|
||||
# A cwd was provided, but it is not valid
|
||||
raise ValueError(f"Working directory (cwd) '{cwd}' is not valid. It must be within the base path.")
|
||||
|
||||
# Populate the history with the current working directory
|
||||
self.set_path(os.path.realpath(cwd))
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
"""Return the path of the current page."""
|
||||
assert len(self.history) > 0
|
||||
return self.history[-1][0]
|
||||
|
||||
def _validate_path(self, path: str) -> bool:
|
||||
"""Validates the path to ensure it is within the base path.
|
||||
|
||||
Arguments:
|
||||
path: The path to validate.
|
||||
Returns:
|
||||
True if the path is valid, False otherwise.
|
||||
"""
|
||||
if self._base_path is None:
|
||||
return True
|
||||
|
||||
# Normalize the paths
|
||||
path = os.path.realpath(path)
|
||||
base = os.path.realpath(self._base_path)
|
||||
|
||||
# Check if the path is within the base path
|
||||
if os.path.commonpath([path, base]) != base:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def set_path(self, path: str) -> None:
|
||||
"""Sets the path of the current page.
|
||||
This will result in the file being opened for reading.
|
||||
|
||||
Arguments:
|
||||
path: An absolute or relative path of the file or directory to open."
|
||||
"""
|
||||
|
||||
# Handle relative paths
|
||||
path = os.path.expanduser(path)
|
||||
if not os.path.isabs(path):
|
||||
if os.path.isfile(self.path):
|
||||
path = os.path.abspath(os.path.join(os.path.dirname(self.path), path))
|
||||
elif os.path.isdir(self.path):
|
||||
path = os.path.abspath(os.path.join(self.path, path))
|
||||
# If neither a file or a directory, take it verbatim
|
||||
|
||||
# Validating the path wrt. the base path is done in _open_path
|
||||
path = os.path.realpath(path)
|
||||
|
||||
self.history.append((path, time.time()))
|
||||
self._open_path(path)
|
||||
self.viewport_current_page = 0
|
||||
self.find_on_page_query = None
|
||||
self.find_on_page_viewport = None
|
||||
|
||||
@property
|
||||
def viewport(self) -> str:
|
||||
"""Return the content of the current viewport."""
|
||||
bounds = self.viewport_pages[self.viewport_current_page]
|
||||
return self.page_content[bounds[0] : bounds[1]]
|
||||
|
||||
@property
|
||||
def page_content(self) -> str:
|
||||
"""Return the full contents of the current page."""
|
||||
return self._page_content
|
||||
|
||||
def _set_page_content(self, content: str, split_pages: bool = True) -> None:
|
||||
"""Sets the text content of the current page."""
|
||||
self._page_content = content
|
||||
|
||||
if split_pages:
|
||||
self._split_pages()
|
||||
else:
|
||||
self.viewport_pages = [(0, len(self._page_content))]
|
||||
|
||||
if self.viewport_current_page >= len(self.viewport_pages):
|
||||
self.viewport_current_page = len(self.viewport_pages) - 1
|
||||
|
||||
def page_down(self) -> None:
|
||||
"""Move the viewport down one page, if possible."""
|
||||
self.viewport_current_page = min(self.viewport_current_page + 1, len(self.viewport_pages) - 1)
|
||||
|
||||
def page_up(self) -> None:
|
||||
"""Move the viewport up one page, if possible."""
|
||||
self.viewport_current_page = max(self.viewport_current_page - 1, 0)
|
||||
|
||||
def find_on_page(self, query: str) -> Union[str, None]:
|
||||
"""Searches for the query from the current viewport forward, looping back to the start if necessary."""
|
||||
|
||||
# Did we get here via a previous find_on_page search with the same query?
|
||||
# If so, map to find_next
|
||||
if query == self._find_on_page_query and self.viewport_current_page == self._find_on_page_last_result:
|
||||
return self.find_next()
|
||||
|
||||
# Ok it's a new search start from the current viewport
|
||||
self._find_on_page_query = query
|
||||
viewport_match = self._find_next_viewport(query, self.viewport_current_page)
|
||||
if viewport_match is None:
|
||||
self._find_on_page_last_result = None
|
||||
return None
|
||||
else:
|
||||
self.viewport_current_page = viewport_match
|
||||
self._find_on_page_last_result = viewport_match
|
||||
return self.viewport
|
||||
|
||||
def find_next(self) -> Union[str, None]:
|
||||
"""Scroll to the next viewport that matches the query"""
|
||||
|
||||
if self._find_on_page_query is None:
|
||||
return None
|
||||
|
||||
starting_viewport = self._find_on_page_last_result
|
||||
if starting_viewport is None:
|
||||
starting_viewport = 0
|
||||
else:
|
||||
starting_viewport += 1
|
||||
if starting_viewport >= len(self.viewport_pages):
|
||||
starting_viewport = 0
|
||||
|
||||
viewport_match = self._find_next_viewport(self._find_on_page_query, starting_viewport)
|
||||
if viewport_match is None:
|
||||
self._find_on_page_last_result = None
|
||||
return None
|
||||
else:
|
||||
self.viewport_current_page = viewport_match
|
||||
self._find_on_page_last_result = viewport_match
|
||||
return self.viewport
|
||||
|
||||
def _find_next_viewport(self, query: Optional[str], starting_viewport: int) -> Union[int, None]:
|
||||
"""Search for matches between the starting viewport looping when reaching the end."""
|
||||
|
||||
if query is None:
|
||||
return None
|
||||
|
||||
# Normalize the query, and convert to a regular expression
|
||||
nquery = re.sub(r"\*", "__STAR__", query)
|
||||
nquery = " " + (" ".join(re.split(r"\W+", nquery))).strip() + " "
|
||||
nquery = nquery.replace(" __STAR__ ", "__STAR__ ") # Merge isolated stars with prior word
|
||||
nquery = nquery.replace("__STAR__", ".*").lower()
|
||||
|
||||
if nquery.strip() == "":
|
||||
return None
|
||||
|
||||
idxs: List[int] = list()
|
||||
idxs.extend(range(starting_viewport, len(self.viewport_pages)))
|
||||
idxs.extend(range(0, starting_viewport))
|
||||
|
||||
for i in idxs:
|
||||
bounds = self.viewport_pages[i]
|
||||
content = self.page_content[bounds[0] : bounds[1]]
|
||||
|
||||
# TODO: Remove markdown links and images
|
||||
ncontent = " " + (" ".join(re.split(r"\W+", content))).strip().lower() + " "
|
||||
if re.search(nquery, ncontent):
|
||||
return i
|
||||
|
||||
return None
|
||||
|
||||
def open_path(self, path: str) -> str:
|
||||
"""Open a file or directory in the file surfer."""
|
||||
self.set_path(path)
|
||||
return self.viewport
|
||||
|
||||
def _split_pages(self) -> None:
|
||||
"""Split the page contents into pages that are approximately the viewport size. Small deviations are permitted to ensure words are not broken."""
|
||||
# Handle empty pages
|
||||
if len(self._page_content) == 0:
|
||||
self.viewport_pages = [(0, 0)]
|
||||
return
|
||||
|
||||
# Break the viewport into pages
|
||||
self.viewport_pages = []
|
||||
start_idx = 0
|
||||
while start_idx < len(self._page_content):
|
||||
end_idx = min(start_idx + self.viewport_size, len(self._page_content)) # type: ignore[operator]
|
||||
# Adjust to end on a space
|
||||
while end_idx < len(self._page_content) and self._page_content[end_idx - 1] not in [" ", "\t", "\r", "\n"]:
|
||||
end_idx += 1
|
||||
self.viewport_pages.append((start_idx, end_idx))
|
||||
start_idx = end_idx
|
||||
|
||||
def _open_path(
|
||||
self,
|
||||
path: str,
|
||||
) -> None:
|
||||
"""Open a file for reading, converting it to Markdown in the process.
|
||||
|
||||
Arguments:
|
||||
path: The path of the file or directory to open.
|
||||
"""
|
||||
|
||||
if not self._validate_path(path):
|
||||
# Not robust to TOCTOU issues.
|
||||
# Mitigate by running with limited permissions, or use a sandbox.
|
||||
self.page_title = "FileNotFoundError"
|
||||
self._set_page_content(f"# FileNotFoundError\n\nFile not found: {path}")
|
||||
else:
|
||||
try:
|
||||
if os.path.isdir(path): # TODO: Fix markdown_converter types
|
||||
res = self._markdown_converter.convert_stream( # type: ignore
|
||||
io.BytesIO(self._fetch_local_dir(path).encode("utf-8")), file_extension=".txt"
|
||||
)
|
||||
assert self._validate_path(path)
|
||||
self.page_title = res.title
|
||||
self._set_page_content(res.text_content, split_pages=False)
|
||||
else:
|
||||
res = self._markdown_converter.convert_local(path)
|
||||
assert self._validate_path(path)
|
||||
self.page_title = res.title
|
||||
self._set_page_content(res.text_content)
|
||||
except UnsupportedFormatException:
|
||||
self.page_title = "UnsupportedFormatException"
|
||||
self._set_page_content(f"# UnsupportedFormatException\n\nCannot preview '{path}' as Markdown.")
|
||||
except FileConversionException:
|
||||
self.page_title = "FileConversionException."
|
||||
self._set_page_content(f"# FileConversionException\n\nError converting '{path}' to Markdown.")
|
||||
except FileNotFoundError:
|
||||
self.page_title = "FileNotFoundError"
|
||||
self._set_page_content(f"# FileNotFoundError\n\nFile not found: {path}")
|
||||
|
||||
def _fetch_local_dir(self, local_path: str) -> str:
|
||||
"""Render a local directory listing in HTML to assist with local file browsing via the "file://" protocol.
|
||||
Through rendered in HTML, later parts of the pipeline will convert the listing to Markdown.
|
||||
|
||||
Arguments:
|
||||
local_path: A path to the local directory whose contents are to be listed.
|
||||
|
||||
Returns:
|
||||
A directory listing, rendered in HTML.
|
||||
"""
|
||||
listing = f"""
|
||||
# Index of {local_path}
|
||||
|
||||
| Name | Size | Date Modified |
|
||||
| ---- | ---- | ------------- |
|
||||
| .. (parent directory) | | |
|
||||
"""
|
||||
for entry in os.listdir(local_path):
|
||||
size = ""
|
||||
full_path = os.path.join(local_path, entry)
|
||||
|
||||
mtime = ""
|
||||
try:
|
||||
mtime = datetime.datetime.fromtimestamp(os.path.getmtime(full_path)).strftime("%Y-%m-%d %H:%M")
|
||||
except Exception as e:
|
||||
# Handles PermissionError, etc.
|
||||
mtime = f"N/A: {type(e).__name__}"
|
||||
|
||||
if os.path.isdir(full_path):
|
||||
entry = entry + os.path.sep
|
||||
else:
|
||||
try:
|
||||
size = str(os.path.getsize(full_path))
|
||||
except Exception as e:
|
||||
# Handles PermissionError, etc.
|
||||
size = f"N/A: {type(e).__name__}"
|
||||
|
||||
listing += f"| {entry} | {size} | {mtime} |\n"
|
||||
return listing
|
||||
@@ -0,0 +1,50 @@
|
||||
from agentdhal_core.tools import ParametersSchema, ToolSchema
|
||||
|
||||
TOOL_OPEN_PATH = ToolSchema(
|
||||
name="open_path",
|
||||
description="Open a local file or directory at a path in the text-based file browser and return current viewport content.",
|
||||
parameters=ParametersSchema(
|
||||
type="object",
|
||||
properties={
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The relative or absolute path of a local file to visit.",
|
||||
},
|
||||
},
|
||||
required=["path"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
TOOL_PAGE_UP = ToolSchema(
|
||||
name="page_up",
|
||||
description="Scroll the viewport UP one page-length in the current file and return the new viewport content.",
|
||||
)
|
||||
|
||||
|
||||
TOOL_PAGE_DOWN = ToolSchema(
|
||||
name="page_down",
|
||||
description="Scroll the viewport DOWN one page-length in the current file and return the new viewport content.",
|
||||
)
|
||||
|
||||
|
||||
TOOL_FIND_ON_PAGE_CTRL_F = ToolSchema(
|
||||
name="find_on_page_ctrl_f",
|
||||
description="Scroll the viewport to the first occurrence of the search string. This is equivalent to Ctrl+F.",
|
||||
parameters=ParametersSchema(
|
||||
type="object",
|
||||
properties={
|
||||
"search_string": {
|
||||
"type": "string",
|
||||
"description": "The string to search for on the page. This search string supports wildcards like '*'",
|
||||
},
|
||||
},
|
||||
required=["search_string"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
TOOL_FIND_NEXT = ToolSchema(
|
||||
name="find_next",
|
||||
description="Scroll the viewport to next occurrence of the search string.",
|
||||
)
|
||||
@@ -0,0 +1,10 @@
|
||||
try:
|
||||
from ._magentic_one_coder_agent import MagenticOneCoderAgent
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Dependencies for MagenticOneCoderAgent not found. "
|
||||
'Please install autogen-ext with the "magentic-one" extra: '
|
||||
'pip install "agentdhal-ext[magentic-one]"'
|
||||
) from e
|
||||
|
||||
__all__ = ["MagenticOneCoderAgent"]
|
||||
@@ -0,0 +1,41 @@
|
||||
from typing import Any
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_core.models import (
|
||||
ChatCompletionClient,
|
||||
)
|
||||
|
||||
MAGENTIC_ONE_CODER_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills."
|
||||
|
||||
MAGENTIC_ONE_CODER_SYSTEM_MESSAGE = """You are a helpful AI assistant.
|
||||
Solve tasks using your coding and language skills.
|
||||
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
|
||||
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
|
||||
2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.
|
||||
Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.
|
||||
When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.
|
||||
Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use the 'print' function for the output when relevant. Check the execution result returned by the user.
|
||||
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
|
||||
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible."""
|
||||
|
||||
|
||||
class MagenticOneCoderAgent(AssistantAgent):
|
||||
"""An agent, used by MagenticOne that provides coding assistance using an LLM model client.
|
||||
|
||||
The prompts and description are sealed, to replicate the original MagenticOne configuration. See AssistantAgent if you wish to modify these values.
|
||||
"""
|
||||
|
||||
component_provider_override = "agentdhal_extensions.agents.magentic_one.MagenticOneCoderAgent"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_client: ChatCompletionClient,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
model_client,
|
||||
description=MAGENTIC_ONE_CODER_DESCRIPTION,
|
||||
system_message=MAGENTIC_ONE_CODER_SYSTEM_MESSAGE,
|
||||
)
|
||||
@@ -0,0 +1,7 @@
|
||||
from ._openai_agent import OpenAIAgent
|
||||
from ._openai_assistant_agent import OpenAIAssistantAgent
|
||||
|
||||
__all__ = [
|
||||
"OpenAIAgent",
|
||||
"OpenAIAssistantAgent",
|
||||
]
|
||||
682
agent_dhal/agentdhal_extensions/agents/openai/_openai_agent.py
Normal file
682
agent_dhal/agentdhal_extensions/agents/openai/_openai_agent.py
Normal file
@@ -0,0 +1,682 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from agentdhal_agentchat import EVENT_LOGGER_NAME
|
||||
from agentdhal_agentchat.agents import BaseChatAgent
|
||||
from agentdhal_agentchat.base import Response
|
||||
from agentdhal_agentchat.messages import (
|
||||
AgentEvent,
|
||||
BaseChatMessage,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallSummaryMessage,
|
||||
)
|
||||
from agentdhal_core import CancellationToken, Component
|
||||
from agentdhal_core.models import UserMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI # type: ignore
|
||||
|
||||
# Number of characters to display when previewing image content in logs and UI
|
||||
# Base64 encoded images can be very long, so we truncate for readability
|
||||
IMAGE_CONTENT_PREVIEW_LENGTH = 50
|
||||
|
||||
# NOTE: We use the new Responses API, so ChatCompletion imports are not needed.
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
# TypedDict classes for built-in tool configurations
|
||||
class FileSearchToolConfig(TypedDict):
|
||||
"""Configuration for file_search tool."""
|
||||
|
||||
type: Literal["file_search"]
|
||||
vector_store_ids: List[str] # required - The IDs of the vector stores to search
|
||||
max_num_results: NotRequired[int] # optional
|
||||
ranking_options: NotRequired[Dict[str, Any]] # optional
|
||||
filters: NotRequired[Dict[str, Any]] # optional
|
||||
|
||||
|
||||
class WebSearchToolConfig(TypedDict):
|
||||
"""Configuration for web_search_preview tool."""
|
||||
|
||||
type: Literal["web_search_preview"]
|
||||
search_context_size: NotRequired[str] # optional
|
||||
user_location: NotRequired[Union[str, Dict[str, Any]]] # optional - Can be string or structured location
|
||||
|
||||
|
||||
class ComputerUseToolConfig(TypedDict):
|
||||
"""Configuration for computer_use_preview tool."""
|
||||
|
||||
type: Literal["computer_use_preview"]
|
||||
display_height: int # required - Display height in pixels
|
||||
display_width: int # required - Display width in pixels
|
||||
environment: str # required - Environment type for computer use
|
||||
|
||||
|
||||
class MCPToolConfig(TypedDict):
|
||||
"""Configuration for mcp tool."""
|
||||
|
||||
type: Literal["mcp"]
|
||||
server_label: str # required - Label for the MCP server
|
||||
server_url: str # required - URL of the MCP server
|
||||
allowed_tools: NotRequired[List[str]] # optional - List of allowed tools
|
||||
headers: NotRequired[Dict[str, str]] # optional - HTTP headers for requests
|
||||
require_approval: NotRequired[bool] # optional - Whether to require user approval
|
||||
|
||||
|
||||
class CodeInterpreterToolConfig(TypedDict):
|
||||
"""Configuration for code_interpreter tool."""
|
||||
|
||||
type: Literal["code_interpreter"]
|
||||
container: str | Dict[str, Any] # required - Container configuration for code execution
|
||||
|
||||
|
||||
class ImageGenerationToolConfig(TypedDict):
|
||||
"""Configuration for image_generation tool."""
|
||||
|
||||
type: Literal["image_generation"]
|
||||
background: NotRequired[str] # optional - Background color or image
|
||||
input_image_mask: NotRequired[str] # optional - Mask for input image editing
|
||||
|
||||
|
||||
class LocalShellToolConfig(TypedDict):
|
||||
"""Configuration for local_shell tool.
|
||||
|
||||
WARNING: This tool is only supported with the 'codex-mini-latest' model
|
||||
and is available exclusively through the Responses API.
|
||||
"""
|
||||
|
||||
type: Literal["local_shell"]
|
||||
# Note: local_shell currently has no additional parameters in the API
|
||||
|
||||
|
||||
# Union type for all built-in tool configurations
|
||||
BuiltinToolConfig = Union[
|
||||
FileSearchToolConfig,
|
||||
WebSearchToolConfig,
|
||||
ComputerUseToolConfig,
|
||||
MCPToolConfig,
|
||||
CodeInterpreterToolConfig,
|
||||
ImageGenerationToolConfig,
|
||||
LocalShellToolConfig,
|
||||
]
|
||||
|
||||
|
||||
# Define ImageMessage class early since it's used in _convert_message_to_openai_message
|
||||
class ImageMessage(BaseChatMessage):
|
||||
"""A message containing an image."""
|
||||
|
||||
content: str # URL or base64 string
|
||||
|
||||
def to_model_message(self) -> UserMessage:
|
||||
return UserMessage(content=self.content, source=self.source)
|
||||
|
||||
def to_model_text(self) -> str:
|
||||
return "[image]"
|
||||
|
||||
def to_text(self) -> str:
|
||||
# Truncate long image content (especially base64) for better readability
|
||||
# While still showing enough of the URL or content to be identifiable
|
||||
if len(self.content) > IMAGE_CONTENT_PREVIEW_LENGTH:
|
||||
return f"[Image: {self.content[:IMAGE_CONTENT_PREVIEW_LENGTH]}...]"
|
||||
return f"[Image: {self.content}]"
|
||||
|
||||
|
||||
class OpenAIMessageContent(TypedDict):
|
||||
type: str
|
||||
text: str
|
||||
|
||||
|
||||
class OpenAIImageUrlContent(TypedDict):
|
||||
url: str
|
||||
|
||||
|
||||
class OpenAIImageContent(TypedDict):
|
||||
type: str
|
||||
image_url: OpenAIImageUrlContent
|
||||
|
||||
|
||||
class OpenAIMessage(TypedDict):
|
||||
role: str
|
||||
content: Union[str, List[Union[OpenAIMessageContent, OpenAIImageContent]]]
|
||||
|
||||
|
||||
def _convert_message_to_openai_message(
|
||||
message: Union[TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage],
|
||||
) -> OpenAIMessage:
|
||||
"""Convert an AutoGen message to an OpenAI message format."""
|
||||
if isinstance(message, TextMessage):
|
||||
if message.source == "user":
|
||||
return {"role": "user", "content": str(message.content)}
|
||||
elif message.source == "system":
|
||||
return {"role": "system", "content": str(message.content)}
|
||||
elif message.source == "assistant":
|
||||
return {"role": "assistant", "content": str(message.content)}
|
||||
else:
|
||||
return {"role": "user", "content": str(message.content)}
|
||||
elif isinstance(message, MultiModalMessage):
|
||||
content_parts: List[Union[OpenAIMessageContent, OpenAIImageContent]] = []
|
||||
for part in message.content:
|
||||
if isinstance(part, TextMessage):
|
||||
content_parts.append({"type": "text", "text": str(part.content)})
|
||||
elif isinstance(part, ImageMessage):
|
||||
image_content = str(part.content)
|
||||
content_parts.append({"type": "image_url", "image_url": {"url": image_content}})
|
||||
return {"role": "user", "content": content_parts}
|
||||
else:
|
||||
return {"role": "user", "content": str(message.content)}
|
||||
|
||||
|
||||
class OpenAIAgentState(BaseModel):
|
||||
type: str = Field(default="OpenAIAgentState")
|
||||
response_id: Optional[str] = None
|
||||
history: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OpenAIAgentConfig(BaseModel):
|
||||
"""
|
||||
Configuration model for OpenAI agent supporting OpenAI built-in tools only.
|
||||
|
||||
.. versionchanged:: v0.7.0
|
||||
Added support for built-in tools in JSON configuration via _to_config and _from_config methods.
|
||||
The tools field accepts built-in tool configurations (dict format) and built-in tool names (string format).
|
||||
Custom tools are not supported.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
model: str
|
||||
instructions: str
|
||||
tools: List[Dict[str, Any] | str] | None = None
|
||||
temperature: Optional[float] = 1
|
||||
max_output_tokens: Optional[int] = None
|
||||
json_mode: bool = False
|
||||
store: bool = True
|
||||
truncation: str = "disabled"
|
||||
|
||||
|
||||
class OpenAIAgent(BaseChatAgent, Component[OpenAIAgentConfig]):
|
||||
"""
|
||||
An agent implementation that uses the OpenAI Responses API to generate responses.
|
||||
|
||||
Installation:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[openai]"
|
||||
# pip install "agentdhal-ext[openai,azure]" # For Azure OpenAI Assistant
|
||||
|
||||
This agent leverages the Responses API to generate responses with capabilities like:
|
||||
|
||||
* Multi-turn conversations
|
||||
* Built-in tool support (file_search, code_interpreter, web_search_preview, etc.)
|
||||
|
||||
Currently, custom tools are not supported.
|
||||
|
||||
.. versionchanged:: v0.7.0
|
||||
|
||||
Added support for built-in tool types like file_search, web_search_preview,
|
||||
code_interpreter, computer_use_preview, image_generation, and mcp.
|
||||
Added support for tool configurations with required and optional parameters.
|
||||
|
||||
Built-in tools are split into two categories:
|
||||
|
||||
**Tools that can use string format** (no required parameters):
|
||||
|
||||
- web_search_preview: Can be used as "web_search_preview" or with optional config
|
||||
(user_location, search_context_size)
|
||||
- image_generation: Can be used as "image_generation" or with optional config (background, input_image_mask)
|
||||
- local_shell: Can be used as "local_shell" (WARNING: Only works with codex-mini-latest model)
|
||||
|
||||
**Tools that REQUIRE dict configuration** (have required parameters):
|
||||
|
||||
- file_search: MUST use dict with vector_store_ids (List[str])
|
||||
- computer_use_preview: MUST use dict with display_height (int), display_width (int), environment (str)
|
||||
- code_interpreter: MUST use dict with container (str)
|
||||
- mcp: MUST use dict with server_label (str), server_url (str)
|
||||
|
||||
Using required-parameter tools in string format will raise a ValueError with helpful error messages.
|
||||
The tools parameter type annotation only accepts string values for tools that don't require parameters.
|
||||
|
||||
Note:
|
||||
Custom tools (autogen FunctionTool or other user-defined tools) are not supported by this agent.
|
||||
Only OpenAI built-in tools provided via the Responses API are supported.
|
||||
|
||||
|
||||
Args:
|
||||
name (str): Name of the agent
|
||||
description (str): Description of the agent's purpose
|
||||
client (Union[AsyncOpenAI, AsyncAzureOpenAI]): OpenAI client instance
|
||||
model (str): Model to use (e.g. "gpt-4.1")
|
||||
instructions (str): System instructions for the agent
|
||||
tools (Optional[Iterable[Union[str, BuiltinToolConfig]]]): Tools the agent can use.
|
||||
Supported string values (no required parameters): "web_search_preview", "image_generation", "local_shell".
|
||||
Dict values can provide configuration for built-in tools with parameters.
|
||||
Required parameters for built-in tools:
|
||||
- file_search: vector_store_ids (List[str])
|
||||
- computer_use_preview: display_height (int), display_width (int), environment (str)
|
||||
- code_interpreter: container (str)
|
||||
- mcp: server_label (str), server_url (str)
|
||||
Optional parameters for built-in tools:
|
||||
- file_search: max_num_results (int), ranking_options (dict), filters (dict)
|
||||
- web_search_preview: user_location (str or dict), search_context_size (int)
|
||||
- image_generation: background (str), input_image_mask (str)
|
||||
- mcp: allowed_tools (List[str]), headers (dict), require_approval (bool)
|
||||
Special tools with model restrictions:
|
||||
- local_shell: Only works with "codex-mini-latest" model (WARNING: Very limited support)
|
||||
Custom tools are not supported.
|
||||
temperature (Optional[float]): Temperature for response generation (default: 1)
|
||||
max_output_tokens (Optional[int]): Maximum output tokens
|
||||
json_mode (bool): Whether to use JSON mode (default: False)
|
||||
store (bool): Whether to store conversations (default: True)
|
||||
truncation (str): Truncation strategy (default: "disabled")
|
||||
|
||||
Example:
|
||||
|
||||
Basic usage with built-in tools:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.agents.openai import OpenAIAgent
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
async def example():
|
||||
client = AsyncOpenAI()
|
||||
agent = OpenAIAgent(
|
||||
name="SimpleAgent",
|
||||
description="A simple OpenAI agent using the Responses API",
|
||||
client=client,
|
||||
model="gpt-4.1",
|
||||
instructions="You are a helpful assistant.",
|
||||
tools=["web_search_preview"], # Only tools without required params
|
||||
)
|
||||
await Console(agent.run_stream(task="Search for recent AI developments"))
|
||||
|
||||
|
||||
asyncio.run(example())
|
||||
|
||||
Usage with configured built-in tools:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.agents.openai import OpenAIAgent
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
async def example_with_configs():
|
||||
client = AsyncOpenAI()
|
||||
# Configure tools with required and optional parameters
|
||||
tools = [
|
||||
# {
|
||||
# "type": "file_search",
|
||||
# "vector_store_ids": ["vs_abc123"], # required
|
||||
# "max_num_results": 10, # optional
|
||||
# },
|
||||
# {
|
||||
# "type": "computer_use_preview",
|
||||
# "display_height": 1024, # required
|
||||
# "display_width": 1280, # required
|
||||
# "environment": "linux", # required
|
||||
# },
|
||||
{
|
||||
"type": "code_interpreter",
|
||||
"container": {"type": "auto"}, # required
|
||||
},
|
||||
# {
|
||||
# "type": "mcp",
|
||||
# "server_label": "my-mcp-server", # required
|
||||
# "server_url": "http://localhost:3000", # required
|
||||
# },
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"user_location": { # optional - structured location
|
||||
"type": "approximate", # required: "approximate" or "exact"
|
||||
"country": "US", # optional
|
||||
"region": "CA", # optional
|
||||
"city": "San Francisco", # optional
|
||||
},
|
||||
"search_context_size": "low", # optional
|
||||
},
|
||||
# "image_generation", # Simple tools can still use string format
|
||||
]
|
||||
|
||||
agent = OpenAIAgent(
|
||||
name="ConfiguredAgent",
|
||||
description="An agent with configured tools",
|
||||
client=client,
|
||||
model="gpt-4.1",
|
||||
instructions="You are a helpful assistant with specialized tools.",
|
||||
tools=tools, # type: ignore
|
||||
)
|
||||
await Console(agent.run_stream(task="Search for recent AI developments"))
|
||||
|
||||
|
||||
asyncio.run(example_with_configs())
|
||||
|
||||
|
||||
Note:
|
||||
Custom tools are not supported by OpenAIAgent. Use only built-in tools from the Responses API.
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = OpenAIAgentConfig
|
||||
component_provider_override = "agentdhal_extensions.agents.openai.OpenAIAgent"
|
||||
|
||||
def __init__(
|
||||
self: "OpenAIAgent",
|
||||
name: str,
|
||||
description: str,
|
||||
client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
||||
model: str,
|
||||
instructions: str,
|
||||
tools: Optional[
|
||||
Iterable[
|
||||
Union[
|
||||
Literal["web_search_preview", "image_generation", "local_shell"],
|
||||
BuiltinToolConfig,
|
||||
]
|
||||
]
|
||||
] = None,
|
||||
temperature: Optional[float] = 1,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
json_mode: bool = False,
|
||||
store: bool = True,
|
||||
truncation: str = "disabled",
|
||||
) -> None:
|
||||
super().__init__(name, description)
|
||||
self._client: Union[AsyncOpenAI, AsyncAzureOpenAI] = client
|
||||
self._model: str = model
|
||||
self._instructions: str = instructions
|
||||
self._temperature: Optional[float] = temperature
|
||||
self._max_output_tokens: Optional[int] = max_output_tokens
|
||||
self._json_mode: bool = json_mode
|
||||
self._store: bool = store
|
||||
self._truncation: str = truncation
|
||||
self._last_response_id: Optional[str] = None
|
||||
self._message_history: List[Dict[str, Any]] = []
|
||||
self._tools: List[Dict[str, Any]] = []
|
||||
if tools is not None:
|
||||
for tool in tools:
|
||||
if isinstance(tool, str):
|
||||
# Handle built-in tool types
|
||||
self._add_builtin_tool(tool)
|
||||
elif isinstance(tool, dict) and "type" in tool:
|
||||
# Handle configured built-in tools
|
||||
self._tools.append(cast(dict[str, Any], tool))
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||
|
||||
def _add_builtin_tool(self, tool_name: str) -> None:
|
||||
"""Add a built-in tool by name."""
|
||||
# Skip if an identical tool has already been registered (idempotent behaviour)
|
||||
if any(td.get("type") == tool_name for td in self._tools):
|
||||
return # Duplicate – ignore rather than raise to stay backward-compatible
|
||||
# Only allow string format for tools that don't require parameters
|
||||
if tool_name == "web_search_preview":
|
||||
self._tools.append({"type": "web_search_preview"})
|
||||
elif tool_name == "image_generation":
|
||||
self._tools.append({"type": "image_generation"})
|
||||
elif tool_name == "local_shell":
|
||||
# Special handling for local_shell - very limited model support
|
||||
if self._model != "codex-mini-latest":
|
||||
raise ValueError(
|
||||
f"Tool 'local_shell' is only supported with model 'codex-mini-latest', "
|
||||
f"but current model is '{self._model}'. "
|
||||
f"This tool is available exclusively through the Responses API and has severe limitations. "
|
||||
f"Consider using agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool with "
|
||||
f"agentdhal_extensions.code_executors.local.LocalCommandLineCodeExecutor for shell execution instead."
|
||||
)
|
||||
self._tools.append({"type": "local_shell"})
|
||||
elif tool_name in ["file_search", "code_interpreter", "computer_use_preview", "mcp"]:
|
||||
# These tools require specific parameters and must use dict configuration
|
||||
raise ValueError(
|
||||
f"Tool '{tool_name}' requires specific parameters and cannot be added using string format. "
|
||||
f"Use dict configuration instead. Required parameters for {tool_name}: "
|
||||
f"{self._get_required_params_help(tool_name)}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported built-in tool type: {tool_name}")
|
||||
|
||||
def _get_required_params_help(self, tool_name: str) -> str:
|
||||
"""Get help text for required parameters of a tool."""
|
||||
help_text = {
|
||||
"file_search": "vector_store_ids (List[str])",
|
||||
"code_interpreter": "container (str | dict)",
|
||||
"computer_use_preview": "display_height (int), display_width (int), environment (str)",
|
||||
"mcp": "server_label (str), server_url (str)",
|
||||
}
|
||||
return help_text.get(tool_name, "unknown parameters")
|
||||
|
||||
def _convert_message_to_dict(self, message: OpenAIMessage) -> Dict[str, Any]:
|
||||
"""Convert an OpenAIMessage to a Dict[str, Any]."""
|
||||
return dict(message)
|
||||
|
||||
@property
|
||||
def produced_message_types(
|
||||
self: "OpenAIAgent",
|
||||
) -> Sequence[
|
||||
Union[
|
||||
Type[TextMessage],
|
||||
Type[MultiModalMessage],
|
||||
Type[StopMessage],
|
||||
Type[ToolCallSummaryMessage],
|
||||
Type[HandoffMessage],
|
||||
]
|
||||
]:
|
||||
"""Return the types of messages that this agent can produce."""
|
||||
return [TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage]
|
||||
|
||||
# Custom tool execution is not supported by this agent.
|
||||
|
||||
def _build_api_parameters(self: "OpenAIAgent", messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
has_system_message = any(msg.get("role") == "system" for msg in messages)
|
||||
if self._instructions and not has_system_message:
|
||||
messages = [{"role": "system", "content": self._instructions}] + messages
|
||||
api_params: Dict[str, Any] = {
|
||||
"model": self._model,
|
||||
"input": messages, # Responses API expects 'input'
|
||||
}
|
||||
if self._temperature is not None:
|
||||
api_params["temperature"] = self._temperature
|
||||
if self._max_output_tokens is not None:
|
||||
api_params["max_output_tokens"] = self._max_output_tokens
|
||||
if self._tools:
|
||||
api_params["tools"] = self._tools
|
||||
if self._json_mode:
|
||||
api_params["text"] = {"type": "json_object"}
|
||||
api_params["store"] = self._store
|
||||
api_params["truncation"] = self._truncation
|
||||
if self._last_response_id:
|
||||
api_params["previous_response_id"] = self._last_response_id
|
||||
return api_params
|
||||
|
||||
async def on_messages(
|
||||
self: "OpenAIAgent", messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> Response:
|
||||
response = None
|
||||
inner_messages: List[
|
||||
Union[AgentEvent, TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage]
|
||||
] = []
|
||||
|
||||
async for msg in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(msg, Response):
|
||||
response = msg
|
||||
# ModelClientStreamingChunkEvent does not exist in this version, so skip this check
|
||||
else:
|
||||
inner_messages.append(msg)
|
||||
|
||||
if response is None:
|
||||
raise ValueError("No response was generated")
|
||||
|
||||
if response.inner_messages is None:
|
||||
response.inner_messages = []
|
||||
|
||||
for msg in inner_messages:
|
||||
if msg not in response.inner_messages:
|
||||
response.inner_messages = list(response.inner_messages) + [msg]
|
||||
|
||||
return response
|
||||
|
||||
async def on_messages_stream(
|
||||
self: "OpenAIAgent", messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[
|
||||
Union[
|
||||
AgentEvent, TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage, Response
|
||||
],
|
||||
None,
|
||||
]:
|
||||
input_messages: List[Dict[str, Any]] = []
|
||||
|
||||
if self._message_history:
|
||||
input_messages.extend(self._message_history)
|
||||
|
||||
for message in messages:
|
||||
if isinstance(
|
||||
message, (TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage)
|
||||
):
|
||||
openai_message = _convert_message_to_openai_message(message)
|
||||
dict_message = self._convert_message_to_dict(openai_message)
|
||||
input_messages.append(dict_message)
|
||||
self._message_history.append(dict_message)
|
||||
else:
|
||||
msg_content = str(cast(Any, message).content) if hasattr(message, "content") else str(message)
|
||||
dict_message = {"role": "user", "content": msg_content}
|
||||
input_messages.append(dict_message)
|
||||
self._message_history.append(dict_message)
|
||||
|
||||
inner_messages: List[AgentEvent | ChatMessage] = []
|
||||
|
||||
api_params = self._build_api_parameters(input_messages)
|
||||
|
||||
try:
|
||||
client = cast(Any, self._client)
|
||||
response_obj = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(client.responses.create(**api_params))
|
||||
)
|
||||
content = getattr(response_obj, "output_text", None)
|
||||
response_id = getattr(response_obj, "id", None)
|
||||
self._last_response_id = response_id
|
||||
# Use a readable placeholder when the API returns no content to aid debugging
|
||||
content_str: str = str(content) if content is not None else "[no content returned]"
|
||||
self._message_history.append({"role": "assistant", "content": content_str})
|
||||
final_message = TextMessage(source=self.name, content=content_str)
|
||||
response = Response(chat_message=final_message, inner_messages=inner_messages)
|
||||
yield response
|
||||
except Exception as e:
|
||||
error_message = f"Error generating response: {str(e)}"
|
||||
event_logger.error(f"API error: {error_message}", exc_info=True)
|
||||
error_response = TextMessage(source=self.name, content=error_message)
|
||||
yield Response(chat_message=error_response, inner_messages=inner_messages)
|
||||
|
||||
async def on_reset(self: "OpenAIAgent", cancellation_token: CancellationToken) -> None:
|
||||
self._last_response_id = None
|
||||
self._message_history = []
|
||||
|
||||
async def save_state(self: "OpenAIAgent") -> Mapping[str, Any]:
|
||||
state = OpenAIAgentState(
|
||||
response_id=self._last_response_id,
|
||||
history=self._message_history,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self: "OpenAIAgent", state: Mapping[str, Any]) -> None:
|
||||
agent_state = OpenAIAgentState.model_validate(state)
|
||||
self._last_response_id = agent_state.response_id
|
||||
self._message_history = agent_state.history
|
||||
|
||||
def _to_config(self: "OpenAIAgent") -> OpenAIAgentConfig:
|
||||
"""Convert the OpenAI agent to a declarative config.
|
||||
|
||||
Serializes built-in tools to their appropriate configuration formats for JSON serialization.
|
||||
|
||||
Returns:
|
||||
OpenAIAgentConfig: The configuration that can recreate this agent.
|
||||
"""
|
||||
return OpenAIAgentConfig(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
model=self._model,
|
||||
instructions=self._instructions,
|
||||
tools=list(self._tools),
|
||||
temperature=self._temperature,
|
||||
max_output_tokens=self._max_output_tokens,
|
||||
json_mode=self._json_mode,
|
||||
store=self._store,
|
||||
truncation=self._truncation,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls: Type["OpenAIAgent"], config: OpenAIAgentConfig) -> "OpenAIAgent":
|
||||
"""Create an OpenAI agent from a declarative config.
|
||||
|
||||
Handles built-in tools (from string or dict configurations).
|
||||
|
||||
Args:
|
||||
config: The configuration to load the agent from.
|
||||
|
||||
Returns:
|
||||
OpenAIAgent: The reconstructed agent.
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI()
|
||||
|
||||
return cls(
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
client=client,
|
||||
model=config.model,
|
||||
instructions=config.instructions,
|
||||
tools=config.tools, # type: ignore
|
||||
temperature=config.temperature,
|
||||
max_output_tokens=config.max_output_tokens,
|
||||
json_mode=config.json_mode,
|
||||
store=config.store,
|
||||
truncation=config.truncation,
|
||||
)
|
||||
|
||||
# Add public API wrappers for configuration and tools
|
||||
def to_config(self) -> OpenAIAgentConfig:
|
||||
"""Public wrapper for the private _to_config method."""
|
||||
return self._to_config()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: OpenAIAgentConfig) -> "OpenAIAgent":
|
||||
"""Public wrapper for the private _from_config classmethod."""
|
||||
return cls._from_config(config)
|
||||
|
||||
@property
|
||||
def tools(self) -> list[Any]:
|
||||
"""Public access to the agent's tools."""
|
||||
return self._tools
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
"""Public access to the agent's model."""
|
||||
return self._model
|
||||
@@ -0,0 +1,715 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import aiofiles
|
||||
from agentdhal_agentchat import EVENT_LOGGER_NAME
|
||||
from agentdhal_agentchat.agents import BaseChatAgent
|
||||
from agentdhal_agentchat.base import Response
|
||||
from agentdhal_agentchat.messages import (
|
||||
BaseAgentEvent,
|
||||
BaseChatMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
)
|
||||
from agentdhal_core import CancellationToken, FunctionCall, Image
|
||||
from agentdhal_core.models import ChatCompletionClient, FunctionExecutionResult
|
||||
from agentdhal_core.tools import FunctionTool, Tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI, NotGiven
|
||||
from openai.pagination import AsyncCursorPage
|
||||
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
|
||||
from openai.types import FileObject
|
||||
from openai.types.beta import thread_update_params
|
||||
from openai.types.beta.assistant import Assistant
|
||||
from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam
|
||||
from openai.types.beta.assistant_tool_param import AssistantToolParam
|
||||
from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
|
||||
from openai.types.beta.file_search_tool_param import FileSearchToolParam
|
||||
from openai.types.beta.function_tool_param import FunctionToolParam
|
||||
from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter
|
||||
from openai.types.beta.threads import Message, MessageDeleted, Run
|
||||
from openai.types.beta.threads.image_url_content_block_param import ImageURLContentBlockParam
|
||||
from openai.types.beta.threads.image_url_param import ImageURLParam
|
||||
from openai.types.beta.threads.message_content_part_param import (
|
||||
MessageContentPartParam,
|
||||
)
|
||||
from openai.types.beta.threads.text_content_block_param import TextContentBlockParam
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
from openai.types.vector_store import VectorStore
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam":
|
||||
"""Convert an autogen Tool to an OpenAI Assistant function tool parameter."""
|
||||
|
||||
schema = tool.schema
|
||||
parameters: Dict[str, object] = {}
|
||||
if "parameters" in schema:
|
||||
parameters = {
|
||||
"type": schema["parameters"]["type"],
|
||||
"properties": schema["parameters"]["properties"],
|
||||
}
|
||||
if "required" in schema["parameters"]:
|
||||
parameters["required"] = schema["parameters"]["required"]
|
||||
|
||||
function_def = FunctionDefinition(
|
||||
name=schema["name"],
|
||||
description=schema.get("description", ""),
|
||||
parameters=parameters,
|
||||
)
|
||||
return FunctionToolParam(type="function", function=function_def)
|
||||
|
||||
|
||||
class OpenAIAssistantAgentState(BaseModel):
|
||||
type: str = Field(default="OpenAIAssistantAgentState")
|
||||
assistant_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None
|
||||
initial_message_ids: List[str] = Field(default_factory=list)
|
||||
vector_store_id: Optional[str] = None
|
||||
uploaded_file_ids: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OpenAIAssistantAgent(BaseChatAgent):
|
||||
"""An agent implementation that uses the Assistant API to generate responses.
|
||||
|
||||
Installation:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[openai]" # For OpenAI Assistant
|
||||
# pip install "agentdhal-ext[openai,azure]" # For Azure OpenAI Assistant
|
||||
|
||||
|
||||
This agent leverages the Assistant API to create AI assistants with capabilities like:
|
||||
|
||||
* Code interpretation and execution
|
||||
* File handling and search
|
||||
* Custom function calling
|
||||
* Multi-turn conversations
|
||||
|
||||
The agent maintains a thread of conversation and can use various tools including
|
||||
|
||||
* Code interpreter: For executing code and working with files
|
||||
* File search: For searching through uploaded documents
|
||||
* Custom functions: For extending capabilities with user-defined tools
|
||||
|
||||
Key Features:
|
||||
|
||||
* Supports multiple file formats including code, documents, images
|
||||
* Can handle up to 128 tools per assistant
|
||||
* Maintains conversation context in threads
|
||||
* Supports file uploads for code interpreter and search
|
||||
* Vector store integration for efficient file search
|
||||
* Automatic file parsing and embedding
|
||||
|
||||
You can use an existing thread or assistant by providing the `thread_id` or `assistant_id` parameters.
|
||||
|
||||
Examples:
|
||||
|
||||
Use the assistant to analyze data in a CSV file:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from agentdhal_core import CancellationToken
|
||||
import asyncio
|
||||
from agentdhal_extensions.agents.openai import OpenAIAssistantAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
|
||||
|
||||
async def example():
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
# Create an OpenAI client
|
||||
client = AsyncOpenAI(api_key="your-api-key", base_url="your-base-url")
|
||||
|
||||
# Create an assistant with code interpreter
|
||||
assistant = OpenAIAssistantAgent(
|
||||
name="PythonHelper",
|
||||
description="Helps with Python programming",
|
||||
client=client,
|
||||
model="gpt-4",
|
||||
instructions="You are a helpful Python programming assistant.",
|
||||
tools=["code_interpreter"],
|
||||
)
|
||||
|
||||
# Upload files for the assistant to use
|
||||
await assistant.on_upload_for_code_interpreter("data.csv", cancellation_token)
|
||||
|
||||
# Get response from the assistant
|
||||
response = await assistant.on_messages(
|
||||
[TextMessage(source="user", content="Analyze the data in data.csv")], cancellation_token
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
# Clean up resources
|
||||
await assistant.delete_uploaded_files(cancellation_token)
|
||||
await assistant.delete_assistant(cancellation_token)
|
||||
|
||||
|
||||
asyncio.run(example())
|
||||
|
||||
Use Azure OpenAI Assistant with AAD authentication:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from openai import AsyncAzureOpenAI
|
||||
import asyncio
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_extensions.agents.openai import OpenAIAssistantAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
|
||||
|
||||
async def example():
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
# Create an Azure OpenAI client
|
||||
token_provider = get_bearer_token_provider(DefaultAzureCredential())
|
||||
client = AsyncAzureOpenAI(
|
||||
azure_deployment="YOUR_AZURE_DEPLOYMENT",
|
||||
api_version="YOUR_API_VERSION",
|
||||
azure_endpoint="YOUR_AZURE_ENDPOINT",
|
||||
azure_ad_token_provider=token_provider,
|
||||
)
|
||||
|
||||
# Create an assistant with code interpreter
|
||||
assistant = OpenAIAssistantAgent(
|
||||
name="PythonHelper",
|
||||
description="Helps with Python programming",
|
||||
client=client,
|
||||
model="gpt-4o",
|
||||
instructions="You are a helpful Python programming assistant.",
|
||||
tools=["code_interpreter"],
|
||||
)
|
||||
|
||||
# Get response from the assistant
|
||||
response = await assistant.on_messages([TextMessage(source="user", content="Hello.")], cancellation_token)
|
||||
|
||||
print(response)
|
||||
|
||||
# Clean up resources
|
||||
await assistant.delete_assistant(cancellation_token)
|
||||
|
||||
|
||||
asyncio.run(example())
|
||||
|
||||
Args:
|
||||
name (str): Name of the assistant
|
||||
description (str): Description of the assistant's purpose
|
||||
client (AsyncOpenAI | AsyncAzureOpenAI): OpenAI client or Azure OpenAI client instance
|
||||
model (str): Model to use (e.g. "gpt-4")
|
||||
instructions (str): System instructions for the assistant
|
||||
tools (Optional[Iterable[Union[Literal["code_interpreter", "file_search"], Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]]]]): Tools the assistant can use
|
||||
assistant_id (Optional[str]): ID of existing assistant to use
|
||||
thread_id (Optional[str]): ID of existing thread to use
|
||||
metadata (Optional[Dict[str, str]]): Additional metadata for the assistant.
|
||||
response_format (Optional[AssistantResponseFormatOptionParam]): Response format settings
|
||||
temperature (Optional[float]): Temperature for response generation
|
||||
tool_resources (Optional[ToolResources]): Additional tool configuration
|
||||
top_p (Optional[float]): Top p sampling parameter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
client: AsyncOpenAI | AsyncAzureOpenAI,
|
||||
model: str,
|
||||
instructions: str,
|
||||
tools: Optional[
|
||||
Iterable[
|
||||
Union[
|
||||
Literal["code_interpreter", "file_search"],
|
||||
Tool | Callable[..., Any] | Callable[..., Awaitable[Any]],
|
||||
]
|
||||
]
|
||||
] = None,
|
||||
assistant_id: Optional[str] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional["AssistantResponseFormatOptionParam"] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_resources: Optional["ToolResources"] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> None:
|
||||
if isinstance(client, ChatCompletionClient):
|
||||
raise ValueError(
|
||||
"Incorrect client passed to OpenAIAssistantAgent. Please use an OpenAI AsyncClient instance instead of an AutoGen ChatCompletionClient instance."
|
||||
)
|
||||
|
||||
super().__init__(name, description)
|
||||
if tools is None:
|
||||
tools = []
|
||||
|
||||
# Store original tools and converted tools separately
|
||||
self._original_tools: List[Tool] = []
|
||||
converted_tools: List["AssistantToolParam"] = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, str):
|
||||
if tool == "code_interpreter":
|
||||
converted_tools.append(CodeInterpreterToolParam(type="code_interpreter"))
|
||||
elif tool == "file_search":
|
||||
converted_tools.append(FileSearchToolParam(type="file_search"))
|
||||
elif isinstance(tool, Tool):
|
||||
self._original_tools.append(tool)
|
||||
converted_tools.append(_convert_tool_to_function_param(tool))
|
||||
elif callable(tool):
|
||||
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
|
||||
description = tool.__doc__
|
||||
else:
|
||||
description = ""
|
||||
function_tool = FunctionTool(tool, description=description)
|
||||
self._original_tools.append(function_tool)
|
||||
converted_tools.append(_convert_tool_to_function_param(function_tool))
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||
|
||||
self._client = client
|
||||
self._assistant: Optional["Assistant"] = None
|
||||
self._thread: Optional["Thread"] = None
|
||||
self._init_thread_id = thread_id
|
||||
self._model = model
|
||||
self._instructions = instructions
|
||||
self._api_tools = converted_tools
|
||||
self._assistant_id = assistant_id
|
||||
self._metadata = metadata
|
||||
self._response_format = response_format
|
||||
self._temperature = temperature
|
||||
self._tool_resources = tool_resources
|
||||
self._top_p = top_p
|
||||
self._vector_store_id: Optional[str] = None
|
||||
self._uploaded_file_ids: List[str] = []
|
||||
|
||||
# Variables to track initial state
|
||||
self._initial_message_ids: Set[str] = set()
|
||||
self._initial_state_retrieved: bool = False
|
||||
|
||||
async def _ensure_initialized(self) -> None:
|
||||
"""Ensure assistant and thread are created."""
|
||||
if self._assistant is None:
|
||||
if self._assistant_id:
|
||||
self._assistant = await self._client.beta.assistants.retrieve(assistant_id=self._assistant_id) # type: ignore[reportDeprecated]
|
||||
else:
|
||||
self._assistant = await self._client.beta.assistants.create( # type: ignore[reportDeprecated]
|
||||
model=self._model,
|
||||
description=self.description,
|
||||
instructions=self._instructions,
|
||||
tools=self._api_tools,
|
||||
metadata=self._metadata,
|
||||
response_format=self._response_format if self._response_format else NOT_GIVEN, # type: ignore
|
||||
temperature=self._temperature,
|
||||
tool_resources=self._tool_resources if self._tool_resources else NOT_GIVEN, # type: ignore
|
||||
top_p=self._top_p,
|
||||
)
|
||||
|
||||
if self._thread is None:
|
||||
if self._init_thread_id:
|
||||
self._thread = await self._client.beta.threads.retrieve(thread_id=self._init_thread_id) # type: ignore[reportDeprecated]
|
||||
else:
|
||||
self._thread = await self._client.beta.threads.create() # type: ignore[reportDeprecated]
|
||||
|
||||
# Retrieve initial state only once
|
||||
if not self._initial_state_retrieved:
|
||||
await self._retrieve_initial_state()
|
||||
self._initial_state_retrieved = True
|
||||
|
||||
async def _retrieve_initial_state(self) -> None:
|
||||
"""Retrieve and store the initial state of messages and runs."""
|
||||
# Retrieve all initial message IDs
|
||||
initial_message_ids: Set[str] = set()
|
||||
after: str | NotGiven = NOT_GIVEN
|
||||
while True:
|
||||
msgs: AsyncCursorPage[Message] = await self._client.beta.threads.messages.list( # type: ignore[reportDeprecated]
|
||||
self._thread_id, after=after, order="asc", limit=100
|
||||
)
|
||||
for msg in msgs.data:
|
||||
initial_message_ids.add(msg.id)
|
||||
if not msgs.has_next_page():
|
||||
break
|
||||
after = msgs.data[-1].id
|
||||
self._initial_message_ids = initial_message_ids
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
"""The types of messages that the assistant agent produces."""
|
||||
return (TextMessage,)
|
||||
|
||||
@property
|
||||
def threads(self) -> AsyncThreads:
|
||||
return self._client.beta.threads
|
||||
|
||||
@property
|
||||
def runs(self) -> AsyncRuns:
|
||||
return self._client.beta.threads.runs
|
||||
|
||||
@property
|
||||
def messages(self) -> AsyncMessages:
|
||||
return self._client.beta.threads.messages
|
||||
|
||||
@property
|
||||
def _get_assistant_id(self) -> str:
|
||||
if self._assistant is None:
|
||||
raise ValueError("Assistant not initialized")
|
||||
return self._assistant.id
|
||||
|
||||
@property
|
||||
def _thread_id(self) -> str:
|
||||
if self._thread is None:
|
||||
raise ValueError("Thread not initialized")
|
||||
return self._thread.id
|
||||
|
||||
async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str:
|
||||
"""Execute a tool call and return the result."""
|
||||
if not self._original_tools:
|
||||
raise ValueError("No tools are available.")
|
||||
tool = next((t for t in self._original_tools if t.name == tool_call.name), None)
|
||||
if tool is None:
|
||||
raise ValueError(f"The tool '{tool_call.name}' is not available.")
|
||||
arguments = json.loads(tool_call.arguments)
|
||||
result = await tool.run_json(arguments, cancellation_token, call_id=tool_call.id)
|
||||
return tool.return_value_as_string(result)
|
||||
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
"""Handle incoming messages and return a response."""
|
||||
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
return message
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""Handle incoming messages and return a response."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
# Process all messages in sequence
|
||||
for message in messages:
|
||||
await self.handle_incoming_message(message, cancellation_token)
|
||||
|
||||
# Inner messages for tool calls
|
||||
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
|
||||
# Create and start a run
|
||||
run: Run = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.runs.create( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
assistant_id=self._get_assistant_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for run completion by polling
|
||||
while True:
|
||||
run = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.runs.retrieve( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
run_id=run.id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if run.status == "failed":
|
||||
raise ValueError(f"Run failed: {run.last_error}")
|
||||
|
||||
# If the run requires action (function calls), execute tools and continue
|
||||
if run.status == "requires_action" and run.required_action is not None:
|
||||
tool_calls: List[FunctionCall] = []
|
||||
for required_tool_call in run.required_action.submit_tool_outputs.tool_calls:
|
||||
if required_tool_call.type == "function":
|
||||
tool_calls.append(
|
||||
FunctionCall(
|
||||
id=required_tool_call.id,
|
||||
name=required_tool_call.function.name,
|
||||
arguments=required_tool_call.function.arguments,
|
||||
)
|
||||
)
|
||||
|
||||
# Add tool call message to inner messages
|
||||
tool_call_msg = ToolCallRequestEvent(source=self.name, content=tool_calls)
|
||||
inner_messages.append(tool_call_msg)
|
||||
event_logger.debug(tool_call_msg)
|
||||
yield tool_call_msg
|
||||
|
||||
# Execute tool calls and get results
|
||||
tool_outputs: List[FunctionExecutionResult] = []
|
||||
for tool_call in tool_calls:
|
||||
try:
|
||||
result = await self._execute_tool_call(tool_call, cancellation_token)
|
||||
is_error = False
|
||||
except Exception as e:
|
||||
result = f"Error: {e}"
|
||||
is_error = True
|
||||
tool_outputs.append(
|
||||
FunctionExecutionResult(
|
||||
content=result, call_id=tool_call.id, is_error=is_error, name=tool_call.name
|
||||
)
|
||||
)
|
||||
|
||||
# Add tool result message to inner messages
|
||||
tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs)
|
||||
inner_messages.append(tool_result_msg)
|
||||
event_logger.debug(tool_result_msg)
|
||||
yield tool_result_msg
|
||||
|
||||
# Submit tool outputs back to the run
|
||||
run = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.runs.submit_tool_outputs( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
run_id=run.id,
|
||||
tool_outputs=[{"tool_call_id": t.call_id, "output": t.content} for t in tool_outputs],
|
||||
)
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if run.status == "completed":
|
||||
break
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Get messages after run completion
|
||||
assistant_messages: AsyncCursorPage[Message] = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.list(thread_id=self._thread_id, order="desc", limit=1) # type: ignore[reportDeprecated]
|
||||
)
|
||||
)
|
||||
|
||||
if not assistant_messages.data:
|
||||
raise ValueError("No messages received from assistant")
|
||||
|
||||
# Get the last message's content
|
||||
last_message = assistant_messages.data[0]
|
||||
if not last_message.content:
|
||||
raise ValueError(f"No content in the last message: {last_message}")
|
||||
|
||||
# Extract text content
|
||||
text_content = [content for content in last_message.content if content.type == "text"]
|
||||
if not text_content:
|
||||
raise ValueError(f"Expected text content in the last message: {last_message.content}")
|
||||
|
||||
# Return the assistant's response as a Response with inner messages
|
||||
chat_message = TextMessage(source=self.name, content=text_content[0].text.value)
|
||||
yield Response(chat_message=chat_message, inner_messages=inner_messages)
|
||||
|
||||
async def handle_incoming_message(self, message: BaseChatMessage, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle regular text messages by adding them to the thread."""
|
||||
content: str | List[MessageContentPartParam] | None = None
|
||||
llm_message = message.to_model_message()
|
||||
if isinstance(llm_message.content, str):
|
||||
content = llm_message.content
|
||||
else:
|
||||
content = []
|
||||
for c in llm_message.content:
|
||||
if isinstance(c, str):
|
||||
content.append(TextContentBlockParam(text=c, type="text"))
|
||||
elif isinstance(c, Image):
|
||||
content.append(ImageURLContentBlockParam(image_url=ImageURLParam(url=c.data_uri), type="image_url"))
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(c)} in {message}")
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.create( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
content=content,
|
||||
role="user",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle reset command by deleting new messages and runs since initialization."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
# Retrieve all message IDs in the thread
|
||||
new_message_ids: List[str] = []
|
||||
after: str | NotGiven = NOT_GIVEN
|
||||
while True:
|
||||
msgs: AsyncCursorPage[Message] = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.list(self._thread_id, after=after, order="asc", limit=100) # type: ignore[reportDeprecated]
|
||||
)
|
||||
)
|
||||
for msg in msgs.data:
|
||||
if msg.id not in self._initial_message_ids:
|
||||
new_message_ids.append(msg.id)
|
||||
if not msgs.has_next_page():
|
||||
break
|
||||
after = msgs.data[-1].id
|
||||
|
||||
# Delete new messages
|
||||
for msg_id in new_message_ids:
|
||||
status: MessageDeleted = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id) # type: ignore[reportDeprecated]
|
||||
)
|
||||
)
|
||||
assert status.deleted is True
|
||||
|
||||
async def _upload_files(self, file_paths: str | Iterable[str], cancellation_token: CancellationToken) -> List[str]:
|
||||
"""Upload files and return their IDs."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
if isinstance(file_paths, str):
|
||||
file_paths = [file_paths]
|
||||
|
||||
file_ids: List[str] = []
|
||||
for file_path in file_paths:
|
||||
async with aiofiles.open(file_path, mode="rb") as f:
|
||||
file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read()))
|
||||
file_name = os.path.basename(file_path)
|
||||
|
||||
file: FileObject = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.files.create(file=(file_name, file_content), purpose="assistants"))
|
||||
)
|
||||
file_ids.append(file.id)
|
||||
self._uploaded_file_ids.append(file.id)
|
||||
|
||||
return file_ids
|
||||
|
||||
async def on_upload_for_code_interpreter(
|
||||
self, file_paths: str | Iterable[str], cancellation_token: CancellationToken
|
||||
) -> None:
|
||||
"""Handle file uploads for the code interpreter."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
file_ids = await self._upload_files(file_paths, cancellation_token)
|
||||
|
||||
# Update thread with the new files
|
||||
thread = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id)) # type: ignore[reportDeprecated]
|
||||
)
|
||||
tool_resources: ToolResources = thread.tool_resources or ToolResources()
|
||||
code_interpreter: ToolResourcesCodeInterpreter = (
|
||||
tool_resources.code_interpreter or ToolResourcesCodeInterpreter()
|
||||
)
|
||||
existing_file_ids: List[str] = code_interpreter.file_ids or []
|
||||
existing_file_ids.extend(file_ids)
|
||||
tool_resources.code_interpreter = ToolResourcesCodeInterpreter(file_ids=existing_file_ids)
|
||||
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.threads.update( # type: ignore[reportDeprecated]
|
||||
thread_id=self._thread_id,
|
||||
tool_resources=cast(thread_update_params.ToolResources, tool_resources.model_dump()),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def on_upload_for_file_search(
|
||||
self, file_paths: str | Iterable[str], cancellation_token: CancellationToken
|
||||
) -> None:
|
||||
"""Handle file uploads for file search."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
# Check if file_search is enabled in tools
|
||||
if not any(tool.get("type") == "file_search" for tool in self._api_tools):
|
||||
raise ValueError(
|
||||
"File search is not enabled for this assistant. Add a file_search tool when creating the assistant."
|
||||
)
|
||||
|
||||
# Create vector store if not already created
|
||||
if self._vector_store_id is None:
|
||||
vector_store: VectorStore = await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.vector_stores.create())
|
||||
)
|
||||
self._vector_store_id = vector_store.id
|
||||
|
||||
# Update assistant with vector store ID
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.beta.assistants.update(
|
||||
assistant_id=self._get_assistant_id,
|
||||
tool_resources={"file_search": {"vector_store_ids": [self._vector_store_id]}},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
file_ids = await self._upload_files(file_paths, cancellation_token)
|
||||
|
||||
# Create file batch with the file IDs
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(
|
||||
self._client.vector_stores.file_batches.create_and_poll(
|
||||
vector_store_id=self._vector_store_id, file_ids=file_ids
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def delete_uploaded_files(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Delete all files that were uploaded by this agent instance."""
|
||||
await self._ensure_initialized()
|
||||
for file_id in self._uploaded_file_ids:
|
||||
try:
|
||||
await cancellation_token.link_future(asyncio.ensure_future(self._client.files.delete(file_id=file_id)))
|
||||
except Exception as e:
|
||||
event_logger.error(f"Failed to delete file {file_id}: {str(e)}")
|
||||
self._uploaded_file_ids = []
|
||||
|
||||
async def delete_assistant(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Delete the assistant if it was created by this instance."""
|
||||
await self._ensure_initialized()
|
||||
if self._assistant is not None and not self._assistant_id:
|
||||
try:
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.beta.assistants.delete(assistant_id=self._get_assistant_id)) # type: ignore[reportDeprecated]
|
||||
)
|
||||
self._assistant = None
|
||||
except Exception as e:
|
||||
event_logger.error(f"Failed to delete assistant: {str(e)}")
|
||||
|
||||
async def delete_vector_store(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Delete the vector store if it was created by this instance."""
|
||||
await self._ensure_initialized()
|
||||
if self._vector_store_id is not None:
|
||||
try:
|
||||
await cancellation_token.link_future(
|
||||
asyncio.ensure_future(self._client.vector_stores.delete(vector_store_id=self._vector_store_id))
|
||||
)
|
||||
self._vector_store_id = None
|
||||
except Exception as e:
|
||||
event_logger.error(f"Failed to delete vector store: {str(e)}")
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = OpenAIAssistantAgentState(
|
||||
assistant_id=self._assistant.id if self._assistant else self._assistant_id,
|
||||
thread_id=self._thread.id if self._thread else self._init_thread_id,
|
||||
initial_message_ids=list(self._initial_message_ids),
|
||||
vector_store_id=self._vector_store_id,
|
||||
uploaded_file_ids=self._uploaded_file_ids,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
agent_state = OpenAIAssistantAgentState.model_validate(state)
|
||||
self._assistant_id = agent_state.assistant_id
|
||||
self._init_thread_id = agent_state.thread_id
|
||||
self._initial_message_ids = set(agent_state.initial_message_ids)
|
||||
self._vector_store_id = agent_state.vector_store_id
|
||||
self._uploaded_file_ids = agent_state.uploaded_file_ids
|
||||
@@ -0,0 +1,3 @@
|
||||
from ._video_surfer import VideoSurfer
|
||||
|
||||
__all__ = ["VideoSurfer"]
|
||||
@@ -0,0 +1,172 @@
|
||||
from typing import Any, Awaitable, Callable, List, Optional
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_core.models import ChatCompletionClient
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .tools import (
|
||||
extract_audio,
|
||||
get_screenshot_at,
|
||||
get_video_length,
|
||||
save_screenshot,
|
||||
transcribe_audio_with_timestamps,
|
||||
transcribe_video_screenshot,
|
||||
)
|
||||
|
||||
|
||||
class VideoSurfer(AssistantAgent):
|
||||
"""
|
||||
VideoSurfer is a specialized agent designed to answer questions about a local video file.
|
||||
|
||||
Installation:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[video-surfer]"
|
||||
|
||||
This agent utilizes various tools to extract information from the video, such as its length, screenshots at specific timestamps, and audio transcriptions. It processes these elements to provide detailed answers to user queries.
|
||||
|
||||
Available tools:
|
||||
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.extract_audio`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.get_video_length`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.transcribe_audio_with_timestamps`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.get_screenshot_at`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.save_screenshot`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.transcribe_video_screenshot`
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
model_client (ChatCompletionClient): The model client used for generating responses.
|
||||
tools (List[BaseTool[BaseModel, BaseModel] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional):
|
||||
A list of tools or functions the agent can use. If not provided, defaults to all video tools from the action space.
|
||||
description (str, optional): A brief description of the agent. Defaults to "An agent that can answer questions about a local video.".
|
||||
system_message (str | None, optional): The system message guiding the agent's behavior. Defaults to a predefined message.
|
||||
|
||||
Example usage:
|
||||
|
||||
The following example demonstrates how to create an video surfing agent with
|
||||
a model client and generate a response to a simple query about a local video
|
||||
called video.mp4.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.agents.video_surfer import VideoSurfer
|
||||
|
||||
async def main() -> None:
|
||||
\"\"\"
|
||||
Main function to run the video agent.
|
||||
\"\"\"
|
||||
# Define an agent
|
||||
video_agent = VideoSurfer(
|
||||
name="VideoSurfer",
|
||||
model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-08-06")
|
||||
)
|
||||
|
||||
# Define termination condition
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
|
||||
# Define a team
|
||||
agent_team = RoundRobinGroupChat([video_agent], termination_condition=termination)
|
||||
|
||||
# Run the team and stream messages to the console
|
||||
stream = agent_team.run_stream(task="How does Adam define complex tasks in video.mp4? What concrete example of complex does his use? Can you save this example to disk as well?")
|
||||
await Console(stream)
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
The following example demonstrates how to create and use a VideoSurfer and UserProxyAgent with MagenticOneGroupChat.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.teams import MagenticOneGroupChat
|
||||
from agentdhal_agentchat.agents import UserProxyAgent
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.agents.video_surfer import VideoSurfer
|
||||
|
||||
async def main() -> None:
|
||||
\"\"\"
|
||||
Main function to run the video agent.
|
||||
\"\"\"
|
||||
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o-2024-08-06")
|
||||
|
||||
# Define an agent
|
||||
video_agent = VideoSurfer(
|
||||
name="VideoSurfer",
|
||||
model_client=model_client
|
||||
)
|
||||
|
||||
web_surfer_agent = UserProxyAgent(
|
||||
name="User"
|
||||
)
|
||||
|
||||
# Define a team
|
||||
agent_team = MagenticOneGroupChat([web_surfer_agent, video_agent], model_client=model_client,)
|
||||
|
||||
# Run the team and stream messages to the console
|
||||
stream = agent_team.run_stream(task="Find a latest video about magentic one on youtube and extract quotes from it that make sense.")
|
||||
await Console(stream)
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
DEFAULT_DESCRIPTION = "An agent that can answer questions about a local video."
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGE = """
|
||||
You are a helpful agent that is an expert at answering questions from a video.
|
||||
When asked to answer a question about a video, you should:
|
||||
1. Check if that video is available locally.
|
||||
2. Use the transcription to find which part of the video the question is referring to.
|
||||
3. Optionally use screenshots from those timestamps
|
||||
4. Provide a detailed answer to the question.
|
||||
Reply with TERMINATE when the task has been completed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
tools: List[BaseTool[BaseModel, BaseModel] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
|
||||
description: Optional[str] = None,
|
||||
system_message: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
model_client=model_client,
|
||||
tools=tools
|
||||
or [
|
||||
get_video_length,
|
||||
get_screenshot_at,
|
||||
save_screenshot,
|
||||
self.vs_transribe_video_screenshot,
|
||||
extract_audio,
|
||||
transcribe_audio_with_timestamps,
|
||||
],
|
||||
description=description or self.DEFAULT_DESCRIPTION,
|
||||
system_message=system_message or self.DEFAULT_SYSTEM_MESSAGE,
|
||||
)
|
||||
|
||||
async def vs_transribe_video_screenshot(self, video_path: str, timestamp: float) -> str:
|
||||
"""
|
||||
Transcribes the video screenshot at a specific timestamp.
|
||||
|
||||
Args:
|
||||
video_path (str): Path to the video file.
|
||||
timestamp (float): Timestamp to take the screenshot.
|
||||
|
||||
Returns:
|
||||
str: Transcription of the video screenshot.
|
||||
"""
|
||||
return await transcribe_video_screenshot(video_path, timestamp, self._model_client)
|
||||
156
agent_dhal/agentdhal_extensions/agents/video_surfer/tools.py
Normal file
156
agent_dhal/agentdhal_extensions/agents/video_surfer/tools.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import base64
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import cv2
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
import whisper
|
||||
from agentdhal_core import Image as AGImage
|
||||
from agentdhal_core.models import (
|
||||
ChatCompletionClient,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
|
||||
def extract_audio(video_path: str, audio_output_path: str) -> str:
|
||||
"""
|
||||
Extracts audio from a video file and saves it as an MP3 file.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:param audio_output_path: Path to save the extracted audio file.
|
||||
:return: Confirmation message with the path to the saved audio file.
|
||||
"""
|
||||
(ffmpeg.input(video_path).output(audio_output_path, format="mp3").run(quiet=True, overwrite_output=True)) # type: ignore
|
||||
return f"Audio extracted and saved to {audio_output_path}."
|
||||
|
||||
|
||||
def transcribe_audio_with_timestamps(audio_path: str) -> str:
|
||||
"""
|
||||
Transcribes the audio file with timestamps using the Whisper model.
|
||||
|
||||
:param audio_path: Path to the audio file.
|
||||
:return: Transcription with timestamps.
|
||||
"""
|
||||
model = whisper.load_model("base") # type: ignore
|
||||
result: Dict[str, Any] = model.transcribe(audio_path, task="transcribe", language="en", verbose=False) # type: ignore
|
||||
|
||||
segments: List[Dict[str, Any]] = result["segments"]
|
||||
transcription_with_timestamps = ""
|
||||
|
||||
for segment in segments:
|
||||
start: float = segment["start"]
|
||||
end: float = segment["end"]
|
||||
text: str = segment["text"]
|
||||
transcription_with_timestamps += f"[{start:.2f} - {end:.2f}] {text}\n"
|
||||
|
||||
return transcription_with_timestamps
|
||||
|
||||
|
||||
def get_video_length(video_path: str) -> str:
|
||||
"""
|
||||
Returns the length of the video in seconds.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:return: Duration of the video in seconds.
|
||||
"""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise IOError(f"Cannot open video file {video_path}")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
duration = frame_count / fps
|
||||
cap.release()
|
||||
|
||||
return f"The video is {duration:.2f} seconds long."
|
||||
|
||||
|
||||
def save_screenshot(video_path: str, timestamp: float, output_path: str) -> None:
|
||||
"""
|
||||
Captures a screenshot at the specified timestamp and saves it to the output path.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:param timestamp: Timestamp in seconds.
|
||||
:param output_path: Path to save the screenshot. The file format is determined by the extension in the path.
|
||||
"""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise IOError(f"Cannot open video file {video_path}")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_number = int(timestamp * fps)
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
cv2.imwrite(output_path, frame)
|
||||
else:
|
||||
raise IOError(f"Failed to capture frame at {timestamp:.2f}s")
|
||||
cap.release()
|
||||
|
||||
|
||||
async def transcribe_video_screenshot(video_path: str, timestamp: float, model_client: ChatCompletionClient) -> str:
|
||||
"""
|
||||
Transcribes the content of a video screenshot captured at the specified timestamp using OpenAI API.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:param timestamp: Timestamp in seconds.
|
||||
:param model_client: ChatCompletionClient instance.
|
||||
:return: Description of the screenshot content.
|
||||
"""
|
||||
screenshots = get_screenshot_at(video_path, [timestamp])
|
||||
if not screenshots:
|
||||
return "Failed to capture screenshot."
|
||||
|
||||
_, frame = screenshots[0]
|
||||
# Convert the frame to bytes and then to base64 encoding
|
||||
_, buffer = cv2.imencode(".jpg", frame)
|
||||
frame_bytes = buffer.tobytes()
|
||||
frame_base64 = base64.b64encode(frame_bytes).decode("utf-8")
|
||||
screenshot_uri = f"data:image/jpeg;base64,{frame_base64}"
|
||||
|
||||
messages = [
|
||||
UserMessage(
|
||||
content=[
|
||||
"Following is a screenshot from the video at {} seconds. Describe what you see here.",
|
||||
AGImage.from_uri(screenshot_uri),
|
||||
],
|
||||
source="tool",
|
||||
)
|
||||
]
|
||||
|
||||
result = await model_client.create(messages=messages)
|
||||
return str(result.content)
|
||||
|
||||
|
||||
def get_screenshot_at(video_path: str, timestamps: List[float]) -> List[Tuple[float, np.ndarray[Any, Any]]]:
|
||||
"""
|
||||
Captures screenshots at the specified timestamps and returns them as Python objects.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:param timestamps: List of timestamps in seconds.
|
||||
:return: List of tuples containing timestamp and the corresponding frame (image).
|
||||
Each frame is a NumPy array (height x width x channels).
|
||||
"""
|
||||
screenshots: List[Tuple[float, np.ndarray[Any, Any]]] = []
|
||||
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise IOError(f"Cannot open video file {video_path}")
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
duration = total_frames / fps
|
||||
|
||||
for timestamp in timestamps:
|
||||
if 0 <= timestamp <= duration:
|
||||
frame_number = int(timestamp * fps)
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
# Append the timestamp and frame to the list
|
||||
screenshots.append((timestamp, frame))
|
||||
else:
|
||||
raise IOError(f"Failed to capture frame at {timestamp:.2f}s")
|
||||
else:
|
||||
raise ValueError(f"Timestamp {timestamp:.2f}s is out of range [0s, {duration:.2f}s]")
|
||||
|
||||
cap.release()
|
||||
return screenshots
|
||||
@@ -0,0 +1,4 @@
|
||||
from ._multimodal_web_surfer import MultimodalWebSurfer
|
||||
from .playwright_controller import PlaywrightController
|
||||
|
||||
__all__ = ["MultimodalWebSurfer", "PlaywrightController"]
|
||||
11
agent_dhal/agentdhal_extensions/agents/web_surfer/_events.py
Normal file
11
agent_dhal/agentdhal_extensions/agents/web_surfer/_events.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSurferEvent:
|
||||
source: str
|
||||
message: str
|
||||
url: str
|
||||
action: str | None = None
|
||||
arguments: Dict[str, Any] | None = None
|
||||
@@ -0,0 +1,988 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
)
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import aiofiles
|
||||
import PIL.Image
|
||||
from agentdhal_agentchat.agents import BaseChatAgent
|
||||
from agentdhal_agentchat.base import Response
|
||||
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, MultiModalMessage, TextMessage
|
||||
from agentdhal_agentchat.utils import content_to_str, remove_images
|
||||
from agentdhal_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel, FunctionCall
|
||||
from agentdhal_core import Image as AGImage
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
LLMMessage,
|
||||
ModelFamily,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from PIL import Image
|
||||
from playwright.async_api import BrowserContext, Download, Page, Playwright, async_playwright
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._events import WebSurferEvent
|
||||
from ._prompts import (
|
||||
WEB_SURFER_QA_PROMPT,
|
||||
WEB_SURFER_QA_SYSTEM_MESSAGE,
|
||||
WEB_SURFER_TOOL_PROMPT_MM,
|
||||
WEB_SURFER_TOOL_PROMPT_TEXT,
|
||||
)
|
||||
from ._set_of_mark import add_set_of_mark
|
||||
from ._tool_definitions import (
|
||||
TOOL_CLICK,
|
||||
TOOL_HISTORY_BACK,
|
||||
TOOL_HOVER,
|
||||
TOOL_READ_PAGE_AND_ANSWER,
|
||||
TOOL_SCROLL_DOWN,
|
||||
TOOL_SCROLL_UP,
|
||||
TOOL_SLEEP,
|
||||
TOOL_SUMMARIZE_PAGE,
|
||||
TOOL_TYPE,
|
||||
TOOL_VISIT_URL,
|
||||
TOOL_WEB_SEARCH,
|
||||
)
|
||||
from ._types import InteractiveRegion, UserContent
|
||||
from .playwright_controller import PlaywrightController
|
||||
|
||||
DEFAULT_CONTEXT_SIZE = 128000
|
||||
|
||||
|
||||
class MultimodalWebSurferConfig(BaseModel):
|
||||
name: str
|
||||
model_client: ComponentModel
|
||||
downloads_folder: str | None = None
|
||||
description: str | None = None
|
||||
debug_dir: str | None = None
|
||||
headless: bool = True
|
||||
start_page: str | None = "https://www.bing.com/"
|
||||
animate_actions: bool = False
|
||||
to_save_screenshots: bool = False
|
||||
use_ocr: bool = False
|
||||
browser_channel: str | None = None
|
||||
browser_data_dir: str | None = None
|
||||
to_resize_viewport: bool = True
|
||||
|
||||
|
||||
class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
|
||||
"""
|
||||
MultimodalWebSurfer is a multimodal agent that acts as a web surfer that can search the web and visit web pages.
|
||||
|
||||
Installation:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[web-surfer]"
|
||||
|
||||
It launches a chromium browser and allows the playwright to interact with the web browser and can perform a variety of actions. The browser is launched on the first call to the agent and is reused for subsequent calls.
|
||||
|
||||
It must be used with a multimodal model client that supports function/tool calling, ideally GPT-4o currently.
|
||||
|
||||
|
||||
When :meth:`on_messages` or :meth:`on_messages_stream` is called, the following occurs:
|
||||
1) If this is the first call, the browser is initialized and the page is loaded. This is done in :meth:`_lazy_init`. The browser is only closed when :meth:`close` is called.
|
||||
2) The method :meth:`_generate_reply` is called, which then creates the final response as below.
|
||||
3) The agent takes a screenshot of the page, extracts the interactive elements, and prepares a set-of-mark screenshot with bounding boxes around the interactive elements.
|
||||
4) The agent makes a call to the :attr:`model_client` with the SOM screenshot, history of messages, and the list of available tools.
|
||||
- If the model returns a string, the agent returns the string as the final response.
|
||||
- If the model returns a list of tool calls, the agent executes the tool calls with :meth:`_execute_tool` using :attr:`_playwright_controller`.
|
||||
- The agent returns a final response which includes a screenshot of the page, page metadata, description of the action taken and the inner text of the webpage.
|
||||
5) If at any point the agent encounters an error, it returns the error message as the final response.
|
||||
|
||||
|
||||
.. note::
|
||||
Please note that using the MultimodalWebSurfer involves interacting with a digital world designed for humans, which carries inherent risks.
|
||||
Be aware that agents may occasionally attempt risky actions, such as recruiting humans for help or accepting cookie agreements without human involvement. Always ensure agents are monitored and operate within a controlled environment to prevent unintended consequences.
|
||||
Moreover, be cautious that MultimodalWebSurfer may be susceptible to prompt injection attacks from webpages.
|
||||
|
||||
.. note::
|
||||
|
||||
On Windows, the event loop policy must be set to `WindowsProactorEventLoopPolicy` to avoid issues with subprocesses.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
model_client (ChatCompletionClient): The model client used by the agent. Must be multimodal and support function calling.
|
||||
downloads_folder (str, optional): The folder where downloads are saved. Defaults to None, no downloads are saved.
|
||||
description (str, optional): The description of the agent. Defaults to MultimodalWebSurfer.DEFAULT_DESCRIPTION.
|
||||
debug_dir (str, optional): The directory where debug information is saved. Defaults to None.
|
||||
headless (bool, optional): Whether the browser should be headless. Defaults to True.
|
||||
start_page (str, optional): The start page for the browser. Defaults to MultimodalWebSurfer.DEFAULT_START_PAGE.
|
||||
animate_actions (bool, optional): Whether to animate actions. Defaults to False.
|
||||
to_save_screenshots (bool, optional): Whether to save screenshots. Defaults to False.
|
||||
use_ocr (bool, optional): Whether to use OCR. Defaults to False.
|
||||
browser_channel (str, optional): The browser channel. Defaults to None.
|
||||
browser_data_dir (str, optional): The browser data directory. Defaults to None.
|
||||
to_resize_viewport (bool, optional): Whether to resize the viewport. Defaults to True.
|
||||
playwright (Playwright, optional): The playwright instance. Defaults to None.
|
||||
context (BrowserContext, optional): The browser context. Defaults to None.
|
||||
|
||||
|
||||
|
||||
|
||||
Example usage:
|
||||
|
||||
The following example demonstrates how to create a web surfing agent with
|
||||
a model client and run it for multiple turns.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.agents.web_surfer import MultimodalWebSurfer
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Define an agent
|
||||
web_surfer_agent = MultimodalWebSurfer(
|
||||
name="MultimodalWebSurfer",
|
||||
model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-08-06"),
|
||||
)
|
||||
|
||||
# Define a team
|
||||
agent_team = RoundRobinGroupChat([web_surfer_agent], max_turns=3)
|
||||
|
||||
# Run the team and stream messages to the console
|
||||
stream = agent_team.run_stream(task="Navigate to the AutoGen readme on GitHub.")
|
||||
await Console(stream)
|
||||
# Close the browser controlled by the agent
|
||||
await web_surfer_agent.close()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_type = "agent"
|
||||
component_config_schema = MultimodalWebSurferConfig
|
||||
component_provider_override = "agentdhal_extensions.agents.web_surfer.MultimodalWebSurfer"
|
||||
|
||||
DEFAULT_DESCRIPTION = """
|
||||
A helpful assistant with access to a web browser.
|
||||
Ask them to perform web searches, open pages, and interact with content (e.g., clicking links, scrolling the viewport, filling in form fields, etc.).
|
||||
It can also summarize the entire page, or answer questions based on the content of the page.
|
||||
It can also be asked to sleep and wait for pages to load, in cases where the page seems not yet fully loaded.
|
||||
"""
|
||||
DEFAULT_START_PAGE = "https://www.bing.com/"
|
||||
|
||||
# Viewport dimensions
|
||||
VIEWPORT_HEIGHT = 900
|
||||
VIEWPORT_WIDTH = 1440
|
||||
|
||||
# Size of the image we send to the MLM
|
||||
# Current values represent a 0.85 scaling to fit within the GPT-4v short-edge constraints (768px)
|
||||
MLM_HEIGHT = 765
|
||||
MLM_WIDTH = 1224
|
||||
|
||||
SCREENSHOT_TOKENS = 1105
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_client: ChatCompletionClient,
|
||||
downloads_folder: str | None = None,
|
||||
description: str = DEFAULT_DESCRIPTION,
|
||||
debug_dir: str | None = None,
|
||||
headless: bool = True,
|
||||
start_page: str | None = DEFAULT_START_PAGE,
|
||||
animate_actions: bool = False,
|
||||
to_save_screenshots: bool = False,
|
||||
use_ocr: bool = False,
|
||||
browser_channel: str | None = None,
|
||||
browser_data_dir: str | None = None,
|
||||
to_resize_viewport: bool = True,
|
||||
playwright: Playwright | None = None,
|
||||
context: BrowserContext | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the MultimodalWebSurfer.
|
||||
"""
|
||||
super().__init__(name, description)
|
||||
if debug_dir is None and to_save_screenshots:
|
||||
raise ValueError(
|
||||
"Cannot save screenshots without a debug directory. Set it using the 'debug_dir' parameter. The debug directory is created if it does not exist."
|
||||
)
|
||||
if model_client.model_info["function_calling"] is False:
|
||||
raise ValueError(
|
||||
"The model does not support function calling. MultimodalWebSurfer requires a model that supports function calling."
|
||||
)
|
||||
|
||||
self._model_client = model_client
|
||||
self.headless = headless
|
||||
self.browser_channel = browser_channel
|
||||
self.browser_data_dir = browser_data_dir
|
||||
self.start_page = start_page or self.DEFAULT_START_PAGE
|
||||
self.downloads_folder = downloads_folder
|
||||
self.debug_dir = debug_dir
|
||||
self.to_save_screenshots = to_save_screenshots
|
||||
self.use_ocr = use_ocr
|
||||
self.to_resize_viewport = to_resize_viewport
|
||||
self.animate_actions = animate_actions
|
||||
|
||||
# Call init to set these in case not set
|
||||
self._playwright: Playwright | None = playwright
|
||||
self._context: BrowserContext | None = context
|
||||
self._page: Page | None = None
|
||||
self._last_download: Download | None = None
|
||||
self._prior_metadata_hash: str | None = None
|
||||
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.name}.MultimodalWebSurfer")
|
||||
self._chat_history: List[LLMMessage] = []
|
||||
|
||||
# Define the download handler
|
||||
def _download_handler(download: Download) -> None:
|
||||
self._last_download = download
|
||||
|
||||
self._download_handler = _download_handler
|
||||
|
||||
# Define the Playwright controller that handles the browser interactions
|
||||
self._playwright_controller = PlaywrightController(
|
||||
animate_actions=self.animate_actions,
|
||||
downloads_folder=self.downloads_folder,
|
||||
viewport_width=self.VIEWPORT_WIDTH,
|
||||
viewport_height=self.VIEWPORT_HEIGHT,
|
||||
_download_handler=self._download_handler,
|
||||
to_resize_viewport=self.to_resize_viewport,
|
||||
)
|
||||
self.default_tools = [
|
||||
TOOL_VISIT_URL,
|
||||
TOOL_WEB_SEARCH,
|
||||
TOOL_HISTORY_BACK,
|
||||
TOOL_CLICK,
|
||||
TOOL_TYPE,
|
||||
TOOL_READ_PAGE_AND_ANSWER,
|
||||
TOOL_SUMMARIZE_PAGE,
|
||||
TOOL_SLEEP,
|
||||
TOOL_HOVER,
|
||||
]
|
||||
self.did_lazy_init = False # flag to check if we have initialized the browser
|
||||
|
||||
async def _lazy_init(
|
||||
self,
|
||||
) -> None:
|
||||
"""
|
||||
On the first call, we initialize the browser and the page.
|
||||
"""
|
||||
|
||||
# Check the current event loop policy if on windows.
|
||||
if sys.platform == "win32":
|
||||
current_policy = asyncio.get_event_loop_policy()
|
||||
if hasattr(asyncio, "WindowsProactorEventLoopPolicy") and not isinstance(
|
||||
current_policy, asyncio.WindowsProactorEventLoopPolicy
|
||||
):
|
||||
warnings.warn(
|
||||
"The current event loop policy is not WindowsProactorEventLoopPolicy. "
|
||||
"This may cause issues with subprocesses. "
|
||||
"Try setting the event loop policy to WindowsProactorEventLoopPolicy. "
|
||||
"For example: `asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())`. "
|
||||
"See https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.ProactorEventLoop.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self._last_download = None
|
||||
self._prior_metadata_hash = None
|
||||
|
||||
# Create the playwright self
|
||||
launch_args: Dict[str, Any] = {"headless": self.headless}
|
||||
if self.browser_channel is not None:
|
||||
launch_args["channel"] = self.browser_channel
|
||||
if self._playwright is None:
|
||||
self._playwright = await async_playwright().start()
|
||||
|
||||
# Create the context -- are we launching persistent?
|
||||
if self._context is None:
|
||||
if self.browser_data_dir is None:
|
||||
browser = await self._playwright.chromium.launch(**launch_args)
|
||||
self._context = await browser.new_context(
|
||||
user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36 Edg/122.0.0.0"
|
||||
)
|
||||
else:
|
||||
self._context = await self._playwright.chromium.launch_persistent_context(
|
||||
self.browser_data_dir, **launch_args
|
||||
)
|
||||
|
||||
# Create the page
|
||||
self._context.set_default_timeout(60000) # One minute
|
||||
self._page = await self._context.new_page()
|
||||
assert self._page is not None
|
||||
# self._page.route(lambda x: True, self._route_handler)
|
||||
self._page.on("download", self._download_handler)
|
||||
if self.to_resize_viewport:
|
||||
await self._page.set_viewport_size({"width": self.VIEWPORT_WIDTH, "height": self.VIEWPORT_HEIGHT})
|
||||
await self._page.add_init_script(
|
||||
path=os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js")
|
||||
)
|
||||
await self._page.goto(self.start_page)
|
||||
await self._page.wait_for_load_state()
|
||||
|
||||
# Prepare the debug directory -- which stores the screenshots generated throughout the process
|
||||
await self._set_debug_dir(self.debug_dir)
|
||||
self.did_lazy_init = True
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the browser and the page.
|
||||
Should be called when the agent is no longer needed.
|
||||
"""
|
||||
if self._page is not None:
|
||||
await self._page.close()
|
||||
self._page = None
|
||||
if self._context is not None:
|
||||
await self._context.close()
|
||||
self._context = None
|
||||
if self._playwright is not None:
|
||||
await self._playwright.stop()
|
||||
self._playwright = None
|
||||
|
||||
async def _set_debug_dir(self, debug_dir: str | None) -> None:
|
||||
assert self._page is not None
|
||||
if self.debug_dir is None:
|
||||
return
|
||||
|
||||
if not os.path.isdir(self.debug_dir):
|
||||
os.mkdir(self.debug_dir)
|
||||
|
||||
if self.to_save_screenshots:
|
||||
current_timestamp = "_" + int(time.time()).__str__()
|
||||
screenshot_png_name = "screenshot" + current_timestamp + ".png"
|
||||
|
||||
await self._page.screenshot(path=os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore
|
||||
self.logger.info(
|
||||
WebSurferEvent(
|
||||
source=self.name,
|
||||
url=self._page.url,
|
||||
message="Screenshot: " + screenshot_png_name,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (MultiModalMessage,)
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
if not self.did_lazy_init:
|
||||
return
|
||||
assert self._page is not None
|
||||
|
||||
self._chat_history.clear()
|
||||
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
|
||||
self._page, self.start_page
|
||||
)
|
||||
if reset_last_download and self._last_download is not None:
|
||||
self._last_download = None
|
||||
if reset_prior_metadata and self._prior_metadata_hash is not None:
|
||||
self._prior_metadata_hash = None
|
||||
if self.to_save_screenshots:
|
||||
current_timestamp = "_" + int(time.time()).__str__()
|
||||
screenshot_png_name = "screenshot" + current_timestamp + ".png"
|
||||
|
||||
await self._page.screenshot(path=os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore
|
||||
self.logger.info(
|
||||
WebSurferEvent(
|
||||
source=self.name,
|
||||
url=self._page.url,
|
||||
message="Screenshot: " + screenshot_png_name,
|
||||
)
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
WebSurferEvent(
|
||||
source=self.name,
|
||||
url=self._page.url,
|
||||
message="Resetting browser.",
|
||||
)
|
||||
)
|
||||
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
return message
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
for chat_message in messages:
|
||||
self._chat_history.append(chat_message.to_model_message())
|
||||
|
||||
self.inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
self.model_usage: List[RequestUsage] = []
|
||||
try:
|
||||
content = await self._generate_reply(cancellation_token=cancellation_token)
|
||||
self._chat_history.append(AssistantMessage(content=content_to_str(content), source=self.name))
|
||||
final_usage = RequestUsage(
|
||||
prompt_tokens=sum([u.prompt_tokens for u in self.model_usage]),
|
||||
completion_tokens=sum([u.completion_tokens for u in self.model_usage]),
|
||||
)
|
||||
if isinstance(content, str):
|
||||
yield Response(
|
||||
chat_message=TextMessage(content=content, source=self.name, models_usage=final_usage),
|
||||
inner_messages=self.inner_messages,
|
||||
)
|
||||
else:
|
||||
yield Response(
|
||||
chat_message=MultiModalMessage(content=content, source=self.name, models_usage=final_usage),
|
||||
inner_messages=self.inner_messages,
|
||||
)
|
||||
|
||||
except BaseException:
|
||||
content = f"Web surfing error:\n\n{traceback.format_exc()}"
|
||||
self._chat_history.append(AssistantMessage(content=content, source=self.name))
|
||||
yield Response(chat_message=TextMessage(content=content, source=self.name))
|
||||
|
||||
async def _generate_reply(self, cancellation_token: CancellationToken) -> UserContent:
|
||||
"""Generates the actual reply. First calls the LLM to figure out which tool to use, then executes the tool."""
|
||||
|
||||
# Lazy init, initialize the browser and the page on the first generate reply only
|
||||
if not self.did_lazy_init:
|
||||
await self._lazy_init()
|
||||
|
||||
assert self._page is not None
|
||||
|
||||
# Clone the messages, removing old screenshots
|
||||
history: List[LLMMessage] = remove_images(self._chat_history)
|
||||
|
||||
# Split the history, removing the last message
|
||||
if len(history):
|
||||
user_request = history.pop()
|
||||
else:
|
||||
user_request = UserMessage(content="Empty request.", source="user")
|
||||
|
||||
# Truncate the history for smaller models
|
||||
if self._model_client.model_info["family"] not in [
|
||||
ModelFamily.GPT_4O,
|
||||
ModelFamily.O1,
|
||||
ModelFamily.O3,
|
||||
ModelFamily.GPT_4,
|
||||
ModelFamily.GPT_35,
|
||||
]:
|
||||
history = []
|
||||
|
||||
# Ask the page for interactive elements, then prepare the state-of-mark screenshot
|
||||
rects = await self._playwright_controller.get_interactive_rects(self._page)
|
||||
viewport = await self._playwright_controller.get_visual_viewport(self._page)
|
||||
screenshot = await self._page.screenshot()
|
||||
som_screenshot, visible_rects, rects_above, rects_below = add_set_of_mark(screenshot, rects)
|
||||
|
||||
if self.to_save_screenshots:
|
||||
current_timestamp = "_" + int(time.time()).__str__()
|
||||
screenshot_png_name = "screenshot_som" + current_timestamp + ".png"
|
||||
som_screenshot.save(os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore
|
||||
self.logger.info(
|
||||
WebSurferEvent(
|
||||
source=self.name,
|
||||
url=self._page.url,
|
||||
message="Screenshot: " + screenshot_png_name,
|
||||
)
|
||||
)
|
||||
# What tools are available?
|
||||
tools = self.default_tools.copy()
|
||||
|
||||
# We can scroll up
|
||||
if viewport["pageTop"] > 5:
|
||||
tools.append(TOOL_SCROLL_UP)
|
||||
|
||||
# Can scroll down
|
||||
if (viewport["pageTop"] + viewport["height"] + 5) < viewport["scrollHeight"]:
|
||||
tools.append(TOOL_SCROLL_DOWN)
|
||||
|
||||
# Focus hint
|
||||
focused = await self._playwright_controller.get_focused_rect_id(self._page)
|
||||
focused_hint = ""
|
||||
if focused:
|
||||
name = self._target_name(focused, rects)
|
||||
if name:
|
||||
name = f"(and name '{name}') "
|
||||
else:
|
||||
name = ""
|
||||
|
||||
role = "control"
|
||||
try:
|
||||
role = rects[focused]["role"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
focused_hint = f"\nThe {role} with ID {focused} {name}currently has the input focus.\n\n"
|
||||
|
||||
# Everything visible
|
||||
visible_targets = "\n".join(self._format_target_list(visible_rects, rects)) + "\n\n"
|
||||
|
||||
# Everything else
|
||||
other_targets: List[str] = []
|
||||
other_targets.extend(self._format_target_list(rects_above, rects))
|
||||
other_targets.extend(self._format_target_list(rects_below, rects))
|
||||
|
||||
if len(other_targets) > 0:
|
||||
if len(other_targets) > 30:
|
||||
other_targets = other_targets[0:30]
|
||||
other_targets.append("...")
|
||||
other_targets_str = (
|
||||
"Additional valid interaction targets include (but are not limited to):\n"
|
||||
+ "\n".join(other_targets)
|
||||
+ "\n\n"
|
||||
)
|
||||
else:
|
||||
other_targets_str = ""
|
||||
|
||||
state_description = "Your " + await self._get_state_description()
|
||||
tool_names = "\n".join([t["name"] for t in tools])
|
||||
page_title = await self._page.title()
|
||||
|
||||
prompt_message = None
|
||||
if self._model_client.model_info["vision"]:
|
||||
text_prompt = WEB_SURFER_TOOL_PROMPT_MM.format(
|
||||
state_description=state_description,
|
||||
visible_targets=visible_targets,
|
||||
other_targets_str=other_targets_str,
|
||||
focused_hint=focused_hint,
|
||||
tool_names=tool_names,
|
||||
title=page_title,
|
||||
url=self._page.url,
|
||||
).strip()
|
||||
|
||||
# Scale the screenshot for the MLM, and close the original
|
||||
scaled_screenshot = som_screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
|
||||
som_screenshot.close()
|
||||
if self.to_save_screenshots:
|
||||
scaled_screenshot.save(os.path.join(self.debug_dir, "screenshot_scaled.png")) # type: ignore
|
||||
|
||||
# Create the message
|
||||
prompt_message = UserMessage(
|
||||
content=[re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), AGImage.from_pil(scaled_screenshot)],
|
||||
source=self.name,
|
||||
)
|
||||
else:
|
||||
text_prompt = WEB_SURFER_TOOL_PROMPT_TEXT.format(
|
||||
state_description=state_description,
|
||||
visible_targets=visible_targets,
|
||||
other_targets_str=other_targets_str,
|
||||
focused_hint=focused_hint,
|
||||
tool_names=tool_names,
|
||||
title=page_title,
|
||||
url=self._page.url,
|
||||
).strip()
|
||||
|
||||
# Create the message
|
||||
prompt_message = UserMessage(content=re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), source=self.name)
|
||||
|
||||
history.append(prompt_message)
|
||||
history.append(user_request)
|
||||
|
||||
# {history[-2].content if isinstance(history[-2].content, str) else history[-2].content[0]}
|
||||
# print(f"""
|
||||
# ================={len(history)}=================
|
||||
# {history[-2].content}
|
||||
# =====
|
||||
# {history[-1].content}
|
||||
# ===================================================
|
||||
# """)
|
||||
|
||||
# Make the request
|
||||
response = await self._model_client.create(
|
||||
history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token
|
||||
) # , "parallel_tool_calls": False})
|
||||
|
||||
self.model_usage.append(response.usage)
|
||||
message = response.content
|
||||
self._last_download = None
|
||||
if isinstance(message, str):
|
||||
# Answer directly
|
||||
self.inner_messages.append(TextMessage(content=message, source=self.name))
|
||||
return message
|
||||
elif isinstance(message, list):
|
||||
# Take an action
|
||||
return await self._execute_tool(message, rects, tool_names, cancellation_token=cancellation_token)
|
||||
else:
|
||||
# Not sure what happened here
|
||||
raise AssertionError(f"Unknown response format '{message}'")
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
message: List[FunctionCall],
|
||||
rects: Dict[str, InteractiveRegion],
|
||||
tool_names: str,
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> UserContent:
|
||||
# Execute the tool
|
||||
name = message[0].name
|
||||
args = json.loads(message[0].arguments)
|
||||
action_description = ""
|
||||
assert self._page is not None
|
||||
self.logger.info(
|
||||
WebSurferEvent(
|
||||
source=self.name,
|
||||
url=self._page.url,
|
||||
action=name,
|
||||
arguments=args,
|
||||
message=f"{name}( {json.dumps(args)} )",
|
||||
)
|
||||
)
|
||||
self.inner_messages.append(TextMessage(content=f"{name}( {json.dumps(args)} )", source=self.name))
|
||||
|
||||
if name == "visit_url":
|
||||
url = args.get("url")
|
||||
action_description = f"I typed '{url}' into the browser address bar."
|
||||
# Check if the argument starts with a known protocol
|
||||
if url.startswith(("https://", "http://", "file://", "about:")):
|
||||
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
|
||||
self._page, url
|
||||
)
|
||||
# If the argument contains a space, treat it as a search query
|
||||
elif " " in url:
|
||||
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
|
||||
self._page, f"https://www.bing.com/search?q={quote_plus(url)}&FORM=QBLH"
|
||||
)
|
||||
# Otherwise, prefix with https://
|
||||
else:
|
||||
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
|
||||
self._page, "https://" + url
|
||||
)
|
||||
if reset_last_download and self._last_download is not None:
|
||||
self._last_download = None
|
||||
if reset_prior_metadata and self._prior_metadata_hash is not None:
|
||||
self._prior_metadata_hash = None
|
||||
elif name == "history_back":
|
||||
action_description = "I clicked the browser back button."
|
||||
await self._playwright_controller.back(self._page)
|
||||
|
||||
elif name == "web_search":
|
||||
query = args.get("query")
|
||||
action_description = f"I typed '{query}' into the browser search bar."
|
||||
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
|
||||
self._page, f"https://www.bing.com/search?q={quote_plus(query)}&FORM=QBLH"
|
||||
)
|
||||
if reset_last_download and self._last_download is not None:
|
||||
self._last_download = None
|
||||
if reset_prior_metadata and self._prior_metadata_hash is not None:
|
||||
self._prior_metadata_hash = None
|
||||
elif name == "scroll_up":
|
||||
action_description = "I scrolled up one page in the browser."
|
||||
await self._playwright_controller.page_up(self._page)
|
||||
elif name == "scroll_down":
|
||||
action_description = "I scrolled down one page in the browser."
|
||||
await self._playwright_controller.page_down(self._page)
|
||||
|
||||
elif name == "click":
|
||||
target_id = str(args.get("target_id"))
|
||||
target_name = self._target_name(target_id, rects)
|
||||
if target_name:
|
||||
action_description = f"I clicked '{target_name}'."
|
||||
else:
|
||||
action_description = "I clicked the control."
|
||||
new_page_tentative = await self._playwright_controller.click_id(self._page, target_id)
|
||||
if new_page_tentative is not None:
|
||||
self._page = new_page_tentative
|
||||
self._prior_metadata_hash = None
|
||||
self.logger.info(
|
||||
WebSurferEvent(
|
||||
source=self.name,
|
||||
url=self._page.url,
|
||||
message="New tab or window.",
|
||||
)
|
||||
)
|
||||
elif name == "input_text":
|
||||
input_field_id = str(args.get("input_field_id"))
|
||||
text_value = str(args.get("text_value"))
|
||||
input_field_name = self._target_name(input_field_id, rects)
|
||||
if input_field_name:
|
||||
action_description = f"I typed '{text_value}' into '{input_field_name}'."
|
||||
else:
|
||||
action_description = f"I input '{text_value}'."
|
||||
await self._playwright_controller.fill_id(self._page, input_field_id, text_value)
|
||||
|
||||
elif name == "scroll_element_up":
|
||||
target_id = str(args.get("target_id"))
|
||||
target_name = self._target_name(target_id, rects)
|
||||
|
||||
if target_name:
|
||||
action_description = f"I scrolled '{target_name}' up."
|
||||
else:
|
||||
action_description = "I scrolled the control up."
|
||||
|
||||
await self._playwright_controller.scroll_id(self._page, target_id, "up")
|
||||
|
||||
elif name == "scroll_element_down":
|
||||
target_id = str(args.get("target_id"))
|
||||
target_name = self._target_name(target_id, rects)
|
||||
|
||||
if target_name:
|
||||
action_description = f"I scrolled '{target_name}' down."
|
||||
else:
|
||||
action_description = "I scrolled the control down."
|
||||
|
||||
await self._playwright_controller.scroll_id(self._page, target_id, "down")
|
||||
|
||||
elif name == "answer_question":
|
||||
question = str(args.get("question"))
|
||||
action_description = f"I answered the following question '{question}' based on the web page."
|
||||
# Do Q&A on the DOM. No need to take further action. Browser state does not change.
|
||||
return await self._summarize_page(question=question, cancellation_token=cancellation_token)
|
||||
elif name == "summarize_page":
|
||||
# Summarize the DOM. No need to take further action. Browser state does not change.
|
||||
action_description = "I summarized the current web page"
|
||||
return await self._summarize_page(cancellation_token=cancellation_token)
|
||||
|
||||
elif name == "hover":
|
||||
target_id = str(args.get("target_id"))
|
||||
target_name = self._target_name(target_id, rects)
|
||||
if target_name:
|
||||
action_description = f"I hovered over '{target_name}'."
|
||||
else:
|
||||
action_description = "I hovered over the control."
|
||||
await self._playwright_controller.hover_id(self._page, target_id)
|
||||
|
||||
elif name == "sleep":
|
||||
action_description = "I am waiting a short period of time before taking further action."
|
||||
await self._playwright_controller.sleep(self._page, 3)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown tool '{name}'. Please choose from:\n\n{tool_names}")
|
||||
|
||||
await self._page.wait_for_load_state()
|
||||
await self._playwright_controller.sleep(self._page, 3)
|
||||
|
||||
# Handle downloads
|
||||
if self._last_download is not None and self.downloads_folder is not None:
|
||||
fname = os.path.join(self.downloads_folder, self._last_download.suggested_filename)
|
||||
await self._last_download.save_as(fname) # type: ignore
|
||||
page_body = f"<html><head><title>Download Successful</title></head><body style=\"margin: 20px;\"><h1>Successfully downloaded '{self._last_download.suggested_filename}' to local path:<br><br>{fname}</h1></body></html>"
|
||||
await self._page.goto(
|
||||
"data:text/html;base64," + base64.b64encode(page_body.encode("utf-8")).decode("utf-8")
|
||||
)
|
||||
await self._page.wait_for_load_state()
|
||||
|
||||
# Handle metadata
|
||||
page_metadata = json.dumps(await self._playwright_controller.get_page_metadata(self._page), indent=4)
|
||||
metadata_hash = hashlib.md5(page_metadata.encode("utf-8")).hexdigest()
|
||||
if metadata_hash != self._prior_metadata_hash:
|
||||
page_metadata = (
|
||||
"\n\nThe following metadata was extracted from the webpage:\n\n" + page_metadata.strip() + "\n"
|
||||
)
|
||||
else:
|
||||
page_metadata = ""
|
||||
self._prior_metadata_hash = metadata_hash
|
||||
|
||||
new_screenshot = await self._page.screenshot()
|
||||
if self.to_save_screenshots:
|
||||
current_timestamp = "_" + int(time.time()).__str__()
|
||||
screenshot_png_name = "screenshot" + current_timestamp + ".png"
|
||||
|
||||
async with aiofiles.open(os.path.join(self.debug_dir, screenshot_png_name), "wb") as file: # type: ignore
|
||||
await file.write(new_screenshot) # type: ignore
|
||||
self.logger.info(
|
||||
WebSurferEvent(
|
||||
source=self.name,
|
||||
url=self._page.url,
|
||||
message="Screenshot: " + screenshot_png_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Return the complete observation
|
||||
state_description = "The " + await self._get_state_description()
|
||||
message_content = (
|
||||
f"{action_description}\n\n" + state_description + page_metadata + "\nHere is a screenshot of the page."
|
||||
)
|
||||
|
||||
return [
|
||||
re.sub(r"(\n\s*){3,}", "\n\n", message_content), # Removing blank lines
|
||||
AGImage.from_pil(PIL.Image.open(io.BytesIO(new_screenshot))),
|
||||
]
|
||||
|
||||
async def _get_state_description(self) -> str:
|
||||
assert self._playwright_controller is not None
|
||||
assert self._page is not None
|
||||
|
||||
# Describe the viewport of the new page in words
|
||||
viewport = await self._playwright_controller.get_visual_viewport(self._page)
|
||||
percent_visible = int(viewport["height"] * 100 / viewport["scrollHeight"])
|
||||
percent_scrolled = int(viewport["pageTop"] * 100 / viewport["scrollHeight"])
|
||||
if percent_scrolled < 1: # Allow some rounding error
|
||||
position_text = "at the top of the page"
|
||||
elif percent_scrolled + percent_visible >= 99: # Allow some rounding error
|
||||
position_text = "at the bottom of the page"
|
||||
else:
|
||||
position_text = str(percent_scrolled) + "% down from the top of the page"
|
||||
|
||||
visible_text = await self._playwright_controller.get_visible_text(self._page)
|
||||
|
||||
# Return the complete observation
|
||||
page_title = await self._page.title()
|
||||
message_content = f"web browser is open to the page [{page_title}]({self._page.url}).\nThe viewport shows {percent_visible}% of the webpage, and is positioned {position_text}\n"
|
||||
message_content += f"The following text is visible in the viewport:\n\n{visible_text}"
|
||||
return message_content
|
||||
|
||||
def _target_name(self, target: str, rects: Dict[str, InteractiveRegion]) -> str | None:
|
||||
try:
|
||||
return rects[target]["aria_name"].strip()
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def _format_target_list(self, ids: List[str], rects: Dict[str, InteractiveRegion]) -> List[str]:
|
||||
"""
|
||||
Format the list of targets in the webpage as a string to be used in the agent's prompt.
|
||||
"""
|
||||
targets: List[str] = []
|
||||
for r in list(set(ids)):
|
||||
if r in rects:
|
||||
# Get the role
|
||||
aria_role = rects[r].get("role", "").strip()
|
||||
if len(aria_role) == 0:
|
||||
aria_role = rects[r].get("tag_name", "").strip()
|
||||
|
||||
# Get the name
|
||||
aria_name = re.sub(r"[\n\r]+", " ", rects[r].get("aria_name", "")).strip()
|
||||
|
||||
# What are the actions?
|
||||
actions = ['"click", "hover"']
|
||||
if rects[r]["role"] in ["textbox", "searchbox", "search"]:
|
||||
actions = ['"input_text"']
|
||||
actions_str = "[" + ",".join(actions) + "]"
|
||||
|
||||
targets.append(f'{{"id": {r}, "name": "{aria_name}", "role": "{aria_role}", "tools": {actions_str} }}')
|
||||
|
||||
return targets
|
||||
|
||||
async def _summarize_page(
|
||||
self,
|
||||
question: str | None = None,
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> str:
|
||||
assert self._page is not None
|
||||
|
||||
page_markdown: str = await self._playwright_controller.get_page_markdown(self._page)
|
||||
|
||||
title: str = self._page.url
|
||||
try:
|
||||
title = await self._page.title()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Take a screenshot and scale it
|
||||
screenshot = Image.open(io.BytesIO(await self._page.screenshot()))
|
||||
scaled_screenshot = screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
|
||||
screenshot.close()
|
||||
ag_image = AGImage.from_pil(scaled_screenshot)
|
||||
|
||||
# Prepare the system prompt
|
||||
messages: List[LLMMessage] = []
|
||||
messages.append(SystemMessage(content=WEB_SURFER_QA_SYSTEM_MESSAGE))
|
||||
prompt = WEB_SURFER_QA_PROMPT(title, question)
|
||||
# Grow the buffer (which is added to the prompt) until we overflow the context window or run out of lines
|
||||
buffer = ""
|
||||
# for line in re.split(r"([\r\n]+)", page_markdown):
|
||||
for line in page_markdown.splitlines():
|
||||
trial_message = UserMessage(
|
||||
content=prompt + buffer + line,
|
||||
source=self.name,
|
||||
)
|
||||
|
||||
try:
|
||||
remaining = self._model_client.remaining_tokens(messages + [trial_message])
|
||||
except KeyError:
|
||||
# Use the default if the model isn't found
|
||||
remaining = DEFAULT_CONTEXT_SIZE - self._model_client.count_tokens(messages + [trial_message])
|
||||
|
||||
if self._model_client.model_info["vision"] and remaining <= 0:
|
||||
break
|
||||
|
||||
if self._model_client.model_info["vision"] and remaining <= self.SCREENSHOT_TOKENS:
|
||||
break
|
||||
|
||||
buffer += line
|
||||
|
||||
# Nothing to do
|
||||
buffer = buffer.strip()
|
||||
if len(buffer) == 0:
|
||||
return "Nothing to summarize."
|
||||
|
||||
# Append the message
|
||||
if self._model_client.model_info["vision"]:
|
||||
# Multimodal
|
||||
messages.append(
|
||||
UserMessage(
|
||||
content=[
|
||||
prompt + buffer,
|
||||
ag_image,
|
||||
],
|
||||
source=self.name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Text only
|
||||
messages.append(
|
||||
UserMessage(
|
||||
content=prompt + buffer,
|
||||
source=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
# Generate the response
|
||||
response = await self._model_client.create(messages, cancellation_token=cancellation_token)
|
||||
self.model_usage.append(response.usage)
|
||||
scaled_screenshot.close()
|
||||
assert isinstance(response.content, str)
|
||||
return response.content
|
||||
|
||||
def _to_config(self) -> MultimodalWebSurferConfig:
|
||||
return MultimodalWebSurferConfig(
|
||||
name=self.name,
|
||||
model_client=self._model_client.dump_component(),
|
||||
downloads_folder=self.downloads_folder,
|
||||
description=self.description,
|
||||
debug_dir=self.debug_dir,
|
||||
headless=self.headless,
|
||||
start_page=self.start_page,
|
||||
animate_actions=self.animate_actions,
|
||||
to_save_screenshots=self.to_save_screenshots,
|
||||
use_ocr=self.use_ocr,
|
||||
browser_channel=self.browser_channel,
|
||||
browser_data_dir=self.browser_data_dir,
|
||||
to_resize_viewport=self.to_resize_viewport,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: MultimodalWebSurferConfig) -> Self:
|
||||
return cls(
|
||||
name=config.name,
|
||||
model_client=ChatCompletionClient.load_component(config.model_client),
|
||||
downloads_folder=config.downloads_folder,
|
||||
description=config.description or cls.DEFAULT_DESCRIPTION,
|
||||
debug_dir=config.debug_dir,
|
||||
headless=config.headless,
|
||||
start_page=config.start_page or cls.DEFAULT_START_PAGE,
|
||||
animate_actions=config.animate_actions,
|
||||
to_save_screenshots=config.to_save_screenshots,
|
||||
use_ocr=config.use_ocr,
|
||||
browser_channel=config.browser_channel,
|
||||
browser_data_dir=config.browser_data_dir,
|
||||
to_resize_viewport=config.to_resize_viewport,
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
WEB_SURFER_TOOL_PROMPT_MM = """
|
||||
{state_description}
|
||||
|
||||
Consider the following screenshot of the page. In this screenshot, interactive elements are outlined in bounding boxes of different colors. Each bounding box has a numeric ID label in the same color. Additional information about each visible label is listed below:
|
||||
|
||||
{visible_targets}{other_targets_str}{focused_hint}
|
||||
|
||||
You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible:
|
||||
|
||||
{tool_names}
|
||||
|
||||
When deciding between tools, consider if the request can be best addressed by:
|
||||
- the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate)
|
||||
- contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate
|
||||
- on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option)
|
||||
|
||||
My request follows:
|
||||
"""
|
||||
|
||||
WEB_SURFER_TOOL_PROMPT_TEXT = """
|
||||
{state_description}
|
||||
|
||||
You have also identified the following interactive components:
|
||||
|
||||
{visible_targets}{other_targets_str}{focused_hint}
|
||||
|
||||
You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible:
|
||||
|
||||
{tool_names}
|
||||
|
||||
When deciding between tools, consider if the request can be best addressed by:
|
||||
- the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate)
|
||||
- contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate
|
||||
- on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option)
|
||||
|
||||
My request follows:
|
||||
"""
|
||||
|
||||
|
||||
WEB_SURFER_QA_SYSTEM_MESSAGE = """
|
||||
You are a helpful assistant that can summarize long documents to answer question.
|
||||
"""
|
||||
|
||||
|
||||
def WEB_SURFER_QA_PROMPT(title: str, question: str | None = None) -> str:
|
||||
base_prompt = f"We are visiting the webpage '{title}'. Its full-text content are pasted below, along with a screenshot of the page's current viewport."
|
||||
if question is not None:
|
||||
return (
|
||||
f"{base_prompt} Please summarize the webpage into one or two paragraphs with respect to '{question}':\n\n"
|
||||
)
|
||||
else:
|
||||
return f"{base_prompt} Please summarize the webpage into one or two paragraphs:\n\n"
|
||||
@@ -0,0 +1,96 @@
|
||||
import io
|
||||
import random
|
||||
from typing import BinaryIO, Dict, List, Tuple, cast
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from ._types import DOMRectangle, InteractiveRegion
|
||||
|
||||
TOP_NO_LABEL_ZONE = 20 # Don't print any labels close the top of the page
|
||||
|
||||
|
||||
def add_set_of_mark(
|
||||
screenshot: bytes | Image.Image | io.BufferedIOBase, ROIs: Dict[str, InteractiveRegion]
|
||||
) -> Tuple[Image.Image, List[str], List[str], List[str]]:
|
||||
if isinstance(screenshot, Image.Image):
|
||||
return _add_set_of_mark(screenshot, ROIs)
|
||||
|
||||
if isinstance(screenshot, bytes):
|
||||
screenshot = io.BytesIO(screenshot)
|
||||
|
||||
# TODO: Not sure why this cast was needed, but by this point screenshot is a binary file-like object
|
||||
image = Image.open(cast(BinaryIO, screenshot))
|
||||
comp, visible_rects, rects_above, rects_below = _add_set_of_mark(image, ROIs)
|
||||
image.close()
|
||||
return comp, visible_rects, rects_above, rects_below
|
||||
|
||||
|
||||
def _add_set_of_mark(
|
||||
screenshot: Image.Image, ROIs: Dict[str, InteractiveRegion]
|
||||
) -> Tuple[Image.Image, List[str], List[str], List[str]]:
|
||||
visible_rects: List[str] = list()
|
||||
rects_above: List[str] = list() # Scroll up to see
|
||||
rects_below: List[str] = list() # Scroll down to see
|
||||
|
||||
fnt = ImageFont.load_default(14)
|
||||
base = screenshot.convert("L").convert("RGBA")
|
||||
overlay = Image.new("RGBA", base.size)
|
||||
|
||||
draw = ImageDraw.Draw(overlay)
|
||||
for r in ROIs:
|
||||
for rect in ROIs[r]["rects"]:
|
||||
# Empty rectangles
|
||||
if not rect:
|
||||
continue
|
||||
if rect["width"] * rect["height"] == 0:
|
||||
continue
|
||||
|
||||
mid = ((rect["right"] + rect["left"]) / 2.0, (rect["top"] + rect["bottom"]) / 2.0)
|
||||
|
||||
if 0 <= mid[0] and mid[0] < base.size[0]:
|
||||
if mid[1] < 0:
|
||||
rects_above.append(r)
|
||||
elif mid[1] >= base.size[1]:
|
||||
rects_below.append(r)
|
||||
else:
|
||||
visible_rects.append(r)
|
||||
_draw_roi(draw, int(r), fnt, rect)
|
||||
|
||||
comp = Image.alpha_composite(base, overlay)
|
||||
overlay.close()
|
||||
return comp, visible_rects, rects_above, rects_below
|
||||
|
||||
|
||||
def _draw_roi(
|
||||
draw: ImageDraw.ImageDraw, idx: int, font: ImageFont.FreeTypeFont | ImageFont.ImageFont, rect: DOMRectangle
|
||||
) -> None:
|
||||
color = _color(idx)
|
||||
luminance = color[0] * 0.3 + color[1] * 0.59 + color[2] * 0.11
|
||||
text_color = (0, 0, 0, 255) if luminance > 90 else (255, 255, 255, 255)
|
||||
|
||||
roi = ((rect["left"], rect["top"]), (rect["right"], rect["bottom"]))
|
||||
|
||||
label_location = (rect["right"], rect["top"])
|
||||
label_anchor = "rb"
|
||||
|
||||
if label_location[1] <= TOP_NO_LABEL_ZONE:
|
||||
label_location = (rect["right"], rect["bottom"])
|
||||
label_anchor = "rt"
|
||||
|
||||
draw.rectangle(roi, outline=color, fill=(color[0], color[1], color[2], 48), width=2)
|
||||
|
||||
# TODO: Having trouble with these types being partially Unknown.
|
||||
bbox = draw.textbbox(label_location, str(idx), font=font, anchor=label_anchor, align="center") # type: ignore
|
||||
bbox = (bbox[0] - 3, bbox[1] - 3, bbox[2] + 3, bbox[3] + 3)
|
||||
draw.rectangle(bbox, fill=color)
|
||||
|
||||
# TODO: Having trouble with these types being partially Unknown.
|
||||
draw.text(label_location, str(idx), fill=text_color, font=font, anchor=label_anchor, align="center") # type: ignore
|
||||
|
||||
|
||||
def _color(identifier: int) -> Tuple[int, int, int, int]:
|
||||
rnd = random.Random(int(identifier))
|
||||
color = [rnd.randint(0, 255), rnd.randint(125, 255), rnd.randint(0, 50)]
|
||||
rnd.shuffle(color)
|
||||
color.append(255)
|
||||
return cast(Tuple[int, int, int, int], tuple(color))
|
||||
@@ -0,0 +1,317 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from agentdhal_core.tools._base import ParametersSchema, ToolSchema
|
||||
|
||||
|
||||
def _load_tool(tooldef: Dict[str, Any]) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name=tooldef["function"]["name"],
|
||||
description=tooldef["function"]["description"],
|
||||
parameters=ParametersSchema(
|
||||
type="object",
|
||||
properties=tooldef["function"]["parameters"]["properties"],
|
||||
required=tooldef["function"]["parameters"]["required"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
REASONING_TOOL_PROMPT = (
|
||||
"A short description of the action to be performed and reason for doing so, do not mention the user."
|
||||
)
|
||||
|
||||
TOOL_VISIT_URL: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "visit_url",
|
||||
"description": "Navigate directly to a provided URL using the browser's address bar. Prefer this tool over other navigation techniques in cases where the user provides a fully-qualified URL (e.g., choose it over clicking links, or inputing queries into search boxes).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to visit in the browser.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "url"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_WEB_SEARCH: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Performs a web search on Bing.com with the given query.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The web search query to use.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_HISTORY_BACK: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "history_back",
|
||||
"description": "Navigates back one page in the browser's history. This is equivalent to clicking the browser back button.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
},
|
||||
"required": ["reasoning"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_SCROLL_UP: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "scroll_up",
|
||||
"description": "Scrolls the entire browser viewport one page UP towards the beginning.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
},
|
||||
"required": ["reasoning"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_SCROLL_DOWN: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "scroll_down",
|
||||
"description": "Scrolls the entire browser viewport one page DOWN towards the end.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
},
|
||||
"required": ["reasoning"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_CLICK: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "click",
|
||||
"description": "Clicks the mouse on the target with the given id.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"target_id": {
|
||||
"type": "integer",
|
||||
"description": "The numeric id of the target to click.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "target_id"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_TYPE: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "input_text",
|
||||
"description": "Types the given text value into the specified field.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"input_field_id": {
|
||||
"type": "integer",
|
||||
"description": "The numeric id of the input field to receive the text.",
|
||||
},
|
||||
"text_value": {
|
||||
"type": "string",
|
||||
"description": "The text to type into the input field.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "input_field_id", "text_value"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_SCROLL_ELEMENT_DOWN: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "scroll_element_down",
|
||||
"description": "Scrolls a given html element (e.g., a div or a menu) DOWN.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"target_id": {
|
||||
"type": "integer",
|
||||
"description": "The numeric id of the target to scroll down.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "target_id"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_SCROLL_ELEMENT_UP: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "scroll_element_up",
|
||||
"description": "Scrolls a given html element (e.g., a div or a menu) UP.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"target_id": {
|
||||
"type": "integer",
|
||||
"description": "The numeric id of the target to scroll UP.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "target_id"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_HOVER: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "hover",
|
||||
"description": "Hovers the mouse over the target with the given id.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"target_id": {
|
||||
"type": "integer",
|
||||
"description": "The numeric id of the target to hover over.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "target_id"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
TOOL_READ_PAGE_AND_ANSWER: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "answer_question",
|
||||
"description": "Uses AI to answer a question about the current webpage's content.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question to answer.",
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "question"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_SUMMARIZE_PAGE: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "summarize_page",
|
||||
"description": "Uses AI to summarize the entire page.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
},
|
||||
"required": ["reasoning"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
TOOL_SLEEP: ToolSchema = _load_tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "sleep",
|
||||
"description": "Wait a short period of time. Call this function if the page has not yet fully loaded, or if it is determined that a small delay would increase the task's chances of success.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": REASONING_TOOL_PROMPT,
|
||||
},
|
||||
},
|
||||
"required": ["reasoning"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
106
agent_dhal/agentdhal_extensions/agents/web_surfer/_types.py
Normal file
106
agent_dhal/agentdhal_extensions/agents/web_surfer/_types.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from typing import Any, Dict, List, TypedDict, Union
|
||||
|
||||
from agentdhal_core import FunctionCall, Image
|
||||
from agentdhal_core.models import FunctionExecutionResult
|
||||
|
||||
UserContent = Union[str, List[Union[str, Image]]]
|
||||
AssistantContent = Union[str, List[FunctionCall]]
|
||||
FunctionExecutionContent = List[FunctionExecutionResult]
|
||||
SystemContent = str
|
||||
|
||||
|
||||
class DOMRectangle(TypedDict):
|
||||
x: Union[int, float]
|
||||
y: Union[int, float]
|
||||
width: Union[int, float]
|
||||
height: Union[int, float]
|
||||
top: Union[int, float]
|
||||
right: Union[int, float]
|
||||
bottom: Union[int, float]
|
||||
left: Union[int, float]
|
||||
|
||||
|
||||
class VisualViewport(TypedDict):
|
||||
height: Union[int, float]
|
||||
width: Union[int, float]
|
||||
offsetLeft: Union[int, float]
|
||||
offsetTop: Union[int, float]
|
||||
pageLeft: Union[int, float]
|
||||
pageTop: Union[int, float]
|
||||
scale: Union[int, float]
|
||||
clientWidth: Union[int, float]
|
||||
clientHeight: Union[int, float]
|
||||
scrollWidth: Union[int, float]
|
||||
scrollHeight: Union[int, float]
|
||||
|
||||
|
||||
class InteractiveRegion(TypedDict):
|
||||
tag_name: str
|
||||
role: str
|
||||
aria_name: str
|
||||
v_scrollable: bool
|
||||
rects: List[DOMRectangle]
|
||||
|
||||
|
||||
# Helper functions for dealing with JSON. Not sure there's a better way?
|
||||
|
||||
|
||||
def _get_str(d: Any, k: str) -> str:
|
||||
val = d[k]
|
||||
assert isinstance(val, str)
|
||||
return val
|
||||
|
||||
|
||||
def _get_number(d: Any, k: str) -> Union[int, float]:
|
||||
val = d[k]
|
||||
assert isinstance(val, int) or isinstance(val, float)
|
||||
return val
|
||||
|
||||
|
||||
def _get_bool(d: Any, k: str) -> bool:
|
||||
val = d[k]
|
||||
assert isinstance(val, bool)
|
||||
return val
|
||||
|
||||
|
||||
def domrectangle_from_dict(rect: Dict[str, Any]) -> DOMRectangle:
|
||||
return DOMRectangle(
|
||||
x=_get_number(rect, "x"),
|
||||
y=_get_number(rect, "y"),
|
||||
width=_get_number(rect, "width"),
|
||||
height=_get_number(rect, "height"),
|
||||
top=_get_number(rect, "top"),
|
||||
right=_get_number(rect, "right"),
|
||||
bottom=_get_number(rect, "bottom"),
|
||||
left=_get_number(rect, "left"),
|
||||
)
|
||||
|
||||
|
||||
def interactiveregion_from_dict(region: Dict[str, Any]) -> InteractiveRegion:
|
||||
typed_rects: List[DOMRectangle] = []
|
||||
for rect in region["rects"]:
|
||||
typed_rects.append(domrectangle_from_dict(rect))
|
||||
|
||||
return InteractiveRegion(
|
||||
tag_name=_get_str(region, "tag_name"),
|
||||
role=_get_str(region, "role"),
|
||||
aria_name=_get_str(region, "aria-name"),
|
||||
v_scrollable=_get_bool(region, "v-scrollable"),
|
||||
rects=typed_rects,
|
||||
)
|
||||
|
||||
|
||||
def visualviewport_from_dict(viewport: Dict[str, Any]) -> VisualViewport:
|
||||
return VisualViewport(
|
||||
height=_get_number(viewport, "height"),
|
||||
width=_get_number(viewport, "width"),
|
||||
offsetLeft=_get_number(viewport, "offsetLeft"),
|
||||
offsetTop=_get_number(viewport, "offsetTop"),
|
||||
pageLeft=_get_number(viewport, "pageLeft"),
|
||||
pageTop=_get_number(viewport, "pageTop"),
|
||||
scale=_get_number(viewport, "scale"),
|
||||
clientWidth=_get_number(viewport, "clientWidth"),
|
||||
clientHeight=_get_number(viewport, "clientHeight"),
|
||||
scrollWidth=_get_number(viewport, "scrollWidth"),
|
||||
scrollHeight=_get_number(viewport, "scrollHeight"),
|
||||
)
|
||||
429
agent_dhal/agentdhal_extensions/agents/web_surfer/page_script.js
Normal file
429
agent_dhal/agentdhal_extensions/agents/web_surfer/page_script.js
Normal file
@@ -0,0 +1,429 @@
|
||||
var MultimodalWebSurfer = MultimodalWebSurfer || (function() {
|
||||
let nextLabel = 10;
|
||||
|
||||
let roleMapping = {
|
||||
"a": "link",
|
||||
"area": "link",
|
||||
"button": "button",
|
||||
"input, type=button": "button",
|
||||
"input, type=checkbox": "checkbox",
|
||||
"input, type=email": "textbox",
|
||||
"input, type=number": "spinbutton",
|
||||
"input, type=radio": "radio",
|
||||
"input, type=range": "slider",
|
||||
"input, type=reset": "button",
|
||||
"input, type=search": "searchbox",
|
||||
"input, type=submit": "button",
|
||||
"input, type=tel": "textbox",
|
||||
"input, type=text": "textbox",
|
||||
"input, type=url": "textbox",
|
||||
"search": "search",
|
||||
"select": "combobox",
|
||||
"option": "option",
|
||||
"textarea": "textbox"
|
||||
};
|
||||
|
||||
let getCursor = function(elm) {
|
||||
return window.getComputedStyle(elm)["cursor"];
|
||||
};
|
||||
|
||||
let getInteractiveElements = function() {
|
||||
|
||||
let results = []
|
||||
let roles = ["scrollbar", "searchbox", "slider", "spinbutton", "switch", "tab", "treeitem", "button", "checkbox", "gridcell", "link", "menuitem", "menuitemcheckbox", "menuitemradio", "option", "progressbar", "radio", "textbox", "combobox", "menu", "tree", "treegrid", "grid", "listbox", "radiogroup", "widget"];
|
||||
let inertCursors = ["auto", "default", "none", "text", "vertical-text", "not-allowed", "no-drop"];
|
||||
|
||||
// Get the main interactive elements
|
||||
let nodeList = document.querySelectorAll("input, select, textarea, button, [href], [onclick], [contenteditable], [tabindex]:not([tabindex='-1'])");
|
||||
for (let i=0; i<nodeList.length; i++) { // Copy to something mutable
|
||||
results.push(nodeList[i]);
|
||||
}
|
||||
|
||||
// Anything not already included that has a suitable role
|
||||
nodeList = document.querySelectorAll("[role]");
|
||||
for (let i=0; i<nodeList.length; i++) { // Copy to something mutable
|
||||
if (results.indexOf(nodeList[i]) == -1) {
|
||||
let role = nodeList[i].getAttribute("role");
|
||||
if (roles.indexOf(role) > -1) {
|
||||
results.push(nodeList[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Any element that changes the cursor to something implying interactivity
|
||||
nodeList = document.querySelectorAll("*");
|
||||
for (let i=0; i<nodeList.length; i++) {
|
||||
let node = nodeList[i];
|
||||
|
||||
// Cursor is default, or does not suggest interactivity
|
||||
let cursor = getCursor(node);
|
||||
if (inertCursors.indexOf(cursor) >= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Move up to the first instance of this cursor change
|
||||
parent = node.parentNode;
|
||||
while (parent && getCursor(parent) == cursor) {
|
||||
node = parent;
|
||||
parent = node.parentNode;
|
||||
}
|
||||
|
||||
// Add the node if it is new
|
||||
if (results.indexOf(node) == -1) {
|
||||
results.push(node);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
};
|
||||
|
||||
let labelElements = function(elements) {
|
||||
for (let i=0; i<elements.length; i++) {
|
||||
if (!elements[i].hasAttribute("__elementId")) {
|
||||
elements[i].setAttribute("__elementId", "" + (nextLabel++));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let isTopmost = function(element, x, y) {
|
||||
let hit = document.elementFromPoint(x, y);
|
||||
|
||||
// Hack to handle elements outside the viewport
|
||||
if (hit === null) {
|
||||
return true;
|
||||
}
|
||||
|
||||
while (hit) {
|
||||
if (hit == element) return true;
|
||||
hit = hit.parentNode;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
let getFocusedElementId = function() {
|
||||
let elm = document.activeElement;
|
||||
while (elm) {
|
||||
if (elm.hasAttribute && elm.hasAttribute("__elementId")) {
|
||||
return elm.getAttribute("__elementId");
|
||||
}
|
||||
elm = elm.parentNode;
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
let trimmedInnerText = function(element) {
|
||||
if (!element) {
|
||||
return "";
|
||||
}
|
||||
let text = element.innerText;
|
||||
if (!text) {
|
||||
return "";
|
||||
}
|
||||
return text.trim();
|
||||
};
|
||||
|
||||
let getApproximateAriaName = function(element) {
|
||||
// Check for aria labels
|
||||
if (element.hasAttribute("aria-labelledby")) {
|
||||
let buffer = "";
|
||||
let ids = element.getAttribute("aria-labelledby").split(" ");
|
||||
for (let i=0; i<ids.length; i++) {
|
||||
let label = document.getElementById(ids[i]);
|
||||
if (label) {
|
||||
buffer = buffer + " " + trimmedInnerText(label);
|
||||
}
|
||||
}
|
||||
return buffer.trim();
|
||||
}
|
||||
|
||||
if (element.hasAttribute("aria-label")) {
|
||||
return element.getAttribute("aria-label");
|
||||
}
|
||||
|
||||
// Check for labels
|
||||
if (element.hasAttribute("id")) {
|
||||
let label_id = element.getAttribute("id");
|
||||
let label = "";
|
||||
let labels = document.querySelectorAll("label[for='" + label_id + "']");
|
||||
for (let j=0; j<labels.length; j++) {
|
||||
label += labels[j].innerText + " ";
|
||||
}
|
||||
label = label.trim();
|
||||
if (label != "") {
|
||||
return label;
|
||||
}
|
||||
}
|
||||
|
||||
if (element.parentElement && element.parentElement.tagName == "LABEL") {
|
||||
return element.parentElement.innerText;
|
||||
}
|
||||
|
||||
// Check for alt text or titles
|
||||
if (element.hasAttribute("alt")) {
|
||||
return element.getAttribute("alt")
|
||||
}
|
||||
|
||||
if (element.hasAttribute("title")) {
|
||||
return element.getAttribute("title")
|
||||
}
|
||||
|
||||
return trimmedInnerText(element);
|
||||
};
|
||||
|
||||
let getApproximateAriaRole = function(element) {
|
||||
let tag = element.tagName.toLowerCase();
|
||||
if (tag == "input" && element.hasAttribute("type")) {
|
||||
tag = tag + ", type=" + element.getAttribute("type");
|
||||
}
|
||||
|
||||
if (element.hasAttribute("role")) {
|
||||
return [element.getAttribute("role"), tag];
|
||||
}
|
||||
else if (tag in roleMapping) {
|
||||
return [roleMapping[tag], tag];
|
||||
}
|
||||
else {
|
||||
return ["", tag];
|
||||
}
|
||||
};
|
||||
|
||||
let getInteractiveRects = function() {
|
||||
labelElements(getInteractiveElements());
|
||||
let elements = document.querySelectorAll("[__elementId]");
|
||||
let results = {};
|
||||
for (let i=0; i<elements.length; i++) {
|
||||
let key = elements[i].getAttribute("__elementId");
|
||||
let rects = elements[i].getClientRects();
|
||||
let ariaRole = getApproximateAriaRole(elements[i]);
|
||||
let ariaName = getApproximateAriaName(elements[i]);
|
||||
let vScrollable = elements[i].scrollHeight - elements[i].clientHeight >= 1;
|
||||
|
||||
let record = {
|
||||
"tag_name": ariaRole[1],
|
||||
"role": ariaRole[0],
|
||||
"aria-name": ariaName,
|
||||
"v-scrollable": vScrollable,
|
||||
"rects": []
|
||||
};
|
||||
|
||||
for (const rect of rects) {
|
||||
let x = rect.left + rect.width/2;
|
||||
let y = rect.top + rect.height/2;
|
||||
if (isTopmost(elements[i], x, y)) {
|
||||
record["rects"].push(JSON.parse(JSON.stringify(rect)));
|
||||
}
|
||||
}
|
||||
|
||||
if (record["rects"].length > 0) {
|
||||
results[key] = record;
|
||||
}
|
||||
}
|
||||
return results;
|
||||
};
|
||||
|
||||
let getVisualViewport = function() {
|
||||
let vv = window.visualViewport;
|
||||
let de = document.documentElement;
|
||||
return {
|
||||
"height": vv ? vv.height : 0,
|
||||
"width": vv ? vv.width : 0,
|
||||
"offsetLeft": vv ? vv.offsetLeft : 0,
|
||||
"offsetTop": vv ? vv.offsetTop : 0,
|
||||
"pageLeft": vv ? vv.pageLeft : 0,
|
||||
"pageTop": vv ? vv.pageTop : 0,
|
||||
"scale": vv ? vv.scale : 0,
|
||||
"clientWidth": de ? de.clientWidth : 0,
|
||||
"clientHeight": de ? de.clientHeight : 0,
|
||||
"scrollWidth": de ? de.scrollWidth : 0,
|
||||
"scrollHeight": de ? de.scrollHeight : 0
|
||||
};
|
||||
};
|
||||
|
||||
let _getMetaTags = function() {
|
||||
let meta = document.querySelectorAll("meta");
|
||||
let results = {};
|
||||
for (let i = 0; i<meta.length; i++) {
|
||||
let key = null;
|
||||
if (meta[i].hasAttribute("name")) {
|
||||
key = meta[i].getAttribute("name");
|
||||
}
|
||||
else if (meta[i].hasAttribute("property")) {
|
||||
key = meta[i].getAttribute("property");
|
||||
}
|
||||
else {
|
||||
continue;
|
||||
}
|
||||
if (meta[i].hasAttribute("content")) {
|
||||
results[key] = meta[i].getAttribute("content");
|
||||
}
|
||||
}
|
||||
return results;
|
||||
};
|
||||
|
||||
let _getJsonLd = function() {
|
||||
let jsonld = [];
|
||||
let scripts = document.querySelectorAll('script[type="application/ld+json"]');
|
||||
for (let i=0; i<scripts.length; i++) {
|
||||
jsonld.push(scripts[i].innerHTML.trim());
|
||||
}
|
||||
return jsonld;
|
||||
};
|
||||
|
||||
// From: https://www.stevefenton.co.uk/blog/2022/12/parse-microdata-with-javascript/
|
||||
let _getMicrodata = function() {
|
||||
function sanitize(input) {
|
||||
return input.replace(/\s/gi, ' ').trim();
|
||||
}
|
||||
|
||||
function addValue(information, name, value) {
|
||||
if (information[name]) {
|
||||
if (typeof information[name] === 'array') {
|
||||
information[name].push(value);
|
||||
} else {
|
||||
const arr = [];
|
||||
arr.push(information[name]);
|
||||
arr.push(value);
|
||||
information[name] = arr;
|
||||
}
|
||||
} else {
|
||||
information[name] = value;
|
||||
}
|
||||
}
|
||||
|
||||
function traverseItem(item, information) {
|
||||
const children = item.children;
|
||||
|
||||
for (let i = 0; i < children.length; i++) {
|
||||
const child = children[i];
|
||||
|
||||
if (child.hasAttribute('itemscope')) {
|
||||
if (child.hasAttribute('itemprop')) {
|
||||
const itemProp = child.getAttribute('itemprop');
|
||||
const itemType = child.getAttribute('itemtype');
|
||||
|
||||
const childInfo = {
|
||||
itemType: itemType
|
||||
};
|
||||
|
||||
traverseItem(child, childInfo);
|
||||
|
||||
itemProp.split(' ').forEach(propName => {
|
||||
addValue(information, propName, childInfo);
|
||||
});
|
||||
}
|
||||
|
||||
} else if (child.hasAttribute('itemprop')) {
|
||||
const itemProp = child.getAttribute('itemprop');
|
||||
itemProp.split(' ').forEach(propName => {
|
||||
if (propName === 'url') {
|
||||
addValue(information, propName, child.href);
|
||||
} else {
|
||||
addValue(information, propName, sanitize(child.getAttribute("content") || child.content || child.textContent || child.src || ""));
|
||||
}
|
||||
});
|
||||
traverseItem(child, information);
|
||||
} else {
|
||||
traverseItem(child, information);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const microdata = [];
|
||||
|
||||
document.querySelectorAll("[itemscope]").forEach(function(elem, i) {
|
||||
const itemType = elem.getAttribute('itemtype');
|
||||
const information = {
|
||||
itemType: itemType
|
||||
};
|
||||
traverseItem(elem, information);
|
||||
microdata.push(information);
|
||||
});
|
||||
|
||||
return microdata;
|
||||
};
|
||||
|
||||
let getPageMetadata = function() {
|
||||
let jsonld = _getJsonLd();
|
||||
let metaTags = _getMetaTags();
|
||||
let microdata = _getMicrodata();
|
||||
let results = {}
|
||||
if (jsonld.length > 0) {
|
||||
try {
|
||||
results["jsonld"] = JSON.parse(jsonld);
|
||||
}
|
||||
catch (e) {
|
||||
results["jsonld"] = jsonld;
|
||||
}
|
||||
}
|
||||
if (microdata.length > 0) {
|
||||
results["microdata"] = microdata;
|
||||
}
|
||||
for (let key in metaTags) {
|
||||
if (metaTags.hasOwnProperty(key)) {
|
||||
results["meta_tags"] = metaTags;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return results;
|
||||
};
|
||||
|
||||
|
||||
let getVisibleText = function() {
|
||||
// Get the window’s current viewport boundaries
|
||||
const viewportHeight = window.innerHeight || document.documentElement.clientHeight;
|
||||
const viewportWidth = window.innerWidth || document.documentElement.clientWidth;
|
||||
|
||||
let textInView = "";
|
||||
const walker = document.createTreeWalker(
|
||||
document.body,
|
||||
NodeFilter.SHOW_TEXT,
|
||||
null,
|
||||
false
|
||||
);
|
||||
|
||||
while (walker.nextNode()) {
|
||||
const textNode = walker.currentNode;
|
||||
// Create a range to retrieve bounding rectangles of the current text node
|
||||
const range = document.createRange();
|
||||
range.selectNodeContents(textNode);
|
||||
|
||||
const rects = range.getClientRects();
|
||||
|
||||
// Check if any rect is inside (or partially inside) the viewport
|
||||
for (const rect of rects) {
|
||||
const isVisible =
|
||||
rect.width > 0 &&
|
||||
rect.height > 0 &&
|
||||
rect.bottom >= 0 &&
|
||||
rect.right >= 0 &&
|
||||
rect.top <= viewportHeight &&
|
||||
rect.left <= viewportWidth;
|
||||
|
||||
if (isVisible) {
|
||||
textInView += textNode.nodeValue.replace(/\s+/g, " ");
|
||||
// Is the parent a block element?
|
||||
if (textNode.parentNode) {
|
||||
const parent = textNode.parentNode;
|
||||
const style = window.getComputedStyle(parent);
|
||||
if (["inline", "hidden", "none"].indexOf(style.display) === -1) {
|
||||
textInView += "\n";
|
||||
}
|
||||
}
|
||||
break; // No need to check other rects once found visible
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove blank lines from textInView
|
||||
textInView = textInView.replace(/^\s*\n/gm, "").trim().replace(/\n+/g, "\n");
|
||||
return textInView;
|
||||
};
|
||||
|
||||
return {
|
||||
getInteractiveRects: getInteractiveRects,
|
||||
getVisualViewport: getVisualViewport,
|
||||
getFocusedElementId: getFocusedElementId,
|
||||
getPageMetadata: getPageMetadata,
|
||||
getVisibleText: getVisibleText,
|
||||
};
|
||||
})();
|
||||
@@ -0,0 +1,578 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union, cast
|
||||
|
||||
from playwright._impl._errors import Error as PlaywrightError
|
||||
from playwright._impl._errors import TimeoutError
|
||||
from playwright.async_api import Download, Page
|
||||
|
||||
from ._types import (
|
||||
InteractiveRegion,
|
||||
VisualViewport,
|
||||
interactiveregion_from_dict,
|
||||
visualviewport_from_dict,
|
||||
)
|
||||
|
||||
markitdown: ModuleType | None = None
|
||||
try:
|
||||
# Suppress warnings from markitdown -- which is pretty chatty
|
||||
warnings.filterwarnings(action="ignore", module="markitdown")
|
||||
import markitdown
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class PlaywrightController:
|
||||
"""
|
||||
A helper class to allow Playwright to interact with web pages to perform actions such as clicking, filling, and scrolling.
|
||||
|
||||
Args:
|
||||
downloads_folder (str | None): The folder to save downloads to. If None, downloads are not saved.
|
||||
animate_actions (bool): Whether to animate the actions (create fake cursor to click).
|
||||
viewport_width (int): The width of the viewport.
|
||||
viewport_height (int): The height of the viewport.
|
||||
_download_handler (Optional[Callable[[Download], None]]): A function to handle downloads.
|
||||
to_resize_viewport (bool): Whether to resize the viewport
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
downloads_folder: str | None = None,
|
||||
animate_actions: bool = False,
|
||||
viewport_width: int = 1440,
|
||||
viewport_height: int = 900,
|
||||
_download_handler: Optional[Callable[[Download], None]] = None,
|
||||
to_resize_viewport: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the PlaywrightController.
|
||||
"""
|
||||
assert isinstance(animate_actions, bool)
|
||||
assert isinstance(viewport_width, int)
|
||||
assert isinstance(viewport_height, int)
|
||||
assert viewport_height > 0
|
||||
assert viewport_width > 0
|
||||
|
||||
self.animate_actions = animate_actions
|
||||
self.downloads_folder = downloads_folder
|
||||
self.viewport_width = viewport_width
|
||||
self.viewport_height = viewport_height
|
||||
self._download_handler = _download_handler
|
||||
self.to_resize_viewport = to_resize_viewport
|
||||
self._page_script: str = ""
|
||||
self.last_cursor_position: Tuple[float, float] = (0.0, 0.0)
|
||||
self._markdown_converter: Optional[Any] | None = None
|
||||
|
||||
# Read page_script
|
||||
with open(
|
||||
os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js"), "rt", encoding="utf-8"
|
||||
) as fh:
|
||||
self._page_script = fh.read()
|
||||
|
||||
async def sleep(self, page: Page, duration: Union[int, float]) -> None:
|
||||
"""
|
||||
Pause the execution for a specified duration.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
duration (Union[int, float]): The duration to sleep in milliseconds.
|
||||
"""
|
||||
assert page is not None
|
||||
await page.wait_for_timeout(duration * 1000)
|
||||
|
||||
async def get_interactive_rects(self, page: Page) -> Dict[str, InteractiveRegion]:
|
||||
"""
|
||||
Retrieve interactive regions from the web page.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
|
||||
Returns:
|
||||
Dict[str, InteractiveRegion]: A dictionary of interactive regions.
|
||||
"""
|
||||
assert page is not None
|
||||
# Read the regions from the DOM
|
||||
try:
|
||||
await page.evaluate(self._page_script)
|
||||
except Exception:
|
||||
pass
|
||||
result = cast(Dict[str, Dict[str, Any]], await page.evaluate("MultimodalWebSurfer.getInteractiveRects();"))
|
||||
|
||||
# Convert the results into appropriate types
|
||||
assert isinstance(result, dict)
|
||||
typed_results: Dict[str, InteractiveRegion] = {}
|
||||
for k in result:
|
||||
assert isinstance(k, str)
|
||||
typed_results[k] = interactiveregion_from_dict(result[k])
|
||||
|
||||
return typed_results
|
||||
|
||||
async def get_visual_viewport(self, page: Page) -> VisualViewport:
|
||||
"""
|
||||
Retrieve the visual viewport of the web page.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
|
||||
Returns:
|
||||
VisualViewport: The visual viewport of the page.
|
||||
"""
|
||||
assert page is not None
|
||||
try:
|
||||
await page.evaluate(self._page_script)
|
||||
except Exception:
|
||||
pass
|
||||
return visualviewport_from_dict(await page.evaluate("MultimodalWebSurfer.getVisualViewport();"))
|
||||
|
||||
async def get_focused_rect_id(self, page: Page) -> str | None:
|
||||
"""
|
||||
Retrieve the ID of the currently focused element.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
|
||||
Returns:
|
||||
str: The ID of the focused element or None if no control has focus.
|
||||
"""
|
||||
assert page is not None
|
||||
try:
|
||||
await page.evaluate(self._page_script)
|
||||
except Exception:
|
||||
pass
|
||||
result = await page.evaluate("MultimodalWebSurfer.getFocusedElementId();")
|
||||
return None if result is None else str(result)
|
||||
|
||||
async def get_page_metadata(self, page: Page) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve metadata from the web page.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary of page metadata.
|
||||
"""
|
||||
assert page is not None
|
||||
try:
|
||||
await page.evaluate(self._page_script)
|
||||
except Exception:
|
||||
pass
|
||||
result = await page.evaluate("MultimodalWebSurfer.getPageMetadata();")
|
||||
assert isinstance(result, dict)
|
||||
return cast(Dict[str, Any], result)
|
||||
|
||||
async def on_new_page(self, page: Page) -> None:
|
||||
"""
|
||||
Handle actions to perform on a new page.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
"""
|
||||
assert page is not None
|
||||
page.on("download", self._download_handler) # type: ignore
|
||||
if self.to_resize_viewport and self.viewport_width and self.viewport_height:
|
||||
await page.set_viewport_size({"width": self.viewport_width, "height": self.viewport_height})
|
||||
await self.sleep(page, 0.2)
|
||||
await page.add_init_script(path=os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js"))
|
||||
await page.wait_for_load_state()
|
||||
|
||||
async def back(self, page: Page) -> None:
|
||||
"""
|
||||
Navigate back to the previous page.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
"""
|
||||
assert page is not None
|
||||
await page.go_back()
|
||||
|
||||
async def visit_page(self, page: Page, url: str) -> Tuple[bool, bool]:
|
||||
"""
|
||||
Visit a specified URL.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
url (str): The URL to visit.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, bool]: A tuple indicating whether to reset prior metadata hash and last download.
|
||||
"""
|
||||
assert page is not None
|
||||
reset_prior_metadata_hash = False
|
||||
reset_last_download = False
|
||||
try:
|
||||
# Regular webpage
|
||||
await page.goto(url)
|
||||
await page.wait_for_load_state()
|
||||
reset_prior_metadata_hash = True
|
||||
except Exception as e_outer:
|
||||
# Downloaded file
|
||||
if self.downloads_folder and "net::ERR_ABORTED" in str(e_outer):
|
||||
async with page.expect_download() as download_info:
|
||||
try:
|
||||
await page.goto(url)
|
||||
except Exception as e_inner:
|
||||
if "net::ERR_ABORTED" in str(e_inner):
|
||||
pass
|
||||
else:
|
||||
raise e_inner
|
||||
download = await download_info.value
|
||||
fname = os.path.join(self.downloads_folder, download.suggested_filename)
|
||||
await download.save_as(fname)
|
||||
message = f"<body style=\"margin: 20px;\"><h1>Successfully downloaded '{download.suggested_filename}' to local path:<br><br>{fname}</h1></body>"
|
||||
await page.goto(
|
||||
"data:text/html;base64," + base64.b64encode(message.encode("utf-8")).decode("utf-8")
|
||||
)
|
||||
reset_last_download = True
|
||||
else:
|
||||
raise e_outer
|
||||
return reset_prior_metadata_hash, reset_last_download
|
||||
|
||||
async def page_down(self, page: Page) -> None:
|
||||
"""
|
||||
Scroll the page down by one viewport height minus 50 pixels.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
"""
|
||||
assert page is not None
|
||||
await page.evaluate(f"window.scrollBy(0, {self.viewport_height-50});")
|
||||
|
||||
async def page_up(self, page: Page) -> None:
|
||||
"""
|
||||
Scroll the page up by one viewport height minus 50 pixels.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
"""
|
||||
assert page is not None
|
||||
await page.evaluate(f"window.scrollBy(0, -{self.viewport_height-50});")
|
||||
|
||||
async def gradual_cursor_animation(
|
||||
self, page: Page, start_x: float, start_y: float, end_x: float, end_y: float
|
||||
) -> None:
|
||||
"""
|
||||
Animate the cursor movement gradually from start to end coordinates.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
start_x (float): The starting x-coordinate.
|
||||
start_y (float): The starting y-coordinate.
|
||||
end_x (float): The ending x-coordinate.
|
||||
end_y (float): The ending y-coordinate.
|
||||
"""
|
||||
# animation helper
|
||||
steps = 20
|
||||
for step in range(steps):
|
||||
x = start_x + (end_x - start_x) * (step / steps)
|
||||
y = start_y + (end_y - start_y) * (step / steps)
|
||||
# await page.mouse.move(x, y, steps=1)
|
||||
await page.evaluate(f"""
|
||||
(function() {{
|
||||
let cursor = document.getElementById('red-cursor');
|
||||
cursor.style.left = '{x}px';
|
||||
cursor.style.top = '{y}px';
|
||||
}})();
|
||||
""")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
self.last_cursor_position = (end_x, end_y)
|
||||
|
||||
async def add_cursor_box(self, page: Page, identifier: str) -> None:
|
||||
"""
|
||||
Add a red cursor box around the element with the given identifier.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
identifier (str): The element identifier.
|
||||
"""
|
||||
# animation helper
|
||||
await page.evaluate(f"""
|
||||
(function() {{
|
||||
let elm = document.querySelector("[__elementId='{identifier}']");
|
||||
if (elm) {{
|
||||
elm.style.transition = 'border 0.3s ease-in-out';
|
||||
elm.style.border = '2px solid red';
|
||||
}}
|
||||
}})();
|
||||
""")
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Create a red cursor
|
||||
await page.evaluate("""
|
||||
(function() {
|
||||
let cursor = document.createElement('div');
|
||||
cursor.id = 'red-cursor';
|
||||
cursor.style.width = '10px';
|
||||
cursor.style.height = '10px';
|
||||
cursor.style.backgroundColor = 'red';
|
||||
cursor.style.position = 'absolute';
|
||||
cursor.style.borderRadius = '50%';
|
||||
cursor.style.zIndex = '10000';
|
||||
document.body.appendChild(cursor);
|
||||
})();
|
||||
""")
|
||||
|
||||
async def remove_cursor_box(self, page: Page, identifier: str) -> None:
|
||||
"""
|
||||
Remove the red cursor box around the element with the given identifier.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
identifier (str): The element identifier.
|
||||
"""
|
||||
# Remove the highlight and cursor
|
||||
await page.evaluate(f"""
|
||||
(function() {{
|
||||
let elm = document.querySelector("[__elementId='{identifier}']");
|
||||
if (elm) {{
|
||||
elm.style.border = '';
|
||||
}}
|
||||
let cursor = document.getElementById('red-cursor');
|
||||
if (cursor) {{
|
||||
cursor.remove();
|
||||
}}
|
||||
}})();
|
||||
""")
|
||||
|
||||
async def click_id(self, page: Page, identifier: str) -> Page | None:
|
||||
"""
|
||||
Click the element with the given identifier.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
identifier (str): The element identifier.
|
||||
|
||||
Returns:
|
||||
Page | None: The new page if a new page is opened, otherwise None.
|
||||
"""
|
||||
new_page: Page | None = None
|
||||
assert page is not None
|
||||
target = page.locator(f"[__elementId='{identifier}']")
|
||||
|
||||
# See if it exists
|
||||
try:
|
||||
await target.wait_for(timeout=5000)
|
||||
except TimeoutError:
|
||||
raise ValueError("No such element.") from None
|
||||
|
||||
# Click it
|
||||
await target.scroll_into_view_if_needed()
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
box = cast(Dict[str, Union[int, float]], await target.bounding_box())
|
||||
|
||||
if self.animate_actions:
|
||||
await self.add_cursor_box(page, identifier)
|
||||
# Move cursor to the box slowly
|
||||
start_x, start_y = self.last_cursor_position
|
||||
end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2
|
||||
await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
try:
|
||||
# Give it a chance to open a new page
|
||||
async with page.expect_event("popup", timeout=1000) as page_info: # type: ignore
|
||||
await page.mouse.click(end_x, end_y, delay=10)
|
||||
new_page = await page_info.value # type: ignore
|
||||
assert isinstance(new_page, Page)
|
||||
await self.on_new_page(new_page)
|
||||
except TimeoutError:
|
||||
pass
|
||||
await self.remove_cursor_box(page, identifier)
|
||||
|
||||
else:
|
||||
try:
|
||||
# Give it a chance to open a new page
|
||||
async with page.expect_event("popup", timeout=1000) as page_info: # type: ignore
|
||||
await page.mouse.click(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2, delay=10)
|
||||
new_page = await page_info.value # type: ignore
|
||||
assert isinstance(new_page, Page)
|
||||
await self.on_new_page(new_page)
|
||||
except TimeoutError:
|
||||
pass
|
||||
return new_page # type: ignore
|
||||
|
||||
async def hover_id(self, page: Page, identifier: str) -> None:
|
||||
"""
|
||||
Hover the mouse over the element with the given identifier.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
identifier (str): The element identifier.
|
||||
"""
|
||||
assert page is not None
|
||||
target = page.locator(f"[__elementId='{identifier}']")
|
||||
|
||||
# See if it exists
|
||||
try:
|
||||
await target.wait_for(timeout=5000)
|
||||
except TimeoutError:
|
||||
raise ValueError("No such element.") from None
|
||||
|
||||
# Hover over it
|
||||
await target.scroll_into_view_if_needed()
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
box = cast(Dict[str, Union[int, float]], await target.bounding_box())
|
||||
|
||||
if self.animate_actions:
|
||||
await self.add_cursor_box(page, identifier)
|
||||
# Move cursor to the box slowly
|
||||
start_x, start_y = self.last_cursor_position
|
||||
end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2
|
||||
await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y)
|
||||
await asyncio.sleep(0.1)
|
||||
await page.mouse.move(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2)
|
||||
|
||||
await self.remove_cursor_box(page, identifier)
|
||||
else:
|
||||
await page.mouse.move(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2)
|
||||
|
||||
async def fill_id(self, page: Page, identifier: str, value: str, press_enter: bool = True) -> None:
|
||||
"""
|
||||
Fill the element with the given identifier with the specified value.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
identifier (str): The element identifier.
|
||||
value (str): The value to fill.
|
||||
"""
|
||||
assert page is not None
|
||||
target = page.locator(f"[__elementId='{identifier}']")
|
||||
|
||||
# See if it exists
|
||||
try:
|
||||
await target.wait_for(timeout=5000)
|
||||
except TimeoutError:
|
||||
raise ValueError("No such element.") from None
|
||||
|
||||
# Fill it
|
||||
await target.scroll_into_view_if_needed()
|
||||
box = cast(Dict[str, Union[int, float]], await target.bounding_box())
|
||||
|
||||
if self.animate_actions:
|
||||
await self.add_cursor_box(page, identifier)
|
||||
# Move cursor to the box slowly
|
||||
start_x, start_y = self.last_cursor_position
|
||||
end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2
|
||||
await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Focus on the element
|
||||
await target.focus()
|
||||
if self.animate_actions:
|
||||
# fill char by char to mimic human speed for short text and type fast for long text
|
||||
if len(value) < 100:
|
||||
delay_typing_speed = 50 + 100 * random.random()
|
||||
else:
|
||||
delay_typing_speed = 10
|
||||
await target.press_sequentially(value, delay=delay_typing_speed)
|
||||
else:
|
||||
try:
|
||||
await target.fill(value)
|
||||
except PlaywrightError:
|
||||
await target.press_sequentially(value)
|
||||
if press_enter:
|
||||
await target.press("Enter")
|
||||
|
||||
if self.animate_actions:
|
||||
await self.remove_cursor_box(page, identifier)
|
||||
|
||||
async def scroll_id(self, page: Page, identifier: str, direction: str) -> None:
|
||||
"""
|
||||
Scroll the element with the given identifier in the specified direction.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
identifier (str): The element identifier.
|
||||
direction (str): The direction to scroll ("up" or "down").
|
||||
"""
|
||||
assert page is not None
|
||||
await page.evaluate(
|
||||
f"""
|
||||
(function() {{
|
||||
let elm = document.querySelector("[__elementId='{identifier}']");
|
||||
if (elm) {{
|
||||
if ("{direction}" == "up") {{
|
||||
elm.scrollTop = Math.max(0, elm.scrollTop - elm.clientHeight);
|
||||
}}
|
||||
else {{
|
||||
elm.scrollTop = Math.min(elm.scrollHeight - elm.clientHeight, elm.scrollTop + elm.clientHeight);
|
||||
}}
|
||||
}}
|
||||
}})();
|
||||
"""
|
||||
)
|
||||
|
||||
async def get_webpage_text(self, page: Page, n_lines: int = 50) -> str:
|
||||
"""
|
||||
Retrieve the text content of the web page.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
n_lines (int): The number of lines to return from the page inner text.
|
||||
|
||||
Returns:
|
||||
str: The text content of the page.
|
||||
"""
|
||||
assert page is not None
|
||||
try:
|
||||
text_in_viewport = await page.evaluate("""() => {
|
||||
return document.body.innerText;
|
||||
}""")
|
||||
text_in_viewport = "\n".join(text_in_viewport.split("\n")[:n_lines])
|
||||
# remove empty lines
|
||||
text_in_viewport = "\n".join([line for line in text_in_viewport.split("\n") if line.strip()])
|
||||
assert isinstance(text_in_viewport, str)
|
||||
return text_in_viewport
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
async def get_visible_text(self, page: Page) -> str:
|
||||
"""
|
||||
Retrieve the text content of the browser viewport (approximately).
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
|
||||
Returns:
|
||||
str: The text content of the page.
|
||||
"""
|
||||
assert page is not None
|
||||
try:
|
||||
await page.evaluate(self._page_script)
|
||||
except Exception:
|
||||
pass
|
||||
result = await page.evaluate("MultimodalWebSurfer.getVisibleText();")
|
||||
assert isinstance(result, str)
|
||||
return result
|
||||
|
||||
async def get_page_markdown(self, page: Page) -> str:
|
||||
"""
|
||||
Retrieve the markdown content of the web page.
|
||||
Currently not implemented.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright page object.
|
||||
|
||||
Returns:
|
||||
str: The markdown content of the page.
|
||||
"""
|
||||
assert page is not None
|
||||
if self._markdown_converter is None and markitdown is not None:
|
||||
self._markdown_converter = markitdown.MarkItDown()
|
||||
assert self._markdown_converter is not None
|
||||
html = await page.evaluate("document.documentElement.outerHTML;")
|
||||
res = self._markdown_converter.convert_stream(
|
||||
io.BytesIO(html.encode("utf-8")), file_extension=".html", url=page.url
|
||||
)
|
||||
assert hasattr(res, "text_content") and isinstance(res.text_content, str)
|
||||
return res.text_content
|
||||
else:
|
||||
return await self.get_webpage_text(page, n_lines=200)
|
||||
Reference in New Issue
Block a user