first commit

This commit is contained in:
DigiJ
2026-03-13 12:56:43 -07:00
commit 159cf9fcfe
309 changed files with 64584 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
try:
from ._azure_ai_agent import AzureAIAgent
except ImportError as e:
raise ImportError(
"Dependencies for AzureAIAgent not found. "
'Please install autogen-ext with the "azure" extra: '
'pip install "agentdhal-ext[azure]"'
) from e
__all__ = ["AzureAIAgent"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,61 @@
from typing import Any, Awaitable, Callable, Iterable, List, Literal, Optional, TypeGuard, Union
from agentdhal_core.tools import Tool
from pydantic import BaseModel, Field
from azure.ai.agents.models import (
AzureAISearchToolDefinition,
AzureFunctionToolDefinition,
BingGroundingToolDefinition,
CodeInterpreterToolDefinition,
FileSearchToolDefinition,
MessageTextUrlCitationAnnotation,
)
ListToolType = Iterable[
Union[
Literal[
"file_search",
"code_interpreter",
"bing_grounding",
"azure_ai_search",
"azure_function",
],
BingGroundingToolDefinition,
CodeInterpreterToolDefinition,
AzureAISearchToolDefinition,
FileSearchToolDefinition,
AzureFunctionToolDefinition,
Tool,
Callable[..., Any],
Callable[..., Awaitable[Any]],
]
]
class AzureAIAgentState(BaseModel):
"""
Represents the state of an AzureAIAgent that can be saved and loaded.
This state model keeps track of persistent information about an agent session
including agent and thread identifiers, message history, and associated resources.
Attributes:
type (str): The type identifier for the state object, always "AzureAIAgentState"
agent_id (Optional[str]): The ID of the Azure AI agent
thread_id (Optional[str]): The ID of the conversation thread
initial_message_ids (List[str]): List of message IDs from the initial state
vector_store_id (Optional[str]): The ID of the associated vector store for file search
uploaded_file_ids (List[str]): List of IDs for files uploaded to the agent
"""
type: str = Field(default="AzureAIAgentState")
agent_id: Optional[str] = None
thread_id: Optional[str] = None
initial_message_ids: List[str] = Field(default_factory=list)
vector_store_id: Optional[str] = None
uploaded_file_ids: List[str] = Field(default_factory=list)
def has_annotations(obj: Any) -> TypeGuard[list[MessageTextUrlCitationAnnotation]]:
return obj is not None and isinstance(obj, list)

View File

@@ -0,0 +1,3 @@
from ._file_surfer import FileSurfer
__all__ = ["FileSurfer"]

View File

@@ -0,0 +1,208 @@
import json
import os
import traceback
from typing import List, Sequence, Tuple
from agentdhal_agentchat.agents import BaseChatAgent
from agentdhal_agentchat.base import Response
from agentdhal_agentchat.messages import (
BaseChatMessage,
TextMessage,
)
from agentdhal_agentchat.utils import remove_images
from agentdhal_core import CancellationToken, Component, ComponentModel, FunctionCall
from agentdhal_core.models import (
AssistantMessage,
ChatCompletionClient,
LLMMessage,
SystemMessage,
UserMessage,
)
from pydantic import BaseModel
from typing_extensions import Self
from ._markdown_file_browser import MarkdownFileBrowser
# from typing_extensions import Annotated
from ._tool_definitions import (
TOOL_FIND_NEXT,
TOOL_FIND_ON_PAGE_CTRL_F,
TOOL_OPEN_PATH,
TOOL_PAGE_DOWN,
TOOL_PAGE_UP,
)
class FileSurferConfig(BaseModel):
"""Configuration for FileSurfer agent"""
name: str
model_client: ComponentModel
description: str | None = None
class FileSurfer(BaseChatAgent, Component[FileSurferConfig]):
"""An agent, used by MagenticOne, that acts as a local file previewer. FileSurfer can open and read a variety of common file types, and can navigate the local file hierarchy.
Installation:
.. code-block:: bash
pip install "agentdhal-ext[file-surfer]"
Args:
name (str): The agent's name
model_client (ChatCompletionClient): The model to use (must be tool-use enabled)
description (str): The agent's description used by the team. Defaults to DEFAULT_DESCRIPTION
base_path (str): The base path to use for the file browser. Defaults to the current working directory.
"""
component_config_schema = FileSurferConfig
component_provider_override = "agentdhal_extensions.agents.file_surfer.FileSurfer"
DEFAULT_DESCRIPTION = "An agent that can handle local files."
DEFAULT_SYSTEM_MESSAGES = [
SystemMessage(
content="""
You are a helpful AI Assistant.
When given a user query, use available functions to help the user with their request."""
),
]
def __init__(
self,
name: str,
model_client: ChatCompletionClient,
description: str = DEFAULT_DESCRIPTION,
base_path: str = os.getcwd(),
) -> None:
super().__init__(name, description)
self._model_client = model_client
self._chat_history: List[LLMMessage] = []
self._browser = MarkdownFileBrowser(viewport_size=1024 * 5, base_path=base_path)
@property
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
for chat_message in messages:
self._chat_history.append(chat_message.to_model_message())
try:
_, content = await self._generate_reply(cancellation_token=cancellation_token)
self._chat_history.append(AssistantMessage(content=content, source=self.name))
return Response(chat_message=TextMessage(content=content, source=self.name))
except BaseException:
content = f"File surfing error:\n\n{traceback.format_exc()}"
self._chat_history.append(AssistantMessage(content=content, source=self.name))
return Response(chat_message=TextMessage(content=content, source=self.name))
async def on_reset(self, cancellation_token: CancellationToken) -> None:
self._chat_history.clear()
def _get_browser_state(self) -> Tuple[str, str]:
"""
Get the current state of the browser, including the header and content.
"""
header = f"Path: {self._browser.path}\n"
if self._browser.page_title is not None:
header += f"Title: {self._browser.page_title}\n"
current_page = self._browser.viewport_current_page
total_pages = len(self._browser.viewport_pages)
header += f"Viewport position: Showing page {current_page+1} of {total_pages}.\n"
return (header, self._browser.viewport)
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, str]:
history = self._chat_history[0:-1]
last_message = self._chat_history[-1]
assert isinstance(last_message, UserMessage)
task_content = last_message.content # the last message from the sender is the task
assert self._browser is not None
context_message = UserMessage(
source="user",
content=f"Your file viewer is currently open to the file or directory '{self._browser.page_title}' with path '{self._browser.path}'.",
)
task_message = UserMessage(
source="user",
content=task_content,
)
create_result = await self._model_client.create(
messages=self._get_compatible_context(history + [context_message, task_message]),
tools=[
TOOL_OPEN_PATH,
TOOL_PAGE_DOWN,
TOOL_PAGE_UP,
TOOL_FIND_NEXT,
TOOL_FIND_ON_PAGE_CTRL_F,
],
cancellation_token=cancellation_token,
)
response = create_result.content
if isinstance(response, str):
# Answer directly.
return False, response
elif isinstance(response, list) and all(isinstance(item, FunctionCall) for item in response):
function_calls = response
for function_call in function_calls:
tool_name = function_call.name
try:
arguments = json.loads(function_call.arguments)
except json.JSONDecodeError as e:
error_str = f"File surfer encountered an error decoding JSON arguments: {e}"
return False, error_str
if tool_name == "open_path":
path = arguments["path"]
self._browser.open_path(path)
elif tool_name == "page_up":
self._browser.page_up()
elif tool_name == "page_down":
self._browser.page_down()
elif tool_name == "find_on_page_ctrl_f":
search_string = arguments["search_string"]
self._browser.find_on_page(search_string)
elif tool_name == "find_next":
self._browser.find_next()
header, content = self._get_browser_state()
final_response = header.strip() + "\n=======================\n" + content
return False, final_response
final_response = "TERMINATE"
return False, final_response
def _get_compatible_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
if self._model_client.model_info["vision"]:
return messages
else:
return remove_images(messages)
def _to_config(self) -> FileSurferConfig:
return FileSurferConfig(
name=self.name,
model_client=self._model_client.dump_component(),
description=self.description,
)
@classmethod
def _from_config(cls, config: FileSurferConfig) -> Self:
return cls(
name=config.name,
model_client=ChatCompletionClient.load_component(config.model_client),
description=config.description or cls.DEFAULT_DESCRIPTION,
)

View File

@@ -0,0 +1,317 @@
# ruff: noqa: E722
import datetime
import io
import os
import re
import time
from typing import List, Optional, Tuple, Union
# TODO: Fix unfollowed import
from markitdown import FileConversionException, MarkItDown, UnsupportedFormatException # type: ignore
class MarkdownFileBrowser:
"""
(In preview) An extremely simple Markdown-powered file browser.
"""
# TODO: Fix unfollowed import
def __init__( # type: ignore
self,
viewport_size: Union[int, None] = 1024 * 8,
base_path: str | None = os.getcwd(),
cwd: str | None = None,
):
"""
Instantiate a new MarkdownFileBrowser.
Arguments:
viewport_size: Approximately how many *characters* fit in the viewport. Viewport dimensions are adjusted dynamically to avoid cutting off words (default: 8192).
base_path: The base path to use for the file browser. Files outside this path cannot be accessed. Defaults to the current working directory.
cwd: The browser's current working directory. Defaults to the system's current working directory.
"""
self.viewport_size = viewport_size # Applies only to the standard uri types
self.history: List[Tuple[str, float]] = list()
self.page_title: Optional[str] = None
self.viewport_current_page = 0
self.viewport_pages: List[Tuple[int, int]] = list()
self._markdown_converter = MarkItDown()
self._base_path = None if base_path is None else os.path.realpath(base_path)
self._page_content: str = ""
self._find_on_page_query: Union[str, None] = None
self._find_on_page_last_result: Union[int, None] = None # Location of the last result
# Set the working directory
if cwd is None:
if self._validate_path(os.getcwd()):
# Use the current working directory if it's in the base path
cwd = os.path.realpath(os.getcwd())
elif self._base_path is not None:
# Otherwise, use the base path
cwd = os.path.realpath(self._base_path)
else:
raise ValueError("No valid working directory (cwd) provided.")
elif not self._validate_path(cwd):
# A cwd was provided, but it is not valid
raise ValueError(f"Working directory (cwd) '{cwd}' is not valid. It must be within the base path.")
# Populate the history with the current working directory
self.set_path(os.path.realpath(cwd))
@property
def path(self) -> str:
"""Return the path of the current page."""
assert len(self.history) > 0
return self.history[-1][0]
def _validate_path(self, path: str) -> bool:
"""Validates the path to ensure it is within the base path.
Arguments:
path: The path to validate.
Returns:
True if the path is valid, False otherwise.
"""
if self._base_path is None:
return True
# Normalize the paths
path = os.path.realpath(path)
base = os.path.realpath(self._base_path)
# Check if the path is within the base path
if os.path.commonpath([path, base]) != base:
return False
return True
def set_path(self, path: str) -> None:
"""Sets the path of the current page.
This will result in the file being opened for reading.
Arguments:
path: An absolute or relative path of the file or directory to open."
"""
# Handle relative paths
path = os.path.expanduser(path)
if not os.path.isabs(path):
if os.path.isfile(self.path):
path = os.path.abspath(os.path.join(os.path.dirname(self.path), path))
elif os.path.isdir(self.path):
path = os.path.abspath(os.path.join(self.path, path))
# If neither a file or a directory, take it verbatim
# Validating the path wrt. the base path is done in _open_path
path = os.path.realpath(path)
self.history.append((path, time.time()))
self._open_path(path)
self.viewport_current_page = 0
self.find_on_page_query = None
self.find_on_page_viewport = None
@property
def viewport(self) -> str:
"""Return the content of the current viewport."""
bounds = self.viewport_pages[self.viewport_current_page]
return self.page_content[bounds[0] : bounds[1]]
@property
def page_content(self) -> str:
"""Return the full contents of the current page."""
return self._page_content
def _set_page_content(self, content: str, split_pages: bool = True) -> None:
"""Sets the text content of the current page."""
self._page_content = content
if split_pages:
self._split_pages()
else:
self.viewport_pages = [(0, len(self._page_content))]
if self.viewport_current_page >= len(self.viewport_pages):
self.viewport_current_page = len(self.viewport_pages) - 1
def page_down(self) -> None:
"""Move the viewport down one page, if possible."""
self.viewport_current_page = min(self.viewport_current_page + 1, len(self.viewport_pages) - 1)
def page_up(self) -> None:
"""Move the viewport up one page, if possible."""
self.viewport_current_page = max(self.viewport_current_page - 1, 0)
def find_on_page(self, query: str) -> Union[str, None]:
"""Searches for the query from the current viewport forward, looping back to the start if necessary."""
# Did we get here via a previous find_on_page search with the same query?
# If so, map to find_next
if query == self._find_on_page_query and self.viewport_current_page == self._find_on_page_last_result:
return self.find_next()
# Ok it's a new search start from the current viewport
self._find_on_page_query = query
viewport_match = self._find_next_viewport(query, self.viewport_current_page)
if viewport_match is None:
self._find_on_page_last_result = None
return None
else:
self.viewport_current_page = viewport_match
self._find_on_page_last_result = viewport_match
return self.viewport
def find_next(self) -> Union[str, None]:
"""Scroll to the next viewport that matches the query"""
if self._find_on_page_query is None:
return None
starting_viewport = self._find_on_page_last_result
if starting_viewport is None:
starting_viewport = 0
else:
starting_viewport += 1
if starting_viewport >= len(self.viewport_pages):
starting_viewport = 0
viewport_match = self._find_next_viewport(self._find_on_page_query, starting_viewport)
if viewport_match is None:
self._find_on_page_last_result = None
return None
else:
self.viewport_current_page = viewport_match
self._find_on_page_last_result = viewport_match
return self.viewport
def _find_next_viewport(self, query: Optional[str], starting_viewport: int) -> Union[int, None]:
"""Search for matches between the starting viewport looping when reaching the end."""
if query is None:
return None
# Normalize the query, and convert to a regular expression
nquery = re.sub(r"\*", "__STAR__", query)
nquery = " " + (" ".join(re.split(r"\W+", nquery))).strip() + " "
nquery = nquery.replace(" __STAR__ ", "__STAR__ ") # Merge isolated stars with prior word
nquery = nquery.replace("__STAR__", ".*").lower()
if nquery.strip() == "":
return None
idxs: List[int] = list()
idxs.extend(range(starting_viewport, len(self.viewport_pages)))
idxs.extend(range(0, starting_viewport))
for i in idxs:
bounds = self.viewport_pages[i]
content = self.page_content[bounds[0] : bounds[1]]
# TODO: Remove markdown links and images
ncontent = " " + (" ".join(re.split(r"\W+", content))).strip().lower() + " "
if re.search(nquery, ncontent):
return i
return None
def open_path(self, path: str) -> str:
"""Open a file or directory in the file surfer."""
self.set_path(path)
return self.viewport
def _split_pages(self) -> None:
"""Split the page contents into pages that are approximately the viewport size. Small deviations are permitted to ensure words are not broken."""
# Handle empty pages
if len(self._page_content) == 0:
self.viewport_pages = [(0, 0)]
return
# Break the viewport into pages
self.viewport_pages = []
start_idx = 0
while start_idx < len(self._page_content):
end_idx = min(start_idx + self.viewport_size, len(self._page_content)) # type: ignore[operator]
# Adjust to end on a space
while end_idx < len(self._page_content) and self._page_content[end_idx - 1] not in [" ", "\t", "\r", "\n"]:
end_idx += 1
self.viewport_pages.append((start_idx, end_idx))
start_idx = end_idx
def _open_path(
self,
path: str,
) -> None:
"""Open a file for reading, converting it to Markdown in the process.
Arguments:
path: The path of the file or directory to open.
"""
if not self._validate_path(path):
# Not robust to TOCTOU issues.
# Mitigate by running with limited permissions, or use a sandbox.
self.page_title = "FileNotFoundError"
self._set_page_content(f"# FileNotFoundError\n\nFile not found: {path}")
else:
try:
if os.path.isdir(path): # TODO: Fix markdown_converter types
res = self._markdown_converter.convert_stream( # type: ignore
io.BytesIO(self._fetch_local_dir(path).encode("utf-8")), file_extension=".txt"
)
assert self._validate_path(path)
self.page_title = res.title
self._set_page_content(res.text_content, split_pages=False)
else:
res = self._markdown_converter.convert_local(path)
assert self._validate_path(path)
self.page_title = res.title
self._set_page_content(res.text_content)
except UnsupportedFormatException:
self.page_title = "UnsupportedFormatException"
self._set_page_content(f"# UnsupportedFormatException\n\nCannot preview '{path}' as Markdown.")
except FileConversionException:
self.page_title = "FileConversionException."
self._set_page_content(f"# FileConversionException\n\nError converting '{path}' to Markdown.")
except FileNotFoundError:
self.page_title = "FileNotFoundError"
self._set_page_content(f"# FileNotFoundError\n\nFile not found: {path}")
def _fetch_local_dir(self, local_path: str) -> str:
"""Render a local directory listing in HTML to assist with local file browsing via the "file://" protocol.
Through rendered in HTML, later parts of the pipeline will convert the listing to Markdown.
Arguments:
local_path: A path to the local directory whose contents are to be listed.
Returns:
A directory listing, rendered in HTML.
"""
listing = f"""
# Index of {local_path}
| Name | Size | Date Modified |
| ---- | ---- | ------------- |
| .. (parent directory) | | |
"""
for entry in os.listdir(local_path):
size = ""
full_path = os.path.join(local_path, entry)
mtime = ""
try:
mtime = datetime.datetime.fromtimestamp(os.path.getmtime(full_path)).strftime("%Y-%m-%d %H:%M")
except Exception as e:
# Handles PermissionError, etc.
mtime = f"N/A: {type(e).__name__}"
if os.path.isdir(full_path):
entry = entry + os.path.sep
else:
try:
size = str(os.path.getsize(full_path))
except Exception as e:
# Handles PermissionError, etc.
size = f"N/A: {type(e).__name__}"
listing += f"| {entry} | {size} | {mtime} |\n"
return listing

View File

@@ -0,0 +1,50 @@
from agentdhal_core.tools import ParametersSchema, ToolSchema
TOOL_OPEN_PATH = ToolSchema(
name="open_path",
description="Open a local file or directory at a path in the text-based file browser and return current viewport content.",
parameters=ParametersSchema(
type="object",
properties={
"path": {
"type": "string",
"description": "The relative or absolute path of a local file to visit.",
},
},
required=["path"],
),
)
TOOL_PAGE_UP = ToolSchema(
name="page_up",
description="Scroll the viewport UP one page-length in the current file and return the new viewport content.",
)
TOOL_PAGE_DOWN = ToolSchema(
name="page_down",
description="Scroll the viewport DOWN one page-length in the current file and return the new viewport content.",
)
TOOL_FIND_ON_PAGE_CTRL_F = ToolSchema(
name="find_on_page_ctrl_f",
description="Scroll the viewport to the first occurrence of the search string. This is equivalent to Ctrl+F.",
parameters=ParametersSchema(
type="object",
properties={
"search_string": {
"type": "string",
"description": "The string to search for on the page. This search string supports wildcards like '*'",
},
},
required=["search_string"],
),
)
TOOL_FIND_NEXT = ToolSchema(
name="find_next",
description="Scroll the viewport to next occurrence of the search string.",
)

View File

@@ -0,0 +1,10 @@
try:
from ._magentic_one_coder_agent import MagenticOneCoderAgent
except ImportError as e:
raise ImportError(
"Dependencies for MagenticOneCoderAgent not found. "
'Please install autogen-ext with the "magentic-one" extra: '
'pip install "agentdhal-ext[magentic-one]"'
) from e
__all__ = ["MagenticOneCoderAgent"]

View File

@@ -0,0 +1,41 @@
from typing import Any
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_core.models import (
ChatCompletionClient,
)
MAGENTIC_ONE_CODER_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills."
MAGENTIC_ONE_CODER_SYSTEM_MESSAGE = """You are a helpful AI assistant.
Solve tasks using your coding and language skills.
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.
Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.
When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.
Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use the 'print' function for the output when relevant. Check the execution result returned by the user.
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible."""
class MagenticOneCoderAgent(AssistantAgent):
"""An agent, used by MagenticOne that provides coding assistance using an LLM model client.
The prompts and description are sealed, to replicate the original MagenticOne configuration. See AssistantAgent if you wish to modify these values.
"""
component_provider_override = "agentdhal_extensions.agents.magentic_one.MagenticOneCoderAgent"
def __init__(
self,
name: str,
model_client: ChatCompletionClient,
**kwargs: Any,
):
super().__init__(
name,
model_client,
description=MAGENTIC_ONE_CODER_DESCRIPTION,
system_message=MAGENTIC_ONE_CODER_SYSTEM_MESSAGE,
)

View File

@@ -0,0 +1,7 @@
from ._openai_agent import OpenAIAgent
from ._openai_assistant_agent import OpenAIAssistantAgent
__all__ = [
"OpenAIAgent",
"OpenAIAssistantAgent",
]

View File

@@ -0,0 +1,682 @@
import asyncio
import logging
from typing import (
Any,
AsyncGenerator,
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Sequence,
Type,
Union,
cast,
)
from agentdhal_agentchat import EVENT_LOGGER_NAME
from agentdhal_agentchat.agents import BaseChatAgent
from agentdhal_agentchat.base import Response
from agentdhal_agentchat.messages import (
AgentEvent,
BaseChatMessage,
ChatMessage,
HandoffMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallSummaryMessage,
)
from agentdhal_core import CancellationToken, Component
from agentdhal_core.models import UserMessage
from pydantic import BaseModel, Field
from typing_extensions import NotRequired, TypedDict
from openai import AsyncAzureOpenAI, AsyncOpenAI # type: ignore
# Number of characters to display when previewing image content in logs and UI
# Base64 encoded images can be very long, so we truncate for readability
IMAGE_CONTENT_PREVIEW_LENGTH = 50
# NOTE: We use the new Responses API, so ChatCompletion imports are not needed.
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
# TypedDict classes for built-in tool configurations
class FileSearchToolConfig(TypedDict):
"""Configuration for file_search tool."""
type: Literal["file_search"]
vector_store_ids: List[str] # required - The IDs of the vector stores to search
max_num_results: NotRequired[int] # optional
ranking_options: NotRequired[Dict[str, Any]] # optional
filters: NotRequired[Dict[str, Any]] # optional
class WebSearchToolConfig(TypedDict):
"""Configuration for web_search_preview tool."""
type: Literal["web_search_preview"]
search_context_size: NotRequired[str] # optional
user_location: NotRequired[Union[str, Dict[str, Any]]] # optional - Can be string or structured location
class ComputerUseToolConfig(TypedDict):
"""Configuration for computer_use_preview tool."""
type: Literal["computer_use_preview"]
display_height: int # required - Display height in pixels
display_width: int # required - Display width in pixels
environment: str # required - Environment type for computer use
class MCPToolConfig(TypedDict):
"""Configuration for mcp tool."""
type: Literal["mcp"]
server_label: str # required - Label for the MCP server
server_url: str # required - URL of the MCP server
allowed_tools: NotRequired[List[str]] # optional - List of allowed tools
headers: NotRequired[Dict[str, str]] # optional - HTTP headers for requests
require_approval: NotRequired[bool] # optional - Whether to require user approval
class CodeInterpreterToolConfig(TypedDict):
"""Configuration for code_interpreter tool."""
type: Literal["code_interpreter"]
container: str | Dict[str, Any] # required - Container configuration for code execution
class ImageGenerationToolConfig(TypedDict):
"""Configuration for image_generation tool."""
type: Literal["image_generation"]
background: NotRequired[str] # optional - Background color or image
input_image_mask: NotRequired[str] # optional - Mask for input image editing
class LocalShellToolConfig(TypedDict):
"""Configuration for local_shell tool.
WARNING: This tool is only supported with the 'codex-mini-latest' model
and is available exclusively through the Responses API.
"""
type: Literal["local_shell"]
# Note: local_shell currently has no additional parameters in the API
# Union type for all built-in tool configurations
BuiltinToolConfig = Union[
FileSearchToolConfig,
WebSearchToolConfig,
ComputerUseToolConfig,
MCPToolConfig,
CodeInterpreterToolConfig,
ImageGenerationToolConfig,
LocalShellToolConfig,
]
# Define ImageMessage class early since it's used in _convert_message_to_openai_message
class ImageMessage(BaseChatMessage):
"""A message containing an image."""
content: str # URL or base64 string
def to_model_message(self) -> UserMessage:
return UserMessage(content=self.content, source=self.source)
def to_model_text(self) -> str:
return "[image]"
def to_text(self) -> str:
# Truncate long image content (especially base64) for better readability
# While still showing enough of the URL or content to be identifiable
if len(self.content) > IMAGE_CONTENT_PREVIEW_LENGTH:
return f"[Image: {self.content[:IMAGE_CONTENT_PREVIEW_LENGTH]}...]"
return f"[Image: {self.content}]"
class OpenAIMessageContent(TypedDict):
type: str
text: str
class OpenAIImageUrlContent(TypedDict):
url: str
class OpenAIImageContent(TypedDict):
type: str
image_url: OpenAIImageUrlContent
class OpenAIMessage(TypedDict):
role: str
content: Union[str, List[Union[OpenAIMessageContent, OpenAIImageContent]]]
def _convert_message_to_openai_message(
message: Union[TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage],
) -> OpenAIMessage:
"""Convert an AutoGen message to an OpenAI message format."""
if isinstance(message, TextMessage):
if message.source == "user":
return {"role": "user", "content": str(message.content)}
elif message.source == "system":
return {"role": "system", "content": str(message.content)}
elif message.source == "assistant":
return {"role": "assistant", "content": str(message.content)}
else:
return {"role": "user", "content": str(message.content)}
elif isinstance(message, MultiModalMessage):
content_parts: List[Union[OpenAIMessageContent, OpenAIImageContent]] = []
for part in message.content:
if isinstance(part, TextMessage):
content_parts.append({"type": "text", "text": str(part.content)})
elif isinstance(part, ImageMessage):
image_content = str(part.content)
content_parts.append({"type": "image_url", "image_url": {"url": image_content}})
return {"role": "user", "content": content_parts}
else:
return {"role": "user", "content": str(message.content)}
class OpenAIAgentState(BaseModel):
type: str = Field(default="OpenAIAgentState")
response_id: Optional[str] = None
history: List[Dict[str, Any]] = Field(default_factory=list)
class OpenAIAgentConfig(BaseModel):
"""
Configuration model for OpenAI agent supporting OpenAI built-in tools only.
.. versionchanged:: v0.7.0
Added support for built-in tools in JSON configuration via _to_config and _from_config methods.
The tools field accepts built-in tool configurations (dict format) and built-in tool names (string format).
Custom tools are not supported.
"""
name: str
description: str
model: str
instructions: str
tools: List[Dict[str, Any] | str] | None = None
temperature: Optional[float] = 1
max_output_tokens: Optional[int] = None
json_mode: bool = False
store: bool = True
truncation: str = "disabled"
class OpenAIAgent(BaseChatAgent, Component[OpenAIAgentConfig]):
"""
An agent implementation that uses the OpenAI Responses API to generate responses.
Installation:
.. code-block:: bash
pip install "agentdhal-ext[openai]"
# pip install "agentdhal-ext[openai,azure]" # For Azure OpenAI Assistant
This agent leverages the Responses API to generate responses with capabilities like:
* Multi-turn conversations
* Built-in tool support (file_search, code_interpreter, web_search_preview, etc.)
Currently, custom tools are not supported.
.. versionchanged:: v0.7.0
Added support for built-in tool types like file_search, web_search_preview,
code_interpreter, computer_use_preview, image_generation, and mcp.
Added support for tool configurations with required and optional parameters.
Built-in tools are split into two categories:
**Tools that can use string format** (no required parameters):
- web_search_preview: Can be used as "web_search_preview" or with optional config
(user_location, search_context_size)
- image_generation: Can be used as "image_generation" or with optional config (background, input_image_mask)
- local_shell: Can be used as "local_shell" (WARNING: Only works with codex-mini-latest model)
**Tools that REQUIRE dict configuration** (have required parameters):
- file_search: MUST use dict with vector_store_ids (List[str])
- computer_use_preview: MUST use dict with display_height (int), display_width (int), environment (str)
- code_interpreter: MUST use dict with container (str)
- mcp: MUST use dict with server_label (str), server_url (str)
Using required-parameter tools in string format will raise a ValueError with helpful error messages.
The tools parameter type annotation only accepts string values for tools that don't require parameters.
Note:
Custom tools (autogen FunctionTool or other user-defined tools) are not supported by this agent.
Only OpenAI built-in tools provided via the Responses API are supported.
Args:
name (str): Name of the agent
description (str): Description of the agent's purpose
client (Union[AsyncOpenAI, AsyncAzureOpenAI]): OpenAI client instance
model (str): Model to use (e.g. "gpt-4.1")
instructions (str): System instructions for the agent
tools (Optional[Iterable[Union[str, BuiltinToolConfig]]]): Tools the agent can use.
Supported string values (no required parameters): "web_search_preview", "image_generation", "local_shell".
Dict values can provide configuration for built-in tools with parameters.
Required parameters for built-in tools:
- file_search: vector_store_ids (List[str])
- computer_use_preview: display_height (int), display_width (int), environment (str)
- code_interpreter: container (str)
- mcp: server_label (str), server_url (str)
Optional parameters for built-in tools:
- file_search: max_num_results (int), ranking_options (dict), filters (dict)
- web_search_preview: user_location (str or dict), search_context_size (int)
- image_generation: background (str), input_image_mask (str)
- mcp: allowed_tools (List[str]), headers (dict), require_approval (bool)
Special tools with model restrictions:
- local_shell: Only works with "codex-mini-latest" model (WARNING: Very limited support)
Custom tools are not supported.
temperature (Optional[float]): Temperature for response generation (default: 1)
max_output_tokens (Optional[int]): Maximum output tokens
json_mode (bool): Whether to use JSON mode (default: False)
store (bool): Whether to store conversations (default: True)
truncation (str): Truncation strategy (default: "disabled")
Example:
Basic usage with built-in tools:
.. code-block:: python
import asyncio
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.agents.openai import OpenAIAgent
from openai import AsyncOpenAI
async def example():
client = AsyncOpenAI()
agent = OpenAIAgent(
name="SimpleAgent",
description="A simple OpenAI agent using the Responses API",
client=client,
model="gpt-4.1",
instructions="You are a helpful assistant.",
tools=["web_search_preview"], # Only tools without required params
)
await Console(agent.run_stream(task="Search for recent AI developments"))
asyncio.run(example())
Usage with configured built-in tools:
.. code-block:: python
import asyncio
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.agents.openai import OpenAIAgent
from openai import AsyncOpenAI
async def example_with_configs():
client = AsyncOpenAI()
# Configure tools with required and optional parameters
tools = [
# {
# "type": "file_search",
# "vector_store_ids": ["vs_abc123"], # required
# "max_num_results": 10, # optional
# },
# {
# "type": "computer_use_preview",
# "display_height": 1024, # required
# "display_width": 1280, # required
# "environment": "linux", # required
# },
{
"type": "code_interpreter",
"container": {"type": "auto"}, # required
},
# {
# "type": "mcp",
# "server_label": "my-mcp-server", # required
# "server_url": "http://localhost:3000", # required
# },
{
"type": "web_search_preview",
"user_location": { # optional - structured location
"type": "approximate", # required: "approximate" or "exact"
"country": "US", # optional
"region": "CA", # optional
"city": "San Francisco", # optional
},
"search_context_size": "low", # optional
},
# "image_generation", # Simple tools can still use string format
]
agent = OpenAIAgent(
name="ConfiguredAgent",
description="An agent with configured tools",
client=client,
model="gpt-4.1",
instructions="You are a helpful assistant with specialized tools.",
tools=tools, # type: ignore
)
await Console(agent.run_stream(task="Search for recent AI developments"))
asyncio.run(example_with_configs())
Note:
Custom tools are not supported by OpenAIAgent. Use only built-in tools from the Responses API.
"""
component_config_schema = OpenAIAgentConfig
component_provider_override = "agentdhal_extensions.agents.openai.OpenAIAgent"
def __init__(
self: "OpenAIAgent",
name: str,
description: str,
client: Union[AsyncOpenAI, AsyncAzureOpenAI],
model: str,
instructions: str,
tools: Optional[
Iterable[
Union[
Literal["web_search_preview", "image_generation", "local_shell"],
BuiltinToolConfig,
]
]
] = None,
temperature: Optional[float] = 1,
max_output_tokens: Optional[int] = None,
json_mode: bool = False,
store: bool = True,
truncation: str = "disabled",
) -> None:
super().__init__(name, description)
self._client: Union[AsyncOpenAI, AsyncAzureOpenAI] = client
self._model: str = model
self._instructions: str = instructions
self._temperature: Optional[float] = temperature
self._max_output_tokens: Optional[int] = max_output_tokens
self._json_mode: bool = json_mode
self._store: bool = store
self._truncation: str = truncation
self._last_response_id: Optional[str] = None
self._message_history: List[Dict[str, Any]] = []
self._tools: List[Dict[str, Any]] = []
if tools is not None:
for tool in tools:
if isinstance(tool, str):
# Handle built-in tool types
self._add_builtin_tool(tool)
elif isinstance(tool, dict) and "type" in tool:
# Handle configured built-in tools
self._tools.append(cast(dict[str, Any], tool))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
def _add_builtin_tool(self, tool_name: str) -> None:
"""Add a built-in tool by name."""
# Skip if an identical tool has already been registered (idempotent behaviour)
if any(td.get("type") == tool_name for td in self._tools):
return # Duplicate ignore rather than raise to stay backward-compatible
# Only allow string format for tools that don't require parameters
if tool_name == "web_search_preview":
self._tools.append({"type": "web_search_preview"})
elif tool_name == "image_generation":
self._tools.append({"type": "image_generation"})
elif tool_name == "local_shell":
# Special handling for local_shell - very limited model support
if self._model != "codex-mini-latest":
raise ValueError(
f"Tool 'local_shell' is only supported with model 'codex-mini-latest', "
f"but current model is '{self._model}'. "
f"This tool is available exclusively through the Responses API and has severe limitations. "
f"Consider using agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool with "
f"agentdhal_extensions.code_executors.local.LocalCommandLineCodeExecutor for shell execution instead."
)
self._tools.append({"type": "local_shell"})
elif tool_name in ["file_search", "code_interpreter", "computer_use_preview", "mcp"]:
# These tools require specific parameters and must use dict configuration
raise ValueError(
f"Tool '{tool_name}' requires specific parameters and cannot be added using string format. "
f"Use dict configuration instead. Required parameters for {tool_name}: "
f"{self._get_required_params_help(tool_name)}"
)
else:
raise ValueError(f"Unsupported built-in tool type: {tool_name}")
def _get_required_params_help(self, tool_name: str) -> str:
"""Get help text for required parameters of a tool."""
help_text = {
"file_search": "vector_store_ids (List[str])",
"code_interpreter": "container (str | dict)",
"computer_use_preview": "display_height (int), display_width (int), environment (str)",
"mcp": "server_label (str), server_url (str)",
}
return help_text.get(tool_name, "unknown parameters")
def _convert_message_to_dict(self, message: OpenAIMessage) -> Dict[str, Any]:
"""Convert an OpenAIMessage to a Dict[str, Any]."""
return dict(message)
@property
def produced_message_types(
self: "OpenAIAgent",
) -> Sequence[
Union[
Type[TextMessage],
Type[MultiModalMessage],
Type[StopMessage],
Type[ToolCallSummaryMessage],
Type[HandoffMessage],
]
]:
"""Return the types of messages that this agent can produce."""
return [TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage]
# Custom tool execution is not supported by this agent.
def _build_api_parameters(self: "OpenAIAgent", messages: List[Dict[str, Any]]) -> Dict[str, Any]:
has_system_message = any(msg.get("role") == "system" for msg in messages)
if self._instructions and not has_system_message:
messages = [{"role": "system", "content": self._instructions}] + messages
api_params: Dict[str, Any] = {
"model": self._model,
"input": messages, # Responses API expects 'input'
}
if self._temperature is not None:
api_params["temperature"] = self._temperature
if self._max_output_tokens is not None:
api_params["max_output_tokens"] = self._max_output_tokens
if self._tools:
api_params["tools"] = self._tools
if self._json_mode:
api_params["text"] = {"type": "json_object"}
api_params["store"] = self._store
api_params["truncation"] = self._truncation
if self._last_response_id:
api_params["previous_response_id"] = self._last_response_id
return api_params
async def on_messages(
self: "OpenAIAgent", messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> Response:
response = None
inner_messages: List[
Union[AgentEvent, TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage]
] = []
async for msg in self.on_messages_stream(messages, cancellation_token):
if isinstance(msg, Response):
response = msg
# ModelClientStreamingChunkEvent does not exist in this version, so skip this check
else:
inner_messages.append(msg)
if response is None:
raise ValueError("No response was generated")
if response.inner_messages is None:
response.inner_messages = []
for msg in inner_messages:
if msg not in response.inner_messages:
response.inner_messages = list(response.inner_messages) + [msg]
return response
async def on_messages_stream(
self: "OpenAIAgent", messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[
Union[
AgentEvent, TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage, Response
],
None,
]:
input_messages: List[Dict[str, Any]] = []
if self._message_history:
input_messages.extend(self._message_history)
for message in messages:
if isinstance(
message, (TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage)
):
openai_message = _convert_message_to_openai_message(message)
dict_message = self._convert_message_to_dict(openai_message)
input_messages.append(dict_message)
self._message_history.append(dict_message)
else:
msg_content = str(cast(Any, message).content) if hasattr(message, "content") else str(message)
dict_message = {"role": "user", "content": msg_content}
input_messages.append(dict_message)
self._message_history.append(dict_message)
inner_messages: List[AgentEvent | ChatMessage] = []
api_params = self._build_api_parameters(input_messages)
try:
client = cast(Any, self._client)
response_obj = await cancellation_token.link_future(
asyncio.ensure_future(client.responses.create(**api_params))
)
content = getattr(response_obj, "output_text", None)
response_id = getattr(response_obj, "id", None)
self._last_response_id = response_id
# Use a readable placeholder when the API returns no content to aid debugging
content_str: str = str(content) if content is not None else "[no content returned]"
self._message_history.append({"role": "assistant", "content": content_str})
final_message = TextMessage(source=self.name, content=content_str)
response = Response(chat_message=final_message, inner_messages=inner_messages)
yield response
except Exception as e:
error_message = f"Error generating response: {str(e)}"
event_logger.error(f"API error: {error_message}", exc_info=True)
error_response = TextMessage(source=self.name, content=error_message)
yield Response(chat_message=error_response, inner_messages=inner_messages)
async def on_reset(self: "OpenAIAgent", cancellation_token: CancellationToken) -> None:
self._last_response_id = None
self._message_history = []
async def save_state(self: "OpenAIAgent") -> Mapping[str, Any]:
state = OpenAIAgentState(
response_id=self._last_response_id,
history=self._message_history,
)
return state.model_dump()
async def load_state(self: "OpenAIAgent", state: Mapping[str, Any]) -> None:
agent_state = OpenAIAgentState.model_validate(state)
self._last_response_id = agent_state.response_id
self._message_history = agent_state.history
def _to_config(self: "OpenAIAgent") -> OpenAIAgentConfig:
"""Convert the OpenAI agent to a declarative config.
Serializes built-in tools to their appropriate configuration formats for JSON serialization.
Returns:
OpenAIAgentConfig: The configuration that can recreate this agent.
"""
return OpenAIAgentConfig(
name=self.name,
description=self.description,
model=self._model,
instructions=self._instructions,
tools=list(self._tools),
temperature=self._temperature,
max_output_tokens=self._max_output_tokens,
json_mode=self._json_mode,
store=self._store,
truncation=self._truncation,
)
@classmethod
def _from_config(cls: Type["OpenAIAgent"], config: OpenAIAgentConfig) -> "OpenAIAgent":
"""Create an OpenAI agent from a declarative config.
Handles built-in tools (from string or dict configurations).
Args:
config: The configuration to load the agent from.
Returns:
OpenAIAgent: The reconstructed agent.
"""
from openai import AsyncOpenAI
client = AsyncOpenAI()
return cls(
name=config.name,
description=config.description,
client=client,
model=config.model,
instructions=config.instructions,
tools=config.tools, # type: ignore
temperature=config.temperature,
max_output_tokens=config.max_output_tokens,
json_mode=config.json_mode,
store=config.store,
truncation=config.truncation,
)
# Add public API wrappers for configuration and tools
def to_config(self) -> OpenAIAgentConfig:
"""Public wrapper for the private _to_config method."""
return self._to_config()
@classmethod
def from_config(cls, config: OpenAIAgentConfig) -> "OpenAIAgent":
"""Public wrapper for the private _from_config classmethod."""
return cls._from_config(config)
@property
def tools(self) -> list[Any]:
"""Public access to the agent's tools."""
return self._tools
@property
def model(self) -> str:
"""Public access to the agent's model."""
return self._model

View File

@@ -0,0 +1,715 @@
import asyncio
import json
import logging
import os
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Union,
cast,
)
import aiofiles
from agentdhal_agentchat import EVENT_LOGGER_NAME
from agentdhal_agentchat.agents import BaseChatAgent
from agentdhal_agentchat.base import Response
from agentdhal_agentchat.messages import (
BaseAgentEvent,
BaseChatMessage,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
)
from agentdhal_core import CancellationToken, FunctionCall, Image
from agentdhal_core.models import ChatCompletionClient, FunctionExecutionResult
from agentdhal_core.tools import FunctionTool, Tool
from pydantic import BaseModel, Field
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI, NotGiven
from openai.pagination import AsyncCursorPage
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
from openai.types import FileObject
from openai.types.beta import thread_update_params
from openai.types.beta.assistant import Assistant
from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam
from openai.types.beta.assistant_tool_param import AssistantToolParam
from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
from openai.types.beta.file_search_tool_param import FileSearchToolParam
from openai.types.beta.function_tool_param import FunctionToolParam
from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter
from openai.types.beta.threads import Message, MessageDeleted, Run
from openai.types.beta.threads.image_url_content_block_param import ImageURLContentBlockParam
from openai.types.beta.threads.image_url_param import ImageURLParam
from openai.types.beta.threads.message_content_part_param import (
MessageContentPartParam,
)
from openai.types.beta.threads.text_content_block_param import TextContentBlockParam
from openai.types.shared_params.function_definition import FunctionDefinition
from openai.types.vector_store import VectorStore
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam":
"""Convert an autogen Tool to an OpenAI Assistant function tool parameter."""
schema = tool.schema
parameters: Dict[str, object] = {}
if "parameters" in schema:
parameters = {
"type": schema["parameters"]["type"],
"properties": schema["parameters"]["properties"],
}
if "required" in schema["parameters"]:
parameters["required"] = schema["parameters"]["required"]
function_def = FunctionDefinition(
name=schema["name"],
description=schema.get("description", ""),
parameters=parameters,
)
return FunctionToolParam(type="function", function=function_def)
class OpenAIAssistantAgentState(BaseModel):
type: str = Field(default="OpenAIAssistantAgentState")
assistant_id: Optional[str] = None
thread_id: Optional[str] = None
initial_message_ids: List[str] = Field(default_factory=list)
vector_store_id: Optional[str] = None
uploaded_file_ids: List[str] = Field(default_factory=list)
class OpenAIAssistantAgent(BaseChatAgent):
"""An agent implementation that uses the Assistant API to generate responses.
Installation:
.. code-block:: bash
pip install "agentdhal-ext[openai]" # For OpenAI Assistant
# pip install "agentdhal-ext[openai,azure]" # For Azure OpenAI Assistant
This agent leverages the Assistant API to create AI assistants with capabilities like:
* Code interpretation and execution
* File handling and search
* Custom function calling
* Multi-turn conversations
The agent maintains a thread of conversation and can use various tools including
* Code interpreter: For executing code and working with files
* File search: For searching through uploaded documents
* Custom functions: For extending capabilities with user-defined tools
Key Features:
* Supports multiple file formats including code, documents, images
* Can handle up to 128 tools per assistant
* Maintains conversation context in threads
* Supports file uploads for code interpreter and search
* Vector store integration for efficient file search
* Automatic file parsing and embedding
You can use an existing thread or assistant by providing the `thread_id` or `assistant_id` parameters.
Examples:
Use the assistant to analyze data in a CSV file:
.. code-block:: python
from openai import AsyncOpenAI
from agentdhal_core import CancellationToken
import asyncio
from agentdhal_extensions.agents.openai import OpenAIAssistantAgent
from agentdhal_agentchat.messages import TextMessage
async def example():
cancellation_token = CancellationToken()
# Create an OpenAI client
client = AsyncOpenAI(api_key="your-api-key", base_url="your-base-url")
# Create an assistant with code interpreter
assistant = OpenAIAssistantAgent(
name="PythonHelper",
description="Helps with Python programming",
client=client,
model="gpt-4",
instructions="You are a helpful Python programming assistant.",
tools=["code_interpreter"],
)
# Upload files for the assistant to use
await assistant.on_upload_for_code_interpreter("data.csv", cancellation_token)
# Get response from the assistant
response = await assistant.on_messages(
[TextMessage(source="user", content="Analyze the data in data.csv")], cancellation_token
)
print(response)
# Clean up resources
await assistant.delete_uploaded_files(cancellation_token)
await assistant.delete_assistant(cancellation_token)
asyncio.run(example())
Use Azure OpenAI Assistant with AAD authentication:
.. code-block:: python
from openai import AsyncAzureOpenAI
import asyncio
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from agentdhal_core import CancellationToken
from agentdhal_extensions.agents.openai import OpenAIAssistantAgent
from agentdhal_agentchat.messages import TextMessage
async def example():
cancellation_token = CancellationToken()
# Create an Azure OpenAI client
token_provider = get_bearer_token_provider(DefaultAzureCredential())
client = AsyncAzureOpenAI(
azure_deployment="YOUR_AZURE_DEPLOYMENT",
api_version="YOUR_API_VERSION",
azure_endpoint="YOUR_AZURE_ENDPOINT",
azure_ad_token_provider=token_provider,
)
# Create an assistant with code interpreter
assistant = OpenAIAssistantAgent(
name="PythonHelper",
description="Helps with Python programming",
client=client,
model="gpt-4o",
instructions="You are a helpful Python programming assistant.",
tools=["code_interpreter"],
)
# Get response from the assistant
response = await assistant.on_messages([TextMessage(source="user", content="Hello.")], cancellation_token)
print(response)
# Clean up resources
await assistant.delete_assistant(cancellation_token)
asyncio.run(example())
Args:
name (str): Name of the assistant
description (str): Description of the assistant's purpose
client (AsyncOpenAI | AsyncAzureOpenAI): OpenAI client or Azure OpenAI client instance
model (str): Model to use (e.g. "gpt-4")
instructions (str): System instructions for the assistant
tools (Optional[Iterable[Union[Literal["code_interpreter", "file_search"], Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]]]]): Tools the assistant can use
assistant_id (Optional[str]): ID of existing assistant to use
thread_id (Optional[str]): ID of existing thread to use
metadata (Optional[Dict[str, str]]): Additional metadata for the assistant.
response_format (Optional[AssistantResponseFormatOptionParam]): Response format settings
temperature (Optional[float]): Temperature for response generation
tool_resources (Optional[ToolResources]): Additional tool configuration
top_p (Optional[float]): Top p sampling parameter
"""
def __init__(
self,
name: str,
description: str,
client: AsyncOpenAI | AsyncAzureOpenAI,
model: str,
instructions: str,
tools: Optional[
Iterable[
Union[
Literal["code_interpreter", "file_search"],
Tool | Callable[..., Any] | Callable[..., Awaitable[Any]],
]
]
] = None,
assistant_id: Optional[str] = None,
thread_id: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
response_format: Optional["AssistantResponseFormatOptionParam"] = None,
temperature: Optional[float] = None,
tool_resources: Optional["ToolResources"] = None,
top_p: Optional[float] = None,
) -> None:
if isinstance(client, ChatCompletionClient):
raise ValueError(
"Incorrect client passed to OpenAIAssistantAgent. Please use an OpenAI AsyncClient instance instead of an AutoGen ChatCompletionClient instance."
)
super().__init__(name, description)
if tools is None:
tools = []
# Store original tools and converted tools separately
self._original_tools: List[Tool] = []
converted_tools: List["AssistantToolParam"] = []
for tool in tools:
if isinstance(tool, str):
if tool == "code_interpreter":
converted_tools.append(CodeInterpreterToolParam(type="code_interpreter"))
elif tool == "file_search":
converted_tools.append(FileSearchToolParam(type="file_search"))
elif isinstance(tool, Tool):
self._original_tools.append(tool)
converted_tools.append(_convert_tool_to_function_param(tool))
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
function_tool = FunctionTool(tool, description=description)
self._original_tools.append(function_tool)
converted_tools.append(_convert_tool_to_function_param(function_tool))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
self._client = client
self._assistant: Optional["Assistant"] = None
self._thread: Optional["Thread"] = None
self._init_thread_id = thread_id
self._model = model
self._instructions = instructions
self._api_tools = converted_tools
self._assistant_id = assistant_id
self._metadata = metadata
self._response_format = response_format
self._temperature = temperature
self._tool_resources = tool_resources
self._top_p = top_p
self._vector_store_id: Optional[str] = None
self._uploaded_file_ids: List[str] = []
# Variables to track initial state
self._initial_message_ids: Set[str] = set()
self._initial_state_retrieved: bool = False
async def _ensure_initialized(self) -> None:
"""Ensure assistant and thread are created."""
if self._assistant is None:
if self._assistant_id:
self._assistant = await self._client.beta.assistants.retrieve(assistant_id=self._assistant_id) # type: ignore[reportDeprecated]
else:
self._assistant = await self._client.beta.assistants.create( # type: ignore[reportDeprecated]
model=self._model,
description=self.description,
instructions=self._instructions,
tools=self._api_tools,
metadata=self._metadata,
response_format=self._response_format if self._response_format else NOT_GIVEN, # type: ignore
temperature=self._temperature,
tool_resources=self._tool_resources if self._tool_resources else NOT_GIVEN, # type: ignore
top_p=self._top_p,
)
if self._thread is None:
if self._init_thread_id:
self._thread = await self._client.beta.threads.retrieve(thread_id=self._init_thread_id) # type: ignore[reportDeprecated]
else:
self._thread = await self._client.beta.threads.create() # type: ignore[reportDeprecated]
# Retrieve initial state only once
if not self._initial_state_retrieved:
await self._retrieve_initial_state()
self._initial_state_retrieved = True
async def _retrieve_initial_state(self) -> None:
"""Retrieve and store the initial state of messages and runs."""
# Retrieve all initial message IDs
initial_message_ids: Set[str] = set()
after: str | NotGiven = NOT_GIVEN
while True:
msgs: AsyncCursorPage[Message] = await self._client.beta.threads.messages.list( # type: ignore[reportDeprecated]
self._thread_id, after=after, order="asc", limit=100
)
for msg in msgs.data:
initial_message_ids.add(msg.id)
if not msgs.has_next_page():
break
after = msgs.data[-1].id
self._initial_message_ids = initial_message_ids
@property
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
"""The types of messages that the assistant agent produces."""
return (TextMessage,)
@property
def threads(self) -> AsyncThreads:
return self._client.beta.threads
@property
def runs(self) -> AsyncRuns:
return self._client.beta.threads.runs
@property
def messages(self) -> AsyncMessages:
return self._client.beta.threads.messages
@property
def _get_assistant_id(self) -> str:
if self._assistant is None:
raise ValueError("Assistant not initialized")
return self._assistant.id
@property
def _thread_id(self) -> str:
if self._thread is None:
raise ValueError("Thread not initialized")
return self._thread.id
async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str:
"""Execute a tool call and return the result."""
if not self._original_tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._original_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token, call_id=tool_call.id)
return tool.return_value_as_string(result)
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handle incoming messages and return a response."""
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
async def on_messages_stream(
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
"""Handle incoming messages and return a response."""
await self._ensure_initialized()
# Process all messages in sequence
for message in messages:
await self.handle_incoming_message(message, cancellation_token)
# Inner messages for tool calls
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
# Create and start a run
run: Run = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.runs.create( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
assistant_id=self._get_assistant_id,
)
)
)
# Wait for run completion by polling
while True:
run = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.runs.retrieve( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
run_id=run.id,
)
)
)
if run.status == "failed":
raise ValueError(f"Run failed: {run.last_error}")
# If the run requires action (function calls), execute tools and continue
if run.status == "requires_action" and run.required_action is not None:
tool_calls: List[FunctionCall] = []
for required_tool_call in run.required_action.submit_tool_outputs.tool_calls:
if required_tool_call.type == "function":
tool_calls.append(
FunctionCall(
id=required_tool_call.id,
name=required_tool_call.function.name,
arguments=required_tool_call.function.arguments,
)
)
# Add tool call message to inner messages
tool_call_msg = ToolCallRequestEvent(source=self.name, content=tool_calls)
inner_messages.append(tool_call_msg)
event_logger.debug(tool_call_msg)
yield tool_call_msg
# Execute tool calls and get results
tool_outputs: List[FunctionExecutionResult] = []
for tool_call in tool_calls:
try:
result = await self._execute_tool_call(tool_call, cancellation_token)
is_error = False
except Exception as e:
result = f"Error: {e}"
is_error = True
tool_outputs.append(
FunctionExecutionResult(
content=result, call_id=tool_call.id, is_error=is_error, name=tool_call.name
)
)
# Add tool result message to inner messages
tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs)
inner_messages.append(tool_result_msg)
event_logger.debug(tool_result_msg)
yield tool_result_msg
# Submit tool outputs back to the run
run = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.runs.submit_tool_outputs( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
run_id=run.id,
tool_outputs=[{"tool_call_id": t.call_id, "output": t.content} for t in tool_outputs],
)
)
)
continue
if run.status == "completed":
break
await asyncio.sleep(0.5)
# Get messages after run completion
assistant_messages: AsyncCursorPage[Message] = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.list(thread_id=self._thread_id, order="desc", limit=1) # type: ignore[reportDeprecated]
)
)
if not assistant_messages.data:
raise ValueError("No messages received from assistant")
# Get the last message's content
last_message = assistant_messages.data[0]
if not last_message.content:
raise ValueError(f"No content in the last message: {last_message}")
# Extract text content
text_content = [content for content in last_message.content if content.type == "text"]
if not text_content:
raise ValueError(f"Expected text content in the last message: {last_message.content}")
# Return the assistant's response as a Response with inner messages
chat_message = TextMessage(source=self.name, content=text_content[0].text.value)
yield Response(chat_message=chat_message, inner_messages=inner_messages)
async def handle_incoming_message(self, message: BaseChatMessage, cancellation_token: CancellationToken) -> None:
"""Handle regular text messages by adding them to the thread."""
content: str | List[MessageContentPartParam] | None = None
llm_message = message.to_model_message()
if isinstance(llm_message.content, str):
content = llm_message.content
else:
content = []
for c in llm_message.content:
if isinstance(c, str):
content.append(TextContentBlockParam(text=c, type="text"))
elif isinstance(c, Image):
content.append(ImageURLContentBlockParam(image_url=ImageURLParam(url=c.data_uri), type="image_url"))
else:
raise ValueError(f"Unsupported content type: {type(c)} in {message}")
await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.create( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
content=content,
role="user",
)
)
)
async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Handle reset command by deleting new messages and runs since initialization."""
await self._ensure_initialized()
# Retrieve all message IDs in the thread
new_message_ids: List[str] = []
after: str | NotGiven = NOT_GIVEN
while True:
msgs: AsyncCursorPage[Message] = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.list(self._thread_id, after=after, order="asc", limit=100) # type: ignore[reportDeprecated]
)
)
for msg in msgs.data:
if msg.id not in self._initial_message_ids:
new_message_ids.append(msg.id)
if not msgs.has_next_page():
break
after = msgs.data[-1].id
# Delete new messages
for msg_id in new_message_ids:
status: MessageDeleted = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id) # type: ignore[reportDeprecated]
)
)
assert status.deleted is True
async def _upload_files(self, file_paths: str | Iterable[str], cancellation_token: CancellationToken) -> List[str]:
"""Upload files and return their IDs."""
await self._ensure_initialized()
if isinstance(file_paths, str):
file_paths = [file_paths]
file_ids: List[str] = []
for file_path in file_paths:
async with aiofiles.open(file_path, mode="rb") as f:
file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read()))
file_name = os.path.basename(file_path)
file: FileObject = await cancellation_token.link_future(
asyncio.ensure_future(self._client.files.create(file=(file_name, file_content), purpose="assistants"))
)
file_ids.append(file.id)
self._uploaded_file_ids.append(file.id)
return file_ids
async def on_upload_for_code_interpreter(
self, file_paths: str | Iterable[str], cancellation_token: CancellationToken
) -> None:
"""Handle file uploads for the code interpreter."""
await self._ensure_initialized()
file_ids = await self._upload_files(file_paths, cancellation_token)
# Update thread with the new files
thread = await cancellation_token.link_future(
asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id)) # type: ignore[reportDeprecated]
)
tool_resources: ToolResources = thread.tool_resources or ToolResources()
code_interpreter: ToolResourcesCodeInterpreter = (
tool_resources.code_interpreter or ToolResourcesCodeInterpreter()
)
existing_file_ids: List[str] = code_interpreter.file_ids or []
existing_file_ids.extend(file_ids)
tool_resources.code_interpreter = ToolResourcesCodeInterpreter(file_ids=existing_file_ids)
await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.update( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
tool_resources=cast(thread_update_params.ToolResources, tool_resources.model_dump()),
)
)
)
async def on_upload_for_file_search(
self, file_paths: str | Iterable[str], cancellation_token: CancellationToken
) -> None:
"""Handle file uploads for file search."""
await self._ensure_initialized()
# Check if file_search is enabled in tools
if not any(tool.get("type") == "file_search" for tool in self._api_tools):
raise ValueError(
"File search is not enabled for this assistant. Add a file_search tool when creating the assistant."
)
# Create vector store if not already created
if self._vector_store_id is None:
vector_store: VectorStore = await cancellation_token.link_future(
asyncio.ensure_future(self._client.vector_stores.create())
)
self._vector_store_id = vector_store.id
# Update assistant with vector store ID
await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.assistants.update(
assistant_id=self._get_assistant_id,
tool_resources={"file_search": {"vector_store_ids": [self._vector_store_id]}},
)
)
)
file_ids = await self._upload_files(file_paths, cancellation_token)
# Create file batch with the file IDs
await cancellation_token.link_future(
asyncio.ensure_future(
self._client.vector_stores.file_batches.create_and_poll(
vector_store_id=self._vector_store_id, file_ids=file_ids
)
)
)
async def delete_uploaded_files(self, cancellation_token: CancellationToken) -> None:
"""Delete all files that were uploaded by this agent instance."""
await self._ensure_initialized()
for file_id in self._uploaded_file_ids:
try:
await cancellation_token.link_future(asyncio.ensure_future(self._client.files.delete(file_id=file_id)))
except Exception as e:
event_logger.error(f"Failed to delete file {file_id}: {str(e)}")
self._uploaded_file_ids = []
async def delete_assistant(self, cancellation_token: CancellationToken) -> None:
"""Delete the assistant if it was created by this instance."""
await self._ensure_initialized()
if self._assistant is not None and not self._assistant_id:
try:
await cancellation_token.link_future(
asyncio.ensure_future(self._client.beta.assistants.delete(assistant_id=self._get_assistant_id)) # type: ignore[reportDeprecated]
)
self._assistant = None
except Exception as e:
event_logger.error(f"Failed to delete assistant: {str(e)}")
async def delete_vector_store(self, cancellation_token: CancellationToken) -> None:
"""Delete the vector store if it was created by this instance."""
await self._ensure_initialized()
if self._vector_store_id is not None:
try:
await cancellation_token.link_future(
asyncio.ensure_future(self._client.vector_stores.delete(vector_store_id=self._vector_store_id))
)
self._vector_store_id = None
except Exception as e:
event_logger.error(f"Failed to delete vector store: {str(e)}")
async def save_state(self) -> Mapping[str, Any]:
state = OpenAIAssistantAgentState(
assistant_id=self._assistant.id if self._assistant else self._assistant_id,
thread_id=self._thread.id if self._thread else self._init_thread_id,
initial_message_ids=list(self._initial_message_ids),
vector_store_id=self._vector_store_id,
uploaded_file_ids=self._uploaded_file_ids,
)
return state.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
agent_state = OpenAIAssistantAgentState.model_validate(state)
self._assistant_id = agent_state.assistant_id
self._init_thread_id = agent_state.thread_id
self._initial_message_ids = set(agent_state.initial_message_ids)
self._vector_store_id = agent_state.vector_store_id
self._uploaded_file_ids = agent_state.uploaded_file_ids

View File

@@ -0,0 +1,3 @@
from ._video_surfer import VideoSurfer
__all__ = ["VideoSurfer"]

View File

@@ -0,0 +1,172 @@
from typing import Any, Awaitable, Callable, List, Optional
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_core.models import ChatCompletionClient
from agentdhal_core.tools import BaseTool
from pydantic import BaseModel
from .tools import (
extract_audio,
get_screenshot_at,
get_video_length,
save_screenshot,
transcribe_audio_with_timestamps,
transcribe_video_screenshot,
)
class VideoSurfer(AssistantAgent):
"""
VideoSurfer is a specialized agent designed to answer questions about a local video file.
Installation:
.. code-block:: bash
pip install "agentdhal-ext[video-surfer]"
This agent utilizes various tools to extract information from the video, such as its length, screenshots at specific timestamps, and audio transcriptions. It processes these elements to provide detailed answers to user queries.
Available tools:
- :func:`~agentdhal_extensions.agents.video_surfer.tools.extract_audio`
- :func:`~agentdhal_extensions.agents.video_surfer.tools.get_video_length`
- :func:`~agentdhal_extensions.agents.video_surfer.tools.transcribe_audio_with_timestamps`
- :func:`~agentdhal_extensions.agents.video_surfer.tools.get_screenshot_at`
- :func:`~agentdhal_extensions.agents.video_surfer.tools.save_screenshot`
- :func:`~agentdhal_extensions.agents.video_surfer.tools.transcribe_video_screenshot`
Args:
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client used for generating responses.
tools (List[BaseTool[BaseModel, BaseModel] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional):
A list of tools or functions the agent can use. If not provided, defaults to all video tools from the action space.
description (str, optional): A brief description of the agent. Defaults to "An agent that can answer questions about a local video.".
system_message (str | None, optional): The system message guiding the agent's behavior. Defaults to a predefined message.
Example usage:
The following example demonstrates how to create an video surfing agent with
a model client and generate a response to a simple query about a local video
called video.mp4.
.. code-block:: python
import asyncio
from agentdhal_agentchat.ui import Console
from agentdhal_agentchat.conditions import TextMentionTermination
from agentdhal_agentchat.teams import RoundRobinGroupChat
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.agents.video_surfer import VideoSurfer
async def main() -> None:
\"\"\"
Main function to run the video agent.
\"\"\"
# Define an agent
video_agent = VideoSurfer(
name="VideoSurfer",
model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-08-06")
)
# Define termination condition
termination = TextMentionTermination("TERMINATE")
# Define a team
agent_team = RoundRobinGroupChat([video_agent], termination_condition=termination)
# Run the team and stream messages to the console
stream = agent_team.run_stream(task="How does Adam define complex tasks in video.mp4? What concrete example of complex does his use? Can you save this example to disk as well?")
await Console(stream)
asyncio.run(main())
The following example demonstrates how to create and use a VideoSurfer and UserProxyAgent with MagenticOneGroupChat.
.. code-block:: python
import asyncio
from agentdhal_agentchat.ui import Console
from agentdhal_agentchat.teams import MagenticOneGroupChat
from agentdhal_agentchat.agents import UserProxyAgent
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.agents.video_surfer import VideoSurfer
async def main() -> None:
\"\"\"
Main function to run the video agent.
\"\"\"
model_client = OpenAIChatCompletionClient(model="gpt-4o-2024-08-06")
# Define an agent
video_agent = VideoSurfer(
name="VideoSurfer",
model_client=model_client
)
web_surfer_agent = UserProxyAgent(
name="User"
)
# Define a team
agent_team = MagenticOneGroupChat([web_surfer_agent, video_agent], model_client=model_client,)
# Run the team and stream messages to the console
stream = agent_team.run_stream(task="Find a latest video about magentic one on youtube and extract quotes from it that make sense.")
await Console(stream)
asyncio.run(main())
"""
DEFAULT_DESCRIPTION = "An agent that can answer questions about a local video."
DEFAULT_SYSTEM_MESSAGE = """
You are a helpful agent that is an expert at answering questions from a video.
When asked to answer a question about a video, you should:
1. Check if that video is available locally.
2. Use the transcription to find which part of the video the question is referring to.
3. Optionally use screenshots from those timestamps
4. Provide a detailed answer to the question.
Reply with TERMINATE when the task has been completed.
"""
def __init__(
self,
name: str,
model_client: ChatCompletionClient,
*,
tools: List[BaseTool[BaseModel, BaseModel] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
description: Optional[str] = None,
system_message: Optional[str] = None,
):
super().__init__(
name=name,
model_client=model_client,
tools=tools
or [
get_video_length,
get_screenshot_at,
save_screenshot,
self.vs_transribe_video_screenshot,
extract_audio,
transcribe_audio_with_timestamps,
],
description=description or self.DEFAULT_DESCRIPTION,
system_message=system_message or self.DEFAULT_SYSTEM_MESSAGE,
)
async def vs_transribe_video_screenshot(self, video_path: str, timestamp: float) -> str:
"""
Transcribes the video screenshot at a specific timestamp.
Args:
video_path (str): Path to the video file.
timestamp (float): Timestamp to take the screenshot.
Returns:
str: Transcription of the video screenshot.
"""
return await transcribe_video_screenshot(video_path, timestamp, self._model_client)

View File

@@ -0,0 +1,156 @@
import base64
from typing import Any, Dict, List, Tuple
import cv2
import ffmpeg
import numpy as np
import whisper
from agentdhal_core import Image as AGImage
from agentdhal_core.models import (
ChatCompletionClient,
UserMessage,
)
def extract_audio(video_path: str, audio_output_path: str) -> str:
"""
Extracts audio from a video file and saves it as an MP3 file.
:param video_path: Path to the video file.
:param audio_output_path: Path to save the extracted audio file.
:return: Confirmation message with the path to the saved audio file.
"""
(ffmpeg.input(video_path).output(audio_output_path, format="mp3").run(quiet=True, overwrite_output=True)) # type: ignore
return f"Audio extracted and saved to {audio_output_path}."
def transcribe_audio_with_timestamps(audio_path: str) -> str:
"""
Transcribes the audio file with timestamps using the Whisper model.
:param audio_path: Path to the audio file.
:return: Transcription with timestamps.
"""
model = whisper.load_model("base") # type: ignore
result: Dict[str, Any] = model.transcribe(audio_path, task="transcribe", language="en", verbose=False) # type: ignore
segments: List[Dict[str, Any]] = result["segments"]
transcription_with_timestamps = ""
for segment in segments:
start: float = segment["start"]
end: float = segment["end"]
text: str = segment["text"]
transcription_with_timestamps += f"[{start:.2f} - {end:.2f}] {text}\n"
return transcription_with_timestamps
def get_video_length(video_path: str) -> str:
"""
Returns the length of the video in seconds.
:param video_path: Path to the video file.
:return: Duration of the video in seconds.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Cannot open video file {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
duration = frame_count / fps
cap.release()
return f"The video is {duration:.2f} seconds long."
def save_screenshot(video_path: str, timestamp: float, output_path: str) -> None:
"""
Captures a screenshot at the specified timestamp and saves it to the output path.
:param video_path: Path to the video file.
:param timestamp: Timestamp in seconds.
:param output_path: Path to save the screenshot. The file format is determined by the extension in the path.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Cannot open video file {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
frame_number = int(timestamp * fps)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = cap.read()
if ret:
cv2.imwrite(output_path, frame)
else:
raise IOError(f"Failed to capture frame at {timestamp:.2f}s")
cap.release()
async def transcribe_video_screenshot(video_path: str, timestamp: float, model_client: ChatCompletionClient) -> str:
"""
Transcribes the content of a video screenshot captured at the specified timestamp using OpenAI API.
:param video_path: Path to the video file.
:param timestamp: Timestamp in seconds.
:param model_client: ChatCompletionClient instance.
:return: Description of the screenshot content.
"""
screenshots = get_screenshot_at(video_path, [timestamp])
if not screenshots:
return "Failed to capture screenshot."
_, frame = screenshots[0]
# Convert the frame to bytes and then to base64 encoding
_, buffer = cv2.imencode(".jpg", frame)
frame_bytes = buffer.tobytes()
frame_base64 = base64.b64encode(frame_bytes).decode("utf-8")
screenshot_uri = f"data:image/jpeg;base64,{frame_base64}"
messages = [
UserMessage(
content=[
"Following is a screenshot from the video at {} seconds. Describe what you see here.",
AGImage.from_uri(screenshot_uri),
],
source="tool",
)
]
result = await model_client.create(messages=messages)
return str(result.content)
def get_screenshot_at(video_path: str, timestamps: List[float]) -> List[Tuple[float, np.ndarray[Any, Any]]]:
"""
Captures screenshots at the specified timestamps and returns them as Python objects.
:param video_path: Path to the video file.
:param timestamps: List of timestamps in seconds.
:return: List of tuples containing timestamp and the corresponding frame (image).
Each frame is a NumPy array (height x width x channels).
"""
screenshots: List[Tuple[float, np.ndarray[Any, Any]]] = []
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Cannot open video file {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
duration = total_frames / fps
for timestamp in timestamps:
if 0 <= timestamp <= duration:
frame_number = int(timestamp * fps)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = cap.read()
if ret:
# Append the timestamp and frame to the list
screenshots.append((timestamp, frame))
else:
raise IOError(f"Failed to capture frame at {timestamp:.2f}s")
else:
raise ValueError(f"Timestamp {timestamp:.2f}s is out of range [0s, {duration:.2f}s]")
cap.release()
return screenshots

View File

@@ -0,0 +1,4 @@
from ._multimodal_web_surfer import MultimodalWebSurfer
from .playwright_controller import PlaywrightController
__all__ = ["MultimodalWebSurfer", "PlaywrightController"]

View File

@@ -0,0 +1,11 @@
from dataclasses import dataclass
from typing import Any, Dict
@dataclass
class WebSurferEvent:
source: str
message: str
url: str
action: str | None = None
arguments: Dict[str, Any] | None = None

View File

@@ -0,0 +1,988 @@
import asyncio
import base64
import hashlib
import io
import json
import logging
import os
import re
import sys
import time
import traceback
import warnings
from typing import (
Any,
AsyncGenerator,
Dict,
List,
Optional,
Sequence,
)
from urllib.parse import quote_plus
import aiofiles
import PIL.Image
from agentdhal_agentchat.agents import BaseChatAgent
from agentdhal_agentchat.base import Response
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, MultiModalMessage, TextMessage
from agentdhal_agentchat.utils import content_to_str, remove_images
from agentdhal_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel, FunctionCall
from agentdhal_core import Image as AGImage
from agentdhal_core.models import (
AssistantMessage,
ChatCompletionClient,
LLMMessage,
ModelFamily,
RequestUsage,
SystemMessage,
UserMessage,
)
from PIL import Image
from playwright.async_api import BrowserContext, Download, Page, Playwright, async_playwright
from pydantic import BaseModel
from typing_extensions import Self
from ._events import WebSurferEvent
from ._prompts import (
WEB_SURFER_QA_PROMPT,
WEB_SURFER_QA_SYSTEM_MESSAGE,
WEB_SURFER_TOOL_PROMPT_MM,
WEB_SURFER_TOOL_PROMPT_TEXT,
)
from ._set_of_mark import add_set_of_mark
from ._tool_definitions import (
TOOL_CLICK,
TOOL_HISTORY_BACK,
TOOL_HOVER,
TOOL_READ_PAGE_AND_ANSWER,
TOOL_SCROLL_DOWN,
TOOL_SCROLL_UP,
TOOL_SLEEP,
TOOL_SUMMARIZE_PAGE,
TOOL_TYPE,
TOOL_VISIT_URL,
TOOL_WEB_SEARCH,
)
from ._types import InteractiveRegion, UserContent
from .playwright_controller import PlaywrightController
DEFAULT_CONTEXT_SIZE = 128000
class MultimodalWebSurferConfig(BaseModel):
name: str
model_client: ComponentModel
downloads_folder: str | None = None
description: str | None = None
debug_dir: str | None = None
headless: bool = True
start_page: str | None = "https://www.bing.com/"
animate_actions: bool = False
to_save_screenshots: bool = False
use_ocr: bool = False
browser_channel: str | None = None
browser_data_dir: str | None = None
to_resize_viewport: bool = True
class MultimodalWebSurfer(BaseChatAgent, Component[MultimodalWebSurferConfig]):
"""
MultimodalWebSurfer is a multimodal agent that acts as a web surfer that can search the web and visit web pages.
Installation:
.. code-block:: bash
pip install "agentdhal-ext[web-surfer]"
It launches a chromium browser and allows the playwright to interact with the web browser and can perform a variety of actions. The browser is launched on the first call to the agent and is reused for subsequent calls.
It must be used with a multimodal model client that supports function/tool calling, ideally GPT-4o currently.
When :meth:`on_messages` or :meth:`on_messages_stream` is called, the following occurs:
1) If this is the first call, the browser is initialized and the page is loaded. This is done in :meth:`_lazy_init`. The browser is only closed when :meth:`close` is called.
2) The method :meth:`_generate_reply` is called, which then creates the final response as below.
3) The agent takes a screenshot of the page, extracts the interactive elements, and prepares a set-of-mark screenshot with bounding boxes around the interactive elements.
4) The agent makes a call to the :attr:`model_client` with the SOM screenshot, history of messages, and the list of available tools.
- If the model returns a string, the agent returns the string as the final response.
- If the model returns a list of tool calls, the agent executes the tool calls with :meth:`_execute_tool` using :attr:`_playwright_controller`.
- The agent returns a final response which includes a screenshot of the page, page metadata, description of the action taken and the inner text of the webpage.
5) If at any point the agent encounters an error, it returns the error message as the final response.
.. note::
Please note that using the MultimodalWebSurfer involves interacting with a digital world designed for humans, which carries inherent risks.
Be aware that agents may occasionally attempt risky actions, such as recruiting humans for help or accepting cookie agreements without human involvement. Always ensure agents are monitored and operate within a controlled environment to prevent unintended consequences.
Moreover, be cautious that MultimodalWebSurfer may be susceptible to prompt injection attacks from webpages.
.. note::
On Windows, the event loop policy must be set to `WindowsProactorEventLoopPolicy` to avoid issues with subprocesses.
.. code-block:: python
import sys
import asyncio
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
Args:
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client used by the agent. Must be multimodal and support function calling.
downloads_folder (str, optional): The folder where downloads are saved. Defaults to None, no downloads are saved.
description (str, optional): The description of the agent. Defaults to MultimodalWebSurfer.DEFAULT_DESCRIPTION.
debug_dir (str, optional): The directory where debug information is saved. Defaults to None.
headless (bool, optional): Whether the browser should be headless. Defaults to True.
start_page (str, optional): The start page for the browser. Defaults to MultimodalWebSurfer.DEFAULT_START_PAGE.
animate_actions (bool, optional): Whether to animate actions. Defaults to False.
to_save_screenshots (bool, optional): Whether to save screenshots. Defaults to False.
use_ocr (bool, optional): Whether to use OCR. Defaults to False.
browser_channel (str, optional): The browser channel. Defaults to None.
browser_data_dir (str, optional): The browser data directory. Defaults to None.
to_resize_viewport (bool, optional): Whether to resize the viewport. Defaults to True.
playwright (Playwright, optional): The playwright instance. Defaults to None.
context (BrowserContext, optional): The browser context. Defaults to None.
Example usage:
The following example demonstrates how to create a web surfing agent with
a model client and run it for multiple turns.
.. code-block:: python
import asyncio
from agentdhal_agentchat.ui import Console
from agentdhal_agentchat.teams import RoundRobinGroupChat
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.agents.web_surfer import MultimodalWebSurfer
async def main() -> None:
# Define an agent
web_surfer_agent = MultimodalWebSurfer(
name="MultimodalWebSurfer",
model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-08-06"),
)
# Define a team
agent_team = RoundRobinGroupChat([web_surfer_agent], max_turns=3)
# Run the team and stream messages to the console
stream = agent_team.run_stream(task="Navigate to the AutoGen readme on GitHub.")
await Console(stream)
# Close the browser controlled by the agent
await web_surfer_agent.close()
asyncio.run(main())
"""
component_type = "agent"
component_config_schema = MultimodalWebSurferConfig
component_provider_override = "agentdhal_extensions.agents.web_surfer.MultimodalWebSurfer"
DEFAULT_DESCRIPTION = """
A helpful assistant with access to a web browser.
Ask them to perform web searches, open pages, and interact with content (e.g., clicking links, scrolling the viewport, filling in form fields, etc.).
It can also summarize the entire page, or answer questions based on the content of the page.
It can also be asked to sleep and wait for pages to load, in cases where the page seems not yet fully loaded.
"""
DEFAULT_START_PAGE = "https://www.bing.com/"
# Viewport dimensions
VIEWPORT_HEIGHT = 900
VIEWPORT_WIDTH = 1440
# Size of the image we send to the MLM
# Current values represent a 0.85 scaling to fit within the GPT-4v short-edge constraints (768px)
MLM_HEIGHT = 765
MLM_WIDTH = 1224
SCREENSHOT_TOKENS = 1105
def __init__(
self,
name: str,
model_client: ChatCompletionClient,
downloads_folder: str | None = None,
description: str = DEFAULT_DESCRIPTION,
debug_dir: str | None = None,
headless: bool = True,
start_page: str | None = DEFAULT_START_PAGE,
animate_actions: bool = False,
to_save_screenshots: bool = False,
use_ocr: bool = False,
browser_channel: str | None = None,
browser_data_dir: str | None = None,
to_resize_viewport: bool = True,
playwright: Playwright | None = None,
context: BrowserContext | None = None,
):
"""
Initialize the MultimodalWebSurfer.
"""
super().__init__(name, description)
if debug_dir is None and to_save_screenshots:
raise ValueError(
"Cannot save screenshots without a debug directory. Set it using the 'debug_dir' parameter. The debug directory is created if it does not exist."
)
if model_client.model_info["function_calling"] is False:
raise ValueError(
"The model does not support function calling. MultimodalWebSurfer requires a model that supports function calling."
)
self._model_client = model_client
self.headless = headless
self.browser_channel = browser_channel
self.browser_data_dir = browser_data_dir
self.start_page = start_page or self.DEFAULT_START_PAGE
self.downloads_folder = downloads_folder
self.debug_dir = debug_dir
self.to_save_screenshots = to_save_screenshots
self.use_ocr = use_ocr
self.to_resize_viewport = to_resize_viewport
self.animate_actions = animate_actions
# Call init to set these in case not set
self._playwright: Playwright | None = playwright
self._context: BrowserContext | None = context
self._page: Page | None = None
self._last_download: Download | None = None
self._prior_metadata_hash: str | None = None
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.name}.MultimodalWebSurfer")
self._chat_history: List[LLMMessage] = []
# Define the download handler
def _download_handler(download: Download) -> None:
self._last_download = download
self._download_handler = _download_handler
# Define the Playwright controller that handles the browser interactions
self._playwright_controller = PlaywrightController(
animate_actions=self.animate_actions,
downloads_folder=self.downloads_folder,
viewport_width=self.VIEWPORT_WIDTH,
viewport_height=self.VIEWPORT_HEIGHT,
_download_handler=self._download_handler,
to_resize_viewport=self.to_resize_viewport,
)
self.default_tools = [
TOOL_VISIT_URL,
TOOL_WEB_SEARCH,
TOOL_HISTORY_BACK,
TOOL_CLICK,
TOOL_TYPE,
TOOL_READ_PAGE_AND_ANSWER,
TOOL_SUMMARIZE_PAGE,
TOOL_SLEEP,
TOOL_HOVER,
]
self.did_lazy_init = False # flag to check if we have initialized the browser
async def _lazy_init(
self,
) -> None:
"""
On the first call, we initialize the browser and the page.
"""
# Check the current event loop policy if on windows.
if sys.platform == "win32":
current_policy = asyncio.get_event_loop_policy()
if hasattr(asyncio, "WindowsProactorEventLoopPolicy") and not isinstance(
current_policy, asyncio.WindowsProactorEventLoopPolicy
):
warnings.warn(
"The current event loop policy is not WindowsProactorEventLoopPolicy. "
"This may cause issues with subprocesses. "
"Try setting the event loop policy to WindowsProactorEventLoopPolicy. "
"For example: `asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())`. "
"See https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.ProactorEventLoop.",
stacklevel=2,
)
self._last_download = None
self._prior_metadata_hash = None
# Create the playwright self
launch_args: Dict[str, Any] = {"headless": self.headless}
if self.browser_channel is not None:
launch_args["channel"] = self.browser_channel
if self._playwright is None:
self._playwright = await async_playwright().start()
# Create the context -- are we launching persistent?
if self._context is None:
if self.browser_data_dir is None:
browser = await self._playwright.chromium.launch(**launch_args)
self._context = await browser.new_context(
user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36 Edg/122.0.0.0"
)
else:
self._context = await self._playwright.chromium.launch_persistent_context(
self.browser_data_dir, **launch_args
)
# Create the page
self._context.set_default_timeout(60000) # One minute
self._page = await self._context.new_page()
assert self._page is not None
# self._page.route(lambda x: True, self._route_handler)
self._page.on("download", self._download_handler)
if self.to_resize_viewport:
await self._page.set_viewport_size({"width": self.VIEWPORT_WIDTH, "height": self.VIEWPORT_HEIGHT})
await self._page.add_init_script(
path=os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js")
)
await self._page.goto(self.start_page)
await self._page.wait_for_load_state()
# Prepare the debug directory -- which stores the screenshots generated throughout the process
await self._set_debug_dir(self.debug_dir)
self.did_lazy_init = True
async def close(self) -> None:
"""
Close the browser and the page.
Should be called when the agent is no longer needed.
"""
if self._page is not None:
await self._page.close()
self._page = None
if self._context is not None:
await self._context.close()
self._context = None
if self._playwright is not None:
await self._playwright.stop()
self._playwright = None
async def _set_debug_dir(self, debug_dir: str | None) -> None:
assert self._page is not None
if self.debug_dir is None:
return
if not os.path.isdir(self.debug_dir):
os.mkdir(self.debug_dir)
if self.to_save_screenshots:
current_timestamp = "_" + int(time.time()).__str__()
screenshot_png_name = "screenshot" + current_timestamp + ".png"
await self._page.screenshot(path=os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore
self.logger.info(
WebSurferEvent(
source=self.name,
url=self._page.url,
message="Screenshot: " + screenshot_png_name,
)
)
@property
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (MultiModalMessage,)
async def on_reset(self, cancellation_token: CancellationToken) -> None:
if not self.did_lazy_init:
return
assert self._page is not None
self._chat_history.clear()
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
self._page, self.start_page
)
if reset_last_download and self._last_download is not None:
self._last_download = None
if reset_prior_metadata and self._prior_metadata_hash is not None:
self._prior_metadata_hash = None
if self.to_save_screenshots:
current_timestamp = "_" + int(time.time()).__str__()
screenshot_png_name = "screenshot" + current_timestamp + ".png"
await self._page.screenshot(path=os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore
self.logger.info(
WebSurferEvent(
source=self.name,
url=self._page.url,
message="Screenshot: " + screenshot_png_name,
)
)
self.logger.info(
WebSurferEvent(
source=self.name,
url=self._page.url,
message="Resetting browser.",
)
)
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
async def on_messages_stream(
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
for chat_message in messages:
self._chat_history.append(chat_message.to_model_message())
self.inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
self.model_usage: List[RequestUsage] = []
try:
content = await self._generate_reply(cancellation_token=cancellation_token)
self._chat_history.append(AssistantMessage(content=content_to_str(content), source=self.name))
final_usage = RequestUsage(
prompt_tokens=sum([u.prompt_tokens for u in self.model_usage]),
completion_tokens=sum([u.completion_tokens for u in self.model_usage]),
)
if isinstance(content, str):
yield Response(
chat_message=TextMessage(content=content, source=self.name, models_usage=final_usage),
inner_messages=self.inner_messages,
)
else:
yield Response(
chat_message=MultiModalMessage(content=content, source=self.name, models_usage=final_usage),
inner_messages=self.inner_messages,
)
except BaseException:
content = f"Web surfing error:\n\n{traceback.format_exc()}"
self._chat_history.append(AssistantMessage(content=content, source=self.name))
yield Response(chat_message=TextMessage(content=content, source=self.name))
async def _generate_reply(self, cancellation_token: CancellationToken) -> UserContent:
"""Generates the actual reply. First calls the LLM to figure out which tool to use, then executes the tool."""
# Lazy init, initialize the browser and the page on the first generate reply only
if not self.did_lazy_init:
await self._lazy_init()
assert self._page is not None
# Clone the messages, removing old screenshots
history: List[LLMMessage] = remove_images(self._chat_history)
# Split the history, removing the last message
if len(history):
user_request = history.pop()
else:
user_request = UserMessage(content="Empty request.", source="user")
# Truncate the history for smaller models
if self._model_client.model_info["family"] not in [
ModelFamily.GPT_4O,
ModelFamily.O1,
ModelFamily.O3,
ModelFamily.GPT_4,
ModelFamily.GPT_35,
]:
history = []
# Ask the page for interactive elements, then prepare the state-of-mark screenshot
rects = await self._playwright_controller.get_interactive_rects(self._page)
viewport = await self._playwright_controller.get_visual_viewport(self._page)
screenshot = await self._page.screenshot()
som_screenshot, visible_rects, rects_above, rects_below = add_set_of_mark(screenshot, rects)
if self.to_save_screenshots:
current_timestamp = "_" + int(time.time()).__str__()
screenshot_png_name = "screenshot_som" + current_timestamp + ".png"
som_screenshot.save(os.path.join(self.debug_dir, screenshot_png_name)) # type: ignore
self.logger.info(
WebSurferEvent(
source=self.name,
url=self._page.url,
message="Screenshot: " + screenshot_png_name,
)
)
# What tools are available?
tools = self.default_tools.copy()
# We can scroll up
if viewport["pageTop"] > 5:
tools.append(TOOL_SCROLL_UP)
# Can scroll down
if (viewport["pageTop"] + viewport["height"] + 5) < viewport["scrollHeight"]:
tools.append(TOOL_SCROLL_DOWN)
# Focus hint
focused = await self._playwright_controller.get_focused_rect_id(self._page)
focused_hint = ""
if focused:
name = self._target_name(focused, rects)
if name:
name = f"(and name '{name}') "
else:
name = ""
role = "control"
try:
role = rects[focused]["role"]
except KeyError:
pass
focused_hint = f"\nThe {role} with ID {focused} {name}currently has the input focus.\n\n"
# Everything visible
visible_targets = "\n".join(self._format_target_list(visible_rects, rects)) + "\n\n"
# Everything else
other_targets: List[str] = []
other_targets.extend(self._format_target_list(rects_above, rects))
other_targets.extend(self._format_target_list(rects_below, rects))
if len(other_targets) > 0:
if len(other_targets) > 30:
other_targets = other_targets[0:30]
other_targets.append("...")
other_targets_str = (
"Additional valid interaction targets include (but are not limited to):\n"
+ "\n".join(other_targets)
+ "\n\n"
)
else:
other_targets_str = ""
state_description = "Your " + await self._get_state_description()
tool_names = "\n".join([t["name"] for t in tools])
page_title = await self._page.title()
prompt_message = None
if self._model_client.model_info["vision"]:
text_prompt = WEB_SURFER_TOOL_PROMPT_MM.format(
state_description=state_description,
visible_targets=visible_targets,
other_targets_str=other_targets_str,
focused_hint=focused_hint,
tool_names=tool_names,
title=page_title,
url=self._page.url,
).strip()
# Scale the screenshot for the MLM, and close the original
scaled_screenshot = som_screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
som_screenshot.close()
if self.to_save_screenshots:
scaled_screenshot.save(os.path.join(self.debug_dir, "screenshot_scaled.png")) # type: ignore
# Create the message
prompt_message = UserMessage(
content=[re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), AGImage.from_pil(scaled_screenshot)],
source=self.name,
)
else:
text_prompt = WEB_SURFER_TOOL_PROMPT_TEXT.format(
state_description=state_description,
visible_targets=visible_targets,
other_targets_str=other_targets_str,
focused_hint=focused_hint,
tool_names=tool_names,
title=page_title,
url=self._page.url,
).strip()
# Create the message
prompt_message = UserMessage(content=re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), source=self.name)
history.append(prompt_message)
history.append(user_request)
# {history[-2].content if isinstance(history[-2].content, str) else history[-2].content[0]}
# print(f"""
# ================={len(history)}=================
# {history[-2].content}
# =====
# {history[-1].content}
# ===================================================
# """)
# Make the request
response = await self._model_client.create(
history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token
) # , "parallel_tool_calls": False})
self.model_usage.append(response.usage)
message = response.content
self._last_download = None
if isinstance(message, str):
# Answer directly
self.inner_messages.append(TextMessage(content=message, source=self.name))
return message
elif isinstance(message, list):
# Take an action
return await self._execute_tool(message, rects, tool_names, cancellation_token=cancellation_token)
else:
# Not sure what happened here
raise AssertionError(f"Unknown response format '{message}'")
async def _execute_tool(
self,
message: List[FunctionCall],
rects: Dict[str, InteractiveRegion],
tool_names: str,
cancellation_token: Optional[CancellationToken] = None,
) -> UserContent:
# Execute the tool
name = message[0].name
args = json.loads(message[0].arguments)
action_description = ""
assert self._page is not None
self.logger.info(
WebSurferEvent(
source=self.name,
url=self._page.url,
action=name,
arguments=args,
message=f"{name}( {json.dumps(args)} )",
)
)
self.inner_messages.append(TextMessage(content=f"{name}( {json.dumps(args)} )", source=self.name))
if name == "visit_url":
url = args.get("url")
action_description = f"I typed '{url}' into the browser address bar."
# Check if the argument starts with a known protocol
if url.startswith(("https://", "http://", "file://", "about:")):
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
self._page, url
)
# If the argument contains a space, treat it as a search query
elif " " in url:
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
self._page, f"https://www.bing.com/search?q={quote_plus(url)}&FORM=QBLH"
)
# Otherwise, prefix with https://
else:
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
self._page, "https://" + url
)
if reset_last_download and self._last_download is not None:
self._last_download = None
if reset_prior_metadata and self._prior_metadata_hash is not None:
self._prior_metadata_hash = None
elif name == "history_back":
action_description = "I clicked the browser back button."
await self._playwright_controller.back(self._page)
elif name == "web_search":
query = args.get("query")
action_description = f"I typed '{query}' into the browser search bar."
reset_prior_metadata, reset_last_download = await self._playwright_controller.visit_page(
self._page, f"https://www.bing.com/search?q={quote_plus(query)}&FORM=QBLH"
)
if reset_last_download and self._last_download is not None:
self._last_download = None
if reset_prior_metadata and self._prior_metadata_hash is not None:
self._prior_metadata_hash = None
elif name == "scroll_up":
action_description = "I scrolled up one page in the browser."
await self._playwright_controller.page_up(self._page)
elif name == "scroll_down":
action_description = "I scrolled down one page in the browser."
await self._playwright_controller.page_down(self._page)
elif name == "click":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I clicked '{target_name}'."
else:
action_description = "I clicked the control."
new_page_tentative = await self._playwright_controller.click_id(self._page, target_id)
if new_page_tentative is not None:
self._page = new_page_tentative
self._prior_metadata_hash = None
self.logger.info(
WebSurferEvent(
source=self.name,
url=self._page.url,
message="New tab or window.",
)
)
elif name == "input_text":
input_field_id = str(args.get("input_field_id"))
text_value = str(args.get("text_value"))
input_field_name = self._target_name(input_field_id, rects)
if input_field_name:
action_description = f"I typed '{text_value}' into '{input_field_name}'."
else:
action_description = f"I input '{text_value}'."
await self._playwright_controller.fill_id(self._page, input_field_id, text_value)
elif name == "scroll_element_up":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I scrolled '{target_name}' up."
else:
action_description = "I scrolled the control up."
await self._playwright_controller.scroll_id(self._page, target_id, "up")
elif name == "scroll_element_down":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I scrolled '{target_name}' down."
else:
action_description = "I scrolled the control down."
await self._playwright_controller.scroll_id(self._page, target_id, "down")
elif name == "answer_question":
question = str(args.get("question"))
action_description = f"I answered the following question '{question}' based on the web page."
# Do Q&A on the DOM. No need to take further action. Browser state does not change.
return await self._summarize_page(question=question, cancellation_token=cancellation_token)
elif name == "summarize_page":
# Summarize the DOM. No need to take further action. Browser state does not change.
action_description = "I summarized the current web page"
return await self._summarize_page(cancellation_token=cancellation_token)
elif name == "hover":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I hovered over '{target_name}'."
else:
action_description = "I hovered over the control."
await self._playwright_controller.hover_id(self._page, target_id)
elif name == "sleep":
action_description = "I am waiting a short period of time before taking further action."
await self._playwright_controller.sleep(self._page, 3)
else:
raise ValueError(f"Unknown tool '{name}'. Please choose from:\n\n{tool_names}")
await self._page.wait_for_load_state()
await self._playwright_controller.sleep(self._page, 3)
# Handle downloads
if self._last_download is not None and self.downloads_folder is not None:
fname = os.path.join(self.downloads_folder, self._last_download.suggested_filename)
await self._last_download.save_as(fname) # type: ignore
page_body = f"<html><head><title>Download Successful</title></head><body style=\"margin: 20px;\"><h1>Successfully downloaded '{self._last_download.suggested_filename}' to local path:<br><br>{fname}</h1></body></html>"
await self._page.goto(
"data:text/html;base64," + base64.b64encode(page_body.encode("utf-8")).decode("utf-8")
)
await self._page.wait_for_load_state()
# Handle metadata
page_metadata = json.dumps(await self._playwright_controller.get_page_metadata(self._page), indent=4)
metadata_hash = hashlib.md5(page_metadata.encode("utf-8")).hexdigest()
if metadata_hash != self._prior_metadata_hash:
page_metadata = (
"\n\nThe following metadata was extracted from the webpage:\n\n" + page_metadata.strip() + "\n"
)
else:
page_metadata = ""
self._prior_metadata_hash = metadata_hash
new_screenshot = await self._page.screenshot()
if self.to_save_screenshots:
current_timestamp = "_" + int(time.time()).__str__()
screenshot_png_name = "screenshot" + current_timestamp + ".png"
async with aiofiles.open(os.path.join(self.debug_dir, screenshot_png_name), "wb") as file: # type: ignore
await file.write(new_screenshot) # type: ignore
self.logger.info(
WebSurferEvent(
source=self.name,
url=self._page.url,
message="Screenshot: " + screenshot_png_name,
)
)
# Return the complete observation
state_description = "The " + await self._get_state_description()
message_content = (
f"{action_description}\n\n" + state_description + page_metadata + "\nHere is a screenshot of the page."
)
return [
re.sub(r"(\n\s*){3,}", "\n\n", message_content), # Removing blank lines
AGImage.from_pil(PIL.Image.open(io.BytesIO(new_screenshot))),
]
async def _get_state_description(self) -> str:
assert self._playwright_controller is not None
assert self._page is not None
# Describe the viewport of the new page in words
viewport = await self._playwright_controller.get_visual_viewport(self._page)
percent_visible = int(viewport["height"] * 100 / viewport["scrollHeight"])
percent_scrolled = int(viewport["pageTop"] * 100 / viewport["scrollHeight"])
if percent_scrolled < 1: # Allow some rounding error
position_text = "at the top of the page"
elif percent_scrolled + percent_visible >= 99: # Allow some rounding error
position_text = "at the bottom of the page"
else:
position_text = str(percent_scrolled) + "% down from the top of the page"
visible_text = await self._playwright_controller.get_visible_text(self._page)
# Return the complete observation
page_title = await self._page.title()
message_content = f"web browser is open to the page [{page_title}]({self._page.url}).\nThe viewport shows {percent_visible}% of the webpage, and is positioned {position_text}\n"
message_content += f"The following text is visible in the viewport:\n\n{visible_text}"
return message_content
def _target_name(self, target: str, rects: Dict[str, InteractiveRegion]) -> str | None:
try:
return rects[target]["aria_name"].strip()
except KeyError:
return None
def _format_target_list(self, ids: List[str], rects: Dict[str, InteractiveRegion]) -> List[str]:
"""
Format the list of targets in the webpage as a string to be used in the agent's prompt.
"""
targets: List[str] = []
for r in list(set(ids)):
if r in rects:
# Get the role
aria_role = rects[r].get("role", "").strip()
if len(aria_role) == 0:
aria_role = rects[r].get("tag_name", "").strip()
# Get the name
aria_name = re.sub(r"[\n\r]+", " ", rects[r].get("aria_name", "")).strip()
# What are the actions?
actions = ['"click", "hover"']
if rects[r]["role"] in ["textbox", "searchbox", "search"]:
actions = ['"input_text"']
actions_str = "[" + ",".join(actions) + "]"
targets.append(f'{{"id": {r}, "name": "{aria_name}", "role": "{aria_role}", "tools": {actions_str} }}')
return targets
async def _summarize_page(
self,
question: str | None = None,
cancellation_token: Optional[CancellationToken] = None,
) -> str:
assert self._page is not None
page_markdown: str = await self._playwright_controller.get_page_markdown(self._page)
title: str = self._page.url
try:
title = await self._page.title()
except Exception:
pass
# Take a screenshot and scale it
screenshot = Image.open(io.BytesIO(await self._page.screenshot()))
scaled_screenshot = screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
screenshot.close()
ag_image = AGImage.from_pil(scaled_screenshot)
# Prepare the system prompt
messages: List[LLMMessage] = []
messages.append(SystemMessage(content=WEB_SURFER_QA_SYSTEM_MESSAGE))
prompt = WEB_SURFER_QA_PROMPT(title, question)
# Grow the buffer (which is added to the prompt) until we overflow the context window or run out of lines
buffer = ""
# for line in re.split(r"([\r\n]+)", page_markdown):
for line in page_markdown.splitlines():
trial_message = UserMessage(
content=prompt + buffer + line,
source=self.name,
)
try:
remaining = self._model_client.remaining_tokens(messages + [trial_message])
except KeyError:
# Use the default if the model isn't found
remaining = DEFAULT_CONTEXT_SIZE - self._model_client.count_tokens(messages + [trial_message])
if self._model_client.model_info["vision"] and remaining <= 0:
break
if self._model_client.model_info["vision"] and remaining <= self.SCREENSHOT_TOKENS:
break
buffer += line
# Nothing to do
buffer = buffer.strip()
if len(buffer) == 0:
return "Nothing to summarize."
# Append the message
if self._model_client.model_info["vision"]:
# Multimodal
messages.append(
UserMessage(
content=[
prompt + buffer,
ag_image,
],
source=self.name,
)
)
else:
# Text only
messages.append(
UserMessage(
content=prompt + buffer,
source=self.name,
)
)
# Generate the response
response = await self._model_client.create(messages, cancellation_token=cancellation_token)
self.model_usage.append(response.usage)
scaled_screenshot.close()
assert isinstance(response.content, str)
return response.content
def _to_config(self) -> MultimodalWebSurferConfig:
return MultimodalWebSurferConfig(
name=self.name,
model_client=self._model_client.dump_component(),
downloads_folder=self.downloads_folder,
description=self.description,
debug_dir=self.debug_dir,
headless=self.headless,
start_page=self.start_page,
animate_actions=self.animate_actions,
to_save_screenshots=self.to_save_screenshots,
use_ocr=self.use_ocr,
browser_channel=self.browser_channel,
browser_data_dir=self.browser_data_dir,
to_resize_viewport=self.to_resize_viewport,
)
@classmethod
def _from_config(cls, config: MultimodalWebSurferConfig) -> Self:
return cls(
name=config.name,
model_client=ChatCompletionClient.load_component(config.model_client),
downloads_folder=config.downloads_folder,
description=config.description or cls.DEFAULT_DESCRIPTION,
debug_dir=config.debug_dir,
headless=config.headless,
start_page=config.start_page or cls.DEFAULT_START_PAGE,
animate_actions=config.animate_actions,
to_save_screenshots=config.to_save_screenshots,
use_ocr=config.use_ocr,
browser_channel=config.browser_channel,
browser_data_dir=config.browser_data_dir,
to_resize_viewport=config.to_resize_viewport,
)

View File

@@ -0,0 +1,52 @@
WEB_SURFER_TOOL_PROMPT_MM = """
{state_description}
Consider the following screenshot of the page. In this screenshot, interactive elements are outlined in bounding boxes of different colors. Each bounding box has a numeric ID label in the same color. Additional information about each visible label is listed below:
{visible_targets}{other_targets_str}{focused_hint}
You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible:
{tool_names}
When deciding between tools, consider if the request can be best addressed by:
- the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate)
- contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate
- on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option)
My request follows:
"""
WEB_SURFER_TOOL_PROMPT_TEXT = """
{state_description}
You have also identified the following interactive components:
{visible_targets}{other_targets_str}{focused_hint}
You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible:
{tool_names}
When deciding between tools, consider if the request can be best addressed by:
- the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate)
- contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate
- on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option)
My request follows:
"""
WEB_SURFER_QA_SYSTEM_MESSAGE = """
You are a helpful assistant that can summarize long documents to answer question.
"""
def WEB_SURFER_QA_PROMPT(title: str, question: str | None = None) -> str:
base_prompt = f"We are visiting the webpage '{title}'. Its full-text content are pasted below, along with a screenshot of the page's current viewport."
if question is not None:
return (
f"{base_prompt} Please summarize the webpage into one or two paragraphs with respect to '{question}':\n\n"
)
else:
return f"{base_prompt} Please summarize the webpage into one or two paragraphs:\n\n"

View File

@@ -0,0 +1,96 @@
import io
import random
from typing import BinaryIO, Dict, List, Tuple, cast
from PIL import Image, ImageDraw, ImageFont
from ._types import DOMRectangle, InteractiveRegion
TOP_NO_LABEL_ZONE = 20 # Don't print any labels close the top of the page
def add_set_of_mark(
screenshot: bytes | Image.Image | io.BufferedIOBase, ROIs: Dict[str, InteractiveRegion]
) -> Tuple[Image.Image, List[str], List[str], List[str]]:
if isinstance(screenshot, Image.Image):
return _add_set_of_mark(screenshot, ROIs)
if isinstance(screenshot, bytes):
screenshot = io.BytesIO(screenshot)
# TODO: Not sure why this cast was needed, but by this point screenshot is a binary file-like object
image = Image.open(cast(BinaryIO, screenshot))
comp, visible_rects, rects_above, rects_below = _add_set_of_mark(image, ROIs)
image.close()
return comp, visible_rects, rects_above, rects_below
def _add_set_of_mark(
screenshot: Image.Image, ROIs: Dict[str, InteractiveRegion]
) -> Tuple[Image.Image, List[str], List[str], List[str]]:
visible_rects: List[str] = list()
rects_above: List[str] = list() # Scroll up to see
rects_below: List[str] = list() # Scroll down to see
fnt = ImageFont.load_default(14)
base = screenshot.convert("L").convert("RGBA")
overlay = Image.new("RGBA", base.size)
draw = ImageDraw.Draw(overlay)
for r in ROIs:
for rect in ROIs[r]["rects"]:
# Empty rectangles
if not rect:
continue
if rect["width"] * rect["height"] == 0:
continue
mid = ((rect["right"] + rect["left"]) / 2.0, (rect["top"] + rect["bottom"]) / 2.0)
if 0 <= mid[0] and mid[0] < base.size[0]:
if mid[1] < 0:
rects_above.append(r)
elif mid[1] >= base.size[1]:
rects_below.append(r)
else:
visible_rects.append(r)
_draw_roi(draw, int(r), fnt, rect)
comp = Image.alpha_composite(base, overlay)
overlay.close()
return comp, visible_rects, rects_above, rects_below
def _draw_roi(
draw: ImageDraw.ImageDraw, idx: int, font: ImageFont.FreeTypeFont | ImageFont.ImageFont, rect: DOMRectangle
) -> None:
color = _color(idx)
luminance = color[0] * 0.3 + color[1] * 0.59 + color[2] * 0.11
text_color = (0, 0, 0, 255) if luminance > 90 else (255, 255, 255, 255)
roi = ((rect["left"], rect["top"]), (rect["right"], rect["bottom"]))
label_location = (rect["right"], rect["top"])
label_anchor = "rb"
if label_location[1] <= TOP_NO_LABEL_ZONE:
label_location = (rect["right"], rect["bottom"])
label_anchor = "rt"
draw.rectangle(roi, outline=color, fill=(color[0], color[1], color[2], 48), width=2)
# TODO: Having trouble with these types being partially Unknown.
bbox = draw.textbbox(label_location, str(idx), font=font, anchor=label_anchor, align="center") # type: ignore
bbox = (bbox[0] - 3, bbox[1] - 3, bbox[2] + 3, bbox[3] + 3)
draw.rectangle(bbox, fill=color)
# TODO: Having trouble with these types being partially Unknown.
draw.text(label_location, str(idx), fill=text_color, font=font, anchor=label_anchor, align="center") # type: ignore
def _color(identifier: int) -> Tuple[int, int, int, int]:
rnd = random.Random(int(identifier))
color = [rnd.randint(0, 255), rnd.randint(125, 255), rnd.randint(0, 50)]
rnd.shuffle(color)
color.append(255)
return cast(Tuple[int, int, int, int], tuple(color))

View File

@@ -0,0 +1,317 @@
from typing import Any, Dict
from agentdhal_core.tools._base import ParametersSchema, ToolSchema
def _load_tool(tooldef: Dict[str, Any]) -> ToolSchema:
return ToolSchema(
name=tooldef["function"]["name"],
description=tooldef["function"]["description"],
parameters=ParametersSchema(
type="object",
properties=tooldef["function"]["parameters"]["properties"],
required=tooldef["function"]["parameters"]["required"],
),
)
REASONING_TOOL_PROMPT = (
"A short description of the action to be performed and reason for doing so, do not mention the user."
)
TOOL_VISIT_URL: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "visit_url",
"description": "Navigate directly to a provided URL using the browser's address bar. Prefer this tool over other navigation techniques in cases where the user provides a fully-qualified URL (e.g., choose it over clicking links, or inputing queries into search boxes).",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"url": {
"type": "string",
"description": "The URL to visit in the browser.",
},
},
"required": ["reasoning", "url"],
},
},
}
)
TOOL_WEB_SEARCH: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "web_search",
"description": "Performs a web search on Bing.com with the given query.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"query": {
"type": "string",
"description": "The web search query to use.",
},
},
"required": ["reasoning", "query"],
},
},
}
)
TOOL_HISTORY_BACK: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "history_back",
"description": "Navigates back one page in the browser's history. This is equivalent to clicking the browser back button.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
},
"required": ["reasoning"],
},
},
}
)
TOOL_SCROLL_UP: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "scroll_up",
"description": "Scrolls the entire browser viewport one page UP towards the beginning.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
},
"required": ["reasoning"],
},
},
}
)
TOOL_SCROLL_DOWN: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "scroll_down",
"description": "Scrolls the entire browser viewport one page DOWN towards the end.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
},
"required": ["reasoning"],
},
},
}
)
TOOL_CLICK: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "click",
"description": "Clicks the mouse on the target with the given id.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"target_id": {
"type": "integer",
"description": "The numeric id of the target to click.",
},
},
"required": ["reasoning", "target_id"],
},
},
}
)
TOOL_TYPE: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "input_text",
"description": "Types the given text value into the specified field.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"input_field_id": {
"type": "integer",
"description": "The numeric id of the input field to receive the text.",
},
"text_value": {
"type": "string",
"description": "The text to type into the input field.",
},
},
"required": ["reasoning", "input_field_id", "text_value"],
},
},
}
)
TOOL_SCROLL_ELEMENT_DOWN: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "scroll_element_down",
"description": "Scrolls a given html element (e.g., a div or a menu) DOWN.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"target_id": {
"type": "integer",
"description": "The numeric id of the target to scroll down.",
},
},
"required": ["reasoning", "target_id"],
},
},
}
)
TOOL_SCROLL_ELEMENT_UP: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "scroll_element_up",
"description": "Scrolls a given html element (e.g., a div or a menu) UP.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"target_id": {
"type": "integer",
"description": "The numeric id of the target to scroll UP.",
},
},
"required": ["reasoning", "target_id"],
},
},
}
)
TOOL_HOVER: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "hover",
"description": "Hovers the mouse over the target with the given id.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"target_id": {
"type": "integer",
"description": "The numeric id of the target to hover over.",
},
},
"required": ["reasoning", "target_id"],
},
},
}
)
TOOL_READ_PAGE_AND_ANSWER: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "answer_question",
"description": "Uses AI to answer a question about the current webpage's content.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
"question": {
"type": "string",
"description": "The question to answer.",
},
},
"required": ["reasoning", "question"],
},
},
}
)
TOOL_SUMMARIZE_PAGE: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "summarize_page",
"description": "Uses AI to summarize the entire page.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
},
"required": ["reasoning"],
},
},
}
)
TOOL_SLEEP: ToolSchema = _load_tool(
{
"type": "function",
"function": {
"name": "sleep",
"description": "Wait a short period of time. Call this function if the page has not yet fully loaded, or if it is determined that a small delay would increase the task's chances of success.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": REASONING_TOOL_PROMPT,
},
},
"required": ["reasoning"],
},
},
}
)

View File

@@ -0,0 +1,106 @@
from typing import Any, Dict, List, TypedDict, Union
from agentdhal_core import FunctionCall, Image
from agentdhal_core.models import FunctionExecutionResult
UserContent = Union[str, List[Union[str, Image]]]
AssistantContent = Union[str, List[FunctionCall]]
FunctionExecutionContent = List[FunctionExecutionResult]
SystemContent = str
class DOMRectangle(TypedDict):
x: Union[int, float]
y: Union[int, float]
width: Union[int, float]
height: Union[int, float]
top: Union[int, float]
right: Union[int, float]
bottom: Union[int, float]
left: Union[int, float]
class VisualViewport(TypedDict):
height: Union[int, float]
width: Union[int, float]
offsetLeft: Union[int, float]
offsetTop: Union[int, float]
pageLeft: Union[int, float]
pageTop: Union[int, float]
scale: Union[int, float]
clientWidth: Union[int, float]
clientHeight: Union[int, float]
scrollWidth: Union[int, float]
scrollHeight: Union[int, float]
class InteractiveRegion(TypedDict):
tag_name: str
role: str
aria_name: str
v_scrollable: bool
rects: List[DOMRectangle]
# Helper functions for dealing with JSON. Not sure there's a better way?
def _get_str(d: Any, k: str) -> str:
val = d[k]
assert isinstance(val, str)
return val
def _get_number(d: Any, k: str) -> Union[int, float]:
val = d[k]
assert isinstance(val, int) or isinstance(val, float)
return val
def _get_bool(d: Any, k: str) -> bool:
val = d[k]
assert isinstance(val, bool)
return val
def domrectangle_from_dict(rect: Dict[str, Any]) -> DOMRectangle:
return DOMRectangle(
x=_get_number(rect, "x"),
y=_get_number(rect, "y"),
width=_get_number(rect, "width"),
height=_get_number(rect, "height"),
top=_get_number(rect, "top"),
right=_get_number(rect, "right"),
bottom=_get_number(rect, "bottom"),
left=_get_number(rect, "left"),
)
def interactiveregion_from_dict(region: Dict[str, Any]) -> InteractiveRegion:
typed_rects: List[DOMRectangle] = []
for rect in region["rects"]:
typed_rects.append(domrectangle_from_dict(rect))
return InteractiveRegion(
tag_name=_get_str(region, "tag_name"),
role=_get_str(region, "role"),
aria_name=_get_str(region, "aria-name"),
v_scrollable=_get_bool(region, "v-scrollable"),
rects=typed_rects,
)
def visualviewport_from_dict(viewport: Dict[str, Any]) -> VisualViewport:
return VisualViewport(
height=_get_number(viewport, "height"),
width=_get_number(viewport, "width"),
offsetLeft=_get_number(viewport, "offsetLeft"),
offsetTop=_get_number(viewport, "offsetTop"),
pageLeft=_get_number(viewport, "pageLeft"),
pageTop=_get_number(viewport, "pageTop"),
scale=_get_number(viewport, "scale"),
clientWidth=_get_number(viewport, "clientWidth"),
clientHeight=_get_number(viewport, "clientHeight"),
scrollWidth=_get_number(viewport, "scrollWidth"),
scrollHeight=_get_number(viewport, "scrollHeight"),
)

View File

@@ -0,0 +1,429 @@
var MultimodalWebSurfer = MultimodalWebSurfer || (function() {
let nextLabel = 10;
let roleMapping = {
"a": "link",
"area": "link",
"button": "button",
"input, type=button": "button",
"input, type=checkbox": "checkbox",
"input, type=email": "textbox",
"input, type=number": "spinbutton",
"input, type=radio": "radio",
"input, type=range": "slider",
"input, type=reset": "button",
"input, type=search": "searchbox",
"input, type=submit": "button",
"input, type=tel": "textbox",
"input, type=text": "textbox",
"input, type=url": "textbox",
"search": "search",
"select": "combobox",
"option": "option",
"textarea": "textbox"
};
let getCursor = function(elm) {
return window.getComputedStyle(elm)["cursor"];
};
let getInteractiveElements = function() {
let results = []
let roles = ["scrollbar", "searchbox", "slider", "spinbutton", "switch", "tab", "treeitem", "button", "checkbox", "gridcell", "link", "menuitem", "menuitemcheckbox", "menuitemradio", "option", "progressbar", "radio", "textbox", "combobox", "menu", "tree", "treegrid", "grid", "listbox", "radiogroup", "widget"];
let inertCursors = ["auto", "default", "none", "text", "vertical-text", "not-allowed", "no-drop"];
// Get the main interactive elements
let nodeList = document.querySelectorAll("input, select, textarea, button, [href], [onclick], [contenteditable], [tabindex]:not([tabindex='-1'])");
for (let i=0; i<nodeList.length; i++) { // Copy to something mutable
results.push(nodeList[i]);
}
// Anything not already included that has a suitable role
nodeList = document.querySelectorAll("[role]");
for (let i=0; i<nodeList.length; i++) { // Copy to something mutable
if (results.indexOf(nodeList[i]) == -1) {
let role = nodeList[i].getAttribute("role");
if (roles.indexOf(role) > -1) {
results.push(nodeList[i]);
}
}
}
// Any element that changes the cursor to something implying interactivity
nodeList = document.querySelectorAll("*");
for (let i=0; i<nodeList.length; i++) {
let node = nodeList[i];
// Cursor is default, or does not suggest interactivity
let cursor = getCursor(node);
if (inertCursors.indexOf(cursor) >= 0) {
continue;
}
// Move up to the first instance of this cursor change
parent = node.parentNode;
while (parent && getCursor(parent) == cursor) {
node = parent;
parent = node.parentNode;
}
// Add the node if it is new
if (results.indexOf(node) == -1) {
results.push(node);
}
}
return results;
};
let labelElements = function(elements) {
for (let i=0; i<elements.length; i++) {
if (!elements[i].hasAttribute("__elementId")) {
elements[i].setAttribute("__elementId", "" + (nextLabel++));
}
}
};
let isTopmost = function(element, x, y) {
let hit = document.elementFromPoint(x, y);
// Hack to handle elements outside the viewport
if (hit === null) {
return true;
}
while (hit) {
if (hit == element) return true;
hit = hit.parentNode;
}
return false;
};
let getFocusedElementId = function() {
let elm = document.activeElement;
while (elm) {
if (elm.hasAttribute && elm.hasAttribute("__elementId")) {
return elm.getAttribute("__elementId");
}
elm = elm.parentNode;
}
return null;
};
let trimmedInnerText = function(element) {
if (!element) {
return "";
}
let text = element.innerText;
if (!text) {
return "";
}
return text.trim();
};
let getApproximateAriaName = function(element) {
// Check for aria labels
if (element.hasAttribute("aria-labelledby")) {
let buffer = "";
let ids = element.getAttribute("aria-labelledby").split(" ");
for (let i=0; i<ids.length; i++) {
let label = document.getElementById(ids[i]);
if (label) {
buffer = buffer + " " + trimmedInnerText(label);
}
}
return buffer.trim();
}
if (element.hasAttribute("aria-label")) {
return element.getAttribute("aria-label");
}
// Check for labels
if (element.hasAttribute("id")) {
let label_id = element.getAttribute("id");
let label = "";
let labels = document.querySelectorAll("label[for='" + label_id + "']");
for (let j=0; j<labels.length; j++) {
label += labels[j].innerText + " ";
}
label = label.trim();
if (label != "") {
return label;
}
}
if (element.parentElement && element.parentElement.tagName == "LABEL") {
return element.parentElement.innerText;
}
// Check for alt text or titles
if (element.hasAttribute("alt")) {
return element.getAttribute("alt")
}
if (element.hasAttribute("title")) {
return element.getAttribute("title")
}
return trimmedInnerText(element);
};
let getApproximateAriaRole = function(element) {
let tag = element.tagName.toLowerCase();
if (tag == "input" && element.hasAttribute("type")) {
tag = tag + ", type=" + element.getAttribute("type");
}
if (element.hasAttribute("role")) {
return [element.getAttribute("role"), tag];
}
else if (tag in roleMapping) {
return [roleMapping[tag], tag];
}
else {
return ["", tag];
}
};
let getInteractiveRects = function() {
labelElements(getInteractiveElements());
let elements = document.querySelectorAll("[__elementId]");
let results = {};
for (let i=0; i<elements.length; i++) {
let key = elements[i].getAttribute("__elementId");
let rects = elements[i].getClientRects();
let ariaRole = getApproximateAriaRole(elements[i]);
let ariaName = getApproximateAriaName(elements[i]);
let vScrollable = elements[i].scrollHeight - elements[i].clientHeight >= 1;
let record = {
"tag_name": ariaRole[1],
"role": ariaRole[0],
"aria-name": ariaName,
"v-scrollable": vScrollable,
"rects": []
};
for (const rect of rects) {
let x = rect.left + rect.width/2;
let y = rect.top + rect.height/2;
if (isTopmost(elements[i], x, y)) {
record["rects"].push(JSON.parse(JSON.stringify(rect)));
}
}
if (record["rects"].length > 0) {
results[key] = record;
}
}
return results;
};
let getVisualViewport = function() {
let vv = window.visualViewport;
let de = document.documentElement;
return {
"height": vv ? vv.height : 0,
"width": vv ? vv.width : 0,
"offsetLeft": vv ? vv.offsetLeft : 0,
"offsetTop": vv ? vv.offsetTop : 0,
"pageLeft": vv ? vv.pageLeft : 0,
"pageTop": vv ? vv.pageTop : 0,
"scale": vv ? vv.scale : 0,
"clientWidth": de ? de.clientWidth : 0,
"clientHeight": de ? de.clientHeight : 0,
"scrollWidth": de ? de.scrollWidth : 0,
"scrollHeight": de ? de.scrollHeight : 0
};
};
let _getMetaTags = function() {
let meta = document.querySelectorAll("meta");
let results = {};
for (let i = 0; i<meta.length; i++) {
let key = null;
if (meta[i].hasAttribute("name")) {
key = meta[i].getAttribute("name");
}
else if (meta[i].hasAttribute("property")) {
key = meta[i].getAttribute("property");
}
else {
continue;
}
if (meta[i].hasAttribute("content")) {
results[key] = meta[i].getAttribute("content");
}
}
return results;
};
let _getJsonLd = function() {
let jsonld = [];
let scripts = document.querySelectorAll('script[type="application/ld+json"]');
for (let i=0; i<scripts.length; i++) {
jsonld.push(scripts[i].innerHTML.trim());
}
return jsonld;
};
// From: https://www.stevefenton.co.uk/blog/2022/12/parse-microdata-with-javascript/
let _getMicrodata = function() {
function sanitize(input) {
return input.replace(/\s/gi, ' ').trim();
}
function addValue(information, name, value) {
if (information[name]) {
if (typeof information[name] === 'array') {
information[name].push(value);
} else {
const arr = [];
arr.push(information[name]);
arr.push(value);
information[name] = arr;
}
} else {
information[name] = value;
}
}
function traverseItem(item, information) {
const children = item.children;
for (let i = 0; i < children.length; i++) {
const child = children[i];
if (child.hasAttribute('itemscope')) {
if (child.hasAttribute('itemprop')) {
const itemProp = child.getAttribute('itemprop');
const itemType = child.getAttribute('itemtype');
const childInfo = {
itemType: itemType
};
traverseItem(child, childInfo);
itemProp.split(' ').forEach(propName => {
addValue(information, propName, childInfo);
});
}
} else if (child.hasAttribute('itemprop')) {
const itemProp = child.getAttribute('itemprop');
itemProp.split(' ').forEach(propName => {
if (propName === 'url') {
addValue(information, propName, child.href);
} else {
addValue(information, propName, sanitize(child.getAttribute("content") || child.content || child.textContent || child.src || ""));
}
});
traverseItem(child, information);
} else {
traverseItem(child, information);
}
}
}
const microdata = [];
document.querySelectorAll("[itemscope]").forEach(function(elem, i) {
const itemType = elem.getAttribute('itemtype');
const information = {
itemType: itemType
};
traverseItem(elem, information);
microdata.push(information);
});
return microdata;
};
let getPageMetadata = function() {
let jsonld = _getJsonLd();
let metaTags = _getMetaTags();
let microdata = _getMicrodata();
let results = {}
if (jsonld.length > 0) {
try {
results["jsonld"] = JSON.parse(jsonld);
}
catch (e) {
results["jsonld"] = jsonld;
}
}
if (microdata.length > 0) {
results["microdata"] = microdata;
}
for (let key in metaTags) {
if (metaTags.hasOwnProperty(key)) {
results["meta_tags"] = metaTags;
break;
}
}
return results;
};
let getVisibleText = function() {
// Get the windows current viewport boundaries
const viewportHeight = window.innerHeight || document.documentElement.clientHeight;
const viewportWidth = window.innerWidth || document.documentElement.clientWidth;
let textInView = "";
const walker = document.createTreeWalker(
document.body,
NodeFilter.SHOW_TEXT,
null,
false
);
while (walker.nextNode()) {
const textNode = walker.currentNode;
// Create a range to retrieve bounding rectangles of the current text node
const range = document.createRange();
range.selectNodeContents(textNode);
const rects = range.getClientRects();
// Check if any rect is inside (or partially inside) the viewport
for (const rect of rects) {
const isVisible =
rect.width > 0 &&
rect.height > 0 &&
rect.bottom >= 0 &&
rect.right >= 0 &&
rect.top <= viewportHeight &&
rect.left <= viewportWidth;
if (isVisible) {
textInView += textNode.nodeValue.replace(/\s+/g, " ");
// Is the parent a block element?
if (textNode.parentNode) {
const parent = textNode.parentNode;
const style = window.getComputedStyle(parent);
if (["inline", "hidden", "none"].indexOf(style.display) === -1) {
textInView += "\n";
}
}
break; // No need to check other rects once found visible
}
}
}
// Remove blank lines from textInView
textInView = textInView.replace(/^\s*\n/gm, "").trim().replace(/\n+/g, "\n");
return textInView;
};
return {
getInteractiveRects: getInteractiveRects,
getVisualViewport: getVisualViewport,
getFocusedElementId: getFocusedElementId,
getPageMetadata: getPageMetadata,
getVisibleText: getVisibleText,
};
})();

View File

@@ -0,0 +1,578 @@
import asyncio
import base64
import io
import os
import random
import warnings
from types import ModuleType
from typing import Any, Callable, Dict, Optional, Tuple, Union, cast
from playwright._impl._errors import Error as PlaywrightError
from playwright._impl._errors import TimeoutError
from playwright.async_api import Download, Page
from ._types import (
InteractiveRegion,
VisualViewport,
interactiveregion_from_dict,
visualviewport_from_dict,
)
markitdown: ModuleType | None = None
try:
# Suppress warnings from markitdown -- which is pretty chatty
warnings.filterwarnings(action="ignore", module="markitdown")
import markitdown
except ImportError:
pass
class PlaywrightController:
"""
A helper class to allow Playwright to interact with web pages to perform actions such as clicking, filling, and scrolling.
Args:
downloads_folder (str | None): The folder to save downloads to. If None, downloads are not saved.
animate_actions (bool): Whether to animate the actions (create fake cursor to click).
viewport_width (int): The width of the viewport.
viewport_height (int): The height of the viewport.
_download_handler (Optional[Callable[[Download], None]]): A function to handle downloads.
to_resize_viewport (bool): Whether to resize the viewport
"""
def __init__(
self,
downloads_folder: str | None = None,
animate_actions: bool = False,
viewport_width: int = 1440,
viewport_height: int = 900,
_download_handler: Optional[Callable[[Download], None]] = None,
to_resize_viewport: bool = True,
) -> None:
"""
Initialize the PlaywrightController.
"""
assert isinstance(animate_actions, bool)
assert isinstance(viewport_width, int)
assert isinstance(viewport_height, int)
assert viewport_height > 0
assert viewport_width > 0
self.animate_actions = animate_actions
self.downloads_folder = downloads_folder
self.viewport_width = viewport_width
self.viewport_height = viewport_height
self._download_handler = _download_handler
self.to_resize_viewport = to_resize_viewport
self._page_script: str = ""
self.last_cursor_position: Tuple[float, float] = (0.0, 0.0)
self._markdown_converter: Optional[Any] | None = None
# Read page_script
with open(
os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js"), "rt", encoding="utf-8"
) as fh:
self._page_script = fh.read()
async def sleep(self, page: Page, duration: Union[int, float]) -> None:
"""
Pause the execution for a specified duration.
Args:
page (Page): The Playwright page object.
duration (Union[int, float]): The duration to sleep in milliseconds.
"""
assert page is not None
await page.wait_for_timeout(duration * 1000)
async def get_interactive_rects(self, page: Page) -> Dict[str, InteractiveRegion]:
"""
Retrieve interactive regions from the web page.
Args:
page (Page): The Playwright page object.
Returns:
Dict[str, InteractiveRegion]: A dictionary of interactive regions.
"""
assert page is not None
# Read the regions from the DOM
try:
await page.evaluate(self._page_script)
except Exception:
pass
result = cast(Dict[str, Dict[str, Any]], await page.evaluate("MultimodalWebSurfer.getInteractiveRects();"))
# Convert the results into appropriate types
assert isinstance(result, dict)
typed_results: Dict[str, InteractiveRegion] = {}
for k in result:
assert isinstance(k, str)
typed_results[k] = interactiveregion_from_dict(result[k])
return typed_results
async def get_visual_viewport(self, page: Page) -> VisualViewport:
"""
Retrieve the visual viewport of the web page.
Args:
page (Page): The Playwright page object.
Returns:
VisualViewport: The visual viewport of the page.
"""
assert page is not None
try:
await page.evaluate(self._page_script)
except Exception:
pass
return visualviewport_from_dict(await page.evaluate("MultimodalWebSurfer.getVisualViewport();"))
async def get_focused_rect_id(self, page: Page) -> str | None:
"""
Retrieve the ID of the currently focused element.
Args:
page (Page): The Playwright page object.
Returns:
str: The ID of the focused element or None if no control has focus.
"""
assert page is not None
try:
await page.evaluate(self._page_script)
except Exception:
pass
result = await page.evaluate("MultimodalWebSurfer.getFocusedElementId();")
return None if result is None else str(result)
async def get_page_metadata(self, page: Page) -> Dict[str, Any]:
"""
Retrieve metadata from the web page.
Args:
page (Page): The Playwright page object.
Returns:
Dict[str, Any]: A dictionary of page metadata.
"""
assert page is not None
try:
await page.evaluate(self._page_script)
except Exception:
pass
result = await page.evaluate("MultimodalWebSurfer.getPageMetadata();")
assert isinstance(result, dict)
return cast(Dict[str, Any], result)
async def on_new_page(self, page: Page) -> None:
"""
Handle actions to perform on a new page.
Args:
page (Page): The Playwright page object.
"""
assert page is not None
page.on("download", self._download_handler) # type: ignore
if self.to_resize_viewport and self.viewport_width and self.viewport_height:
await page.set_viewport_size({"width": self.viewport_width, "height": self.viewport_height})
await self.sleep(page, 0.2)
await page.add_init_script(path=os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js"))
await page.wait_for_load_state()
async def back(self, page: Page) -> None:
"""
Navigate back to the previous page.
Args:
page (Page): The Playwright page object.
"""
assert page is not None
await page.go_back()
async def visit_page(self, page: Page, url: str) -> Tuple[bool, bool]:
"""
Visit a specified URL.
Args:
page (Page): The Playwright page object.
url (str): The URL to visit.
Returns:
Tuple[bool, bool]: A tuple indicating whether to reset prior metadata hash and last download.
"""
assert page is not None
reset_prior_metadata_hash = False
reset_last_download = False
try:
# Regular webpage
await page.goto(url)
await page.wait_for_load_state()
reset_prior_metadata_hash = True
except Exception as e_outer:
# Downloaded file
if self.downloads_folder and "net::ERR_ABORTED" in str(e_outer):
async with page.expect_download() as download_info:
try:
await page.goto(url)
except Exception as e_inner:
if "net::ERR_ABORTED" in str(e_inner):
pass
else:
raise e_inner
download = await download_info.value
fname = os.path.join(self.downloads_folder, download.suggested_filename)
await download.save_as(fname)
message = f"<body style=\"margin: 20px;\"><h1>Successfully downloaded '{download.suggested_filename}' to local path:<br><br>{fname}</h1></body>"
await page.goto(
"data:text/html;base64," + base64.b64encode(message.encode("utf-8")).decode("utf-8")
)
reset_last_download = True
else:
raise e_outer
return reset_prior_metadata_hash, reset_last_download
async def page_down(self, page: Page) -> None:
"""
Scroll the page down by one viewport height minus 50 pixels.
Args:
page (Page): The Playwright page object.
"""
assert page is not None
await page.evaluate(f"window.scrollBy(0, {self.viewport_height-50});")
async def page_up(self, page: Page) -> None:
"""
Scroll the page up by one viewport height minus 50 pixels.
Args:
page (Page): The Playwright page object.
"""
assert page is not None
await page.evaluate(f"window.scrollBy(0, -{self.viewport_height-50});")
async def gradual_cursor_animation(
self, page: Page, start_x: float, start_y: float, end_x: float, end_y: float
) -> None:
"""
Animate the cursor movement gradually from start to end coordinates.
Args:
page (Page): The Playwright page object.
start_x (float): The starting x-coordinate.
start_y (float): The starting y-coordinate.
end_x (float): The ending x-coordinate.
end_y (float): The ending y-coordinate.
"""
# animation helper
steps = 20
for step in range(steps):
x = start_x + (end_x - start_x) * (step / steps)
y = start_y + (end_y - start_y) * (step / steps)
# await page.mouse.move(x, y, steps=1)
await page.evaluate(f"""
(function() {{
let cursor = document.getElementById('red-cursor');
cursor.style.left = '{x}px';
cursor.style.top = '{y}px';
}})();
""")
await asyncio.sleep(0.05)
self.last_cursor_position = (end_x, end_y)
async def add_cursor_box(self, page: Page, identifier: str) -> None:
"""
Add a red cursor box around the element with the given identifier.
Args:
page (Page): The Playwright page object.
identifier (str): The element identifier.
"""
# animation helper
await page.evaluate(f"""
(function() {{
let elm = document.querySelector("[__elementId='{identifier}']");
if (elm) {{
elm.style.transition = 'border 0.3s ease-in-out';
elm.style.border = '2px solid red';
}}
}})();
""")
await asyncio.sleep(0.3)
# Create a red cursor
await page.evaluate("""
(function() {
let cursor = document.createElement('div');
cursor.id = 'red-cursor';
cursor.style.width = '10px';
cursor.style.height = '10px';
cursor.style.backgroundColor = 'red';
cursor.style.position = 'absolute';
cursor.style.borderRadius = '50%';
cursor.style.zIndex = '10000';
document.body.appendChild(cursor);
})();
""")
async def remove_cursor_box(self, page: Page, identifier: str) -> None:
"""
Remove the red cursor box around the element with the given identifier.
Args:
page (Page): The Playwright page object.
identifier (str): The element identifier.
"""
# Remove the highlight and cursor
await page.evaluate(f"""
(function() {{
let elm = document.querySelector("[__elementId='{identifier}']");
if (elm) {{
elm.style.border = '';
}}
let cursor = document.getElementById('red-cursor');
if (cursor) {{
cursor.remove();
}}
}})();
""")
async def click_id(self, page: Page, identifier: str) -> Page | None:
"""
Click the element with the given identifier.
Args:
page (Page): The Playwright page object.
identifier (str): The element identifier.
Returns:
Page | None: The new page if a new page is opened, otherwise None.
"""
new_page: Page | None = None
assert page is not None
target = page.locator(f"[__elementId='{identifier}']")
# See if it exists
try:
await target.wait_for(timeout=5000)
except TimeoutError:
raise ValueError("No such element.") from None
# Click it
await target.scroll_into_view_if_needed()
await asyncio.sleep(0.3)
box = cast(Dict[str, Union[int, float]], await target.bounding_box())
if self.animate_actions:
await self.add_cursor_box(page, identifier)
# Move cursor to the box slowly
start_x, start_y = self.last_cursor_position
end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2
await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y)
await asyncio.sleep(0.1)
try:
# Give it a chance to open a new page
async with page.expect_event("popup", timeout=1000) as page_info: # type: ignore
await page.mouse.click(end_x, end_y, delay=10)
new_page = await page_info.value # type: ignore
assert isinstance(new_page, Page)
await self.on_new_page(new_page)
except TimeoutError:
pass
await self.remove_cursor_box(page, identifier)
else:
try:
# Give it a chance to open a new page
async with page.expect_event("popup", timeout=1000) as page_info: # type: ignore
await page.mouse.click(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2, delay=10)
new_page = await page_info.value # type: ignore
assert isinstance(new_page, Page)
await self.on_new_page(new_page)
except TimeoutError:
pass
return new_page # type: ignore
async def hover_id(self, page: Page, identifier: str) -> None:
"""
Hover the mouse over the element with the given identifier.
Args:
page (Page): The Playwright page object.
identifier (str): The element identifier.
"""
assert page is not None
target = page.locator(f"[__elementId='{identifier}']")
# See if it exists
try:
await target.wait_for(timeout=5000)
except TimeoutError:
raise ValueError("No such element.") from None
# Hover over it
await target.scroll_into_view_if_needed()
await asyncio.sleep(0.3)
box = cast(Dict[str, Union[int, float]], await target.bounding_box())
if self.animate_actions:
await self.add_cursor_box(page, identifier)
# Move cursor to the box slowly
start_x, start_y = self.last_cursor_position
end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2
await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y)
await asyncio.sleep(0.1)
await page.mouse.move(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2)
await self.remove_cursor_box(page, identifier)
else:
await page.mouse.move(box["x"] + box["width"] / 2, box["y"] + box["height"] / 2)
async def fill_id(self, page: Page, identifier: str, value: str, press_enter: bool = True) -> None:
"""
Fill the element with the given identifier with the specified value.
Args:
page (Page): The Playwright page object.
identifier (str): The element identifier.
value (str): The value to fill.
"""
assert page is not None
target = page.locator(f"[__elementId='{identifier}']")
# See if it exists
try:
await target.wait_for(timeout=5000)
except TimeoutError:
raise ValueError("No such element.") from None
# Fill it
await target.scroll_into_view_if_needed()
box = cast(Dict[str, Union[int, float]], await target.bounding_box())
if self.animate_actions:
await self.add_cursor_box(page, identifier)
# Move cursor to the box slowly
start_x, start_y = self.last_cursor_position
end_x, end_y = box["x"] + box["width"] / 2, box["y"] + box["height"] / 2
await self.gradual_cursor_animation(page, start_x, start_y, end_x, end_y)
await asyncio.sleep(0.1)
# Focus on the element
await target.focus()
if self.animate_actions:
# fill char by char to mimic human speed for short text and type fast for long text
if len(value) < 100:
delay_typing_speed = 50 + 100 * random.random()
else:
delay_typing_speed = 10
await target.press_sequentially(value, delay=delay_typing_speed)
else:
try:
await target.fill(value)
except PlaywrightError:
await target.press_sequentially(value)
if press_enter:
await target.press("Enter")
if self.animate_actions:
await self.remove_cursor_box(page, identifier)
async def scroll_id(self, page: Page, identifier: str, direction: str) -> None:
"""
Scroll the element with the given identifier in the specified direction.
Args:
page (Page): The Playwright page object.
identifier (str): The element identifier.
direction (str): The direction to scroll ("up" or "down").
"""
assert page is not None
await page.evaluate(
f"""
(function() {{
let elm = document.querySelector("[__elementId='{identifier}']");
if (elm) {{
if ("{direction}" == "up") {{
elm.scrollTop = Math.max(0, elm.scrollTop - elm.clientHeight);
}}
else {{
elm.scrollTop = Math.min(elm.scrollHeight - elm.clientHeight, elm.scrollTop + elm.clientHeight);
}}
}}
}})();
"""
)
async def get_webpage_text(self, page: Page, n_lines: int = 50) -> str:
"""
Retrieve the text content of the web page.
Args:
page (Page): The Playwright page object.
n_lines (int): The number of lines to return from the page inner text.
Returns:
str: The text content of the page.
"""
assert page is not None
try:
text_in_viewport = await page.evaluate("""() => {
return document.body.innerText;
}""")
text_in_viewport = "\n".join(text_in_viewport.split("\n")[:n_lines])
# remove empty lines
text_in_viewport = "\n".join([line for line in text_in_viewport.split("\n") if line.strip()])
assert isinstance(text_in_viewport, str)
return text_in_viewport
except Exception:
return ""
async def get_visible_text(self, page: Page) -> str:
"""
Retrieve the text content of the browser viewport (approximately).
Args:
page (Page): The Playwright page object.
Returns:
str: The text content of the page.
"""
assert page is not None
try:
await page.evaluate(self._page_script)
except Exception:
pass
result = await page.evaluate("MultimodalWebSurfer.getVisibleText();")
assert isinstance(result, str)
return result
async def get_page_markdown(self, page: Page) -> str:
"""
Retrieve the markdown content of the web page.
Currently not implemented.
Args:
page (Page): The Playwright page object.
Returns:
str: The markdown content of the page.
"""
assert page is not None
if self._markdown_converter is None and markitdown is not None:
self._markdown_converter = markitdown.MarkItDown()
assert self._markdown_converter is not None
html = await page.evaluate("document.documentElement.outerHTML;")
res = self._markdown_converter.convert_stream(
io.BytesIO(html.encode("utf-8")), file_extension=".html", url=page.url
)
assert hasattr(res, "text_content") and isinstance(res.text_content, str)
return res.text_content
else:
return await self.get_webpage_text(page, n_lines=200)