first commit

This commit is contained in:
DigiJ
2026-03-13 12:56:43 -07:00
commit 159cf9fcfe
309 changed files with 64584 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
from ._openai_agent import OpenAIAgent
from ._openai_assistant_agent import OpenAIAssistantAgent
__all__ = [
"OpenAIAgent",
"OpenAIAssistantAgent",
]

View File

@@ -0,0 +1,682 @@
import asyncio
import logging
from typing import (
Any,
AsyncGenerator,
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Sequence,
Type,
Union,
cast,
)
from agentdhal_agentchat import EVENT_LOGGER_NAME
from agentdhal_agentchat.agents import BaseChatAgent
from agentdhal_agentchat.base import Response
from agentdhal_agentchat.messages import (
AgentEvent,
BaseChatMessage,
ChatMessage,
HandoffMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallSummaryMessage,
)
from agentdhal_core import CancellationToken, Component
from agentdhal_core.models import UserMessage
from pydantic import BaseModel, Field
from typing_extensions import NotRequired, TypedDict
from openai import AsyncAzureOpenAI, AsyncOpenAI # type: ignore
# Number of characters to display when previewing image content in logs and UI
# Base64 encoded images can be very long, so we truncate for readability
IMAGE_CONTENT_PREVIEW_LENGTH = 50
# NOTE: We use the new Responses API, so ChatCompletion imports are not needed.
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
# TypedDict classes for built-in tool configurations
class FileSearchToolConfig(TypedDict):
"""Configuration for file_search tool."""
type: Literal["file_search"]
vector_store_ids: List[str] # required - The IDs of the vector stores to search
max_num_results: NotRequired[int] # optional
ranking_options: NotRequired[Dict[str, Any]] # optional
filters: NotRequired[Dict[str, Any]] # optional
class WebSearchToolConfig(TypedDict):
"""Configuration for web_search_preview tool."""
type: Literal["web_search_preview"]
search_context_size: NotRequired[str] # optional
user_location: NotRequired[Union[str, Dict[str, Any]]] # optional - Can be string or structured location
class ComputerUseToolConfig(TypedDict):
"""Configuration for computer_use_preview tool."""
type: Literal["computer_use_preview"]
display_height: int # required - Display height in pixels
display_width: int # required - Display width in pixels
environment: str # required - Environment type for computer use
class MCPToolConfig(TypedDict):
"""Configuration for mcp tool."""
type: Literal["mcp"]
server_label: str # required - Label for the MCP server
server_url: str # required - URL of the MCP server
allowed_tools: NotRequired[List[str]] # optional - List of allowed tools
headers: NotRequired[Dict[str, str]] # optional - HTTP headers for requests
require_approval: NotRequired[bool] # optional - Whether to require user approval
class CodeInterpreterToolConfig(TypedDict):
"""Configuration for code_interpreter tool."""
type: Literal["code_interpreter"]
container: str | Dict[str, Any] # required - Container configuration for code execution
class ImageGenerationToolConfig(TypedDict):
"""Configuration for image_generation tool."""
type: Literal["image_generation"]
background: NotRequired[str] # optional - Background color or image
input_image_mask: NotRequired[str] # optional - Mask for input image editing
class LocalShellToolConfig(TypedDict):
"""Configuration for local_shell tool.
WARNING: This tool is only supported with the 'codex-mini-latest' model
and is available exclusively through the Responses API.
"""
type: Literal["local_shell"]
# Note: local_shell currently has no additional parameters in the API
# Union type for all built-in tool configurations
BuiltinToolConfig = Union[
FileSearchToolConfig,
WebSearchToolConfig,
ComputerUseToolConfig,
MCPToolConfig,
CodeInterpreterToolConfig,
ImageGenerationToolConfig,
LocalShellToolConfig,
]
# Define ImageMessage class early since it's used in _convert_message_to_openai_message
class ImageMessage(BaseChatMessage):
"""A message containing an image."""
content: str # URL or base64 string
def to_model_message(self) -> UserMessage:
return UserMessage(content=self.content, source=self.source)
def to_model_text(self) -> str:
return "[image]"
def to_text(self) -> str:
# Truncate long image content (especially base64) for better readability
# While still showing enough of the URL or content to be identifiable
if len(self.content) > IMAGE_CONTENT_PREVIEW_LENGTH:
return f"[Image: {self.content[:IMAGE_CONTENT_PREVIEW_LENGTH]}...]"
return f"[Image: {self.content}]"
class OpenAIMessageContent(TypedDict):
type: str
text: str
class OpenAIImageUrlContent(TypedDict):
url: str
class OpenAIImageContent(TypedDict):
type: str
image_url: OpenAIImageUrlContent
class OpenAIMessage(TypedDict):
role: str
content: Union[str, List[Union[OpenAIMessageContent, OpenAIImageContent]]]
def _convert_message_to_openai_message(
message: Union[TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage],
) -> OpenAIMessage:
"""Convert an AutoGen message to an OpenAI message format."""
if isinstance(message, TextMessage):
if message.source == "user":
return {"role": "user", "content": str(message.content)}
elif message.source == "system":
return {"role": "system", "content": str(message.content)}
elif message.source == "assistant":
return {"role": "assistant", "content": str(message.content)}
else:
return {"role": "user", "content": str(message.content)}
elif isinstance(message, MultiModalMessage):
content_parts: List[Union[OpenAIMessageContent, OpenAIImageContent]] = []
for part in message.content:
if isinstance(part, TextMessage):
content_parts.append({"type": "text", "text": str(part.content)})
elif isinstance(part, ImageMessage):
image_content = str(part.content)
content_parts.append({"type": "image_url", "image_url": {"url": image_content}})
return {"role": "user", "content": content_parts}
else:
return {"role": "user", "content": str(message.content)}
class OpenAIAgentState(BaseModel):
type: str = Field(default="OpenAIAgentState")
response_id: Optional[str] = None
history: List[Dict[str, Any]] = Field(default_factory=list)
class OpenAIAgentConfig(BaseModel):
"""
Configuration model for OpenAI agent supporting OpenAI built-in tools only.
.. versionchanged:: v0.7.0
Added support for built-in tools in JSON configuration via _to_config and _from_config methods.
The tools field accepts built-in tool configurations (dict format) and built-in tool names (string format).
Custom tools are not supported.
"""
name: str
description: str
model: str
instructions: str
tools: List[Dict[str, Any] | str] | None = None
temperature: Optional[float] = 1
max_output_tokens: Optional[int] = None
json_mode: bool = False
store: bool = True
truncation: str = "disabled"
class OpenAIAgent(BaseChatAgent, Component[OpenAIAgentConfig]):
"""
An agent implementation that uses the OpenAI Responses API to generate responses.
Installation:
.. code-block:: bash
pip install "agentdhal-ext[openai]"
# pip install "agentdhal-ext[openai,azure]" # For Azure OpenAI Assistant
This agent leverages the Responses API to generate responses with capabilities like:
* Multi-turn conversations
* Built-in tool support (file_search, code_interpreter, web_search_preview, etc.)
Currently, custom tools are not supported.
.. versionchanged:: v0.7.0
Added support for built-in tool types like file_search, web_search_preview,
code_interpreter, computer_use_preview, image_generation, and mcp.
Added support for tool configurations with required and optional parameters.
Built-in tools are split into two categories:
**Tools that can use string format** (no required parameters):
- web_search_preview: Can be used as "web_search_preview" or with optional config
(user_location, search_context_size)
- image_generation: Can be used as "image_generation" or with optional config (background, input_image_mask)
- local_shell: Can be used as "local_shell" (WARNING: Only works with codex-mini-latest model)
**Tools that REQUIRE dict configuration** (have required parameters):
- file_search: MUST use dict with vector_store_ids (List[str])
- computer_use_preview: MUST use dict with display_height (int), display_width (int), environment (str)
- code_interpreter: MUST use dict with container (str)
- mcp: MUST use dict with server_label (str), server_url (str)
Using required-parameter tools in string format will raise a ValueError with helpful error messages.
The tools parameter type annotation only accepts string values for tools that don't require parameters.
Note:
Custom tools (autogen FunctionTool or other user-defined tools) are not supported by this agent.
Only OpenAI built-in tools provided via the Responses API are supported.
Args:
name (str): Name of the agent
description (str): Description of the agent's purpose
client (Union[AsyncOpenAI, AsyncAzureOpenAI]): OpenAI client instance
model (str): Model to use (e.g. "gpt-4.1")
instructions (str): System instructions for the agent
tools (Optional[Iterable[Union[str, BuiltinToolConfig]]]): Tools the agent can use.
Supported string values (no required parameters): "web_search_preview", "image_generation", "local_shell".
Dict values can provide configuration for built-in tools with parameters.
Required parameters for built-in tools:
- file_search: vector_store_ids (List[str])
- computer_use_preview: display_height (int), display_width (int), environment (str)
- code_interpreter: container (str)
- mcp: server_label (str), server_url (str)
Optional parameters for built-in tools:
- file_search: max_num_results (int), ranking_options (dict), filters (dict)
- web_search_preview: user_location (str or dict), search_context_size (int)
- image_generation: background (str), input_image_mask (str)
- mcp: allowed_tools (List[str]), headers (dict), require_approval (bool)
Special tools with model restrictions:
- local_shell: Only works with "codex-mini-latest" model (WARNING: Very limited support)
Custom tools are not supported.
temperature (Optional[float]): Temperature for response generation (default: 1)
max_output_tokens (Optional[int]): Maximum output tokens
json_mode (bool): Whether to use JSON mode (default: False)
store (bool): Whether to store conversations (default: True)
truncation (str): Truncation strategy (default: "disabled")
Example:
Basic usage with built-in tools:
.. code-block:: python
import asyncio
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.agents.openai import OpenAIAgent
from openai import AsyncOpenAI
async def example():
client = AsyncOpenAI()
agent = OpenAIAgent(
name="SimpleAgent",
description="A simple OpenAI agent using the Responses API",
client=client,
model="gpt-4.1",
instructions="You are a helpful assistant.",
tools=["web_search_preview"], # Only tools without required params
)
await Console(agent.run_stream(task="Search for recent AI developments"))
asyncio.run(example())
Usage with configured built-in tools:
.. code-block:: python
import asyncio
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.agents.openai import OpenAIAgent
from openai import AsyncOpenAI
async def example_with_configs():
client = AsyncOpenAI()
# Configure tools with required and optional parameters
tools = [
# {
# "type": "file_search",
# "vector_store_ids": ["vs_abc123"], # required
# "max_num_results": 10, # optional
# },
# {
# "type": "computer_use_preview",
# "display_height": 1024, # required
# "display_width": 1280, # required
# "environment": "linux", # required
# },
{
"type": "code_interpreter",
"container": {"type": "auto"}, # required
},
# {
# "type": "mcp",
# "server_label": "my-mcp-server", # required
# "server_url": "http://localhost:3000", # required
# },
{
"type": "web_search_preview",
"user_location": { # optional - structured location
"type": "approximate", # required: "approximate" or "exact"
"country": "US", # optional
"region": "CA", # optional
"city": "San Francisco", # optional
},
"search_context_size": "low", # optional
},
# "image_generation", # Simple tools can still use string format
]
agent = OpenAIAgent(
name="ConfiguredAgent",
description="An agent with configured tools",
client=client,
model="gpt-4.1",
instructions="You are a helpful assistant with specialized tools.",
tools=tools, # type: ignore
)
await Console(agent.run_stream(task="Search for recent AI developments"))
asyncio.run(example_with_configs())
Note:
Custom tools are not supported by OpenAIAgent. Use only built-in tools from the Responses API.
"""
component_config_schema = OpenAIAgentConfig
component_provider_override = "agentdhal_extensions.agents.openai.OpenAIAgent"
def __init__(
self: "OpenAIAgent",
name: str,
description: str,
client: Union[AsyncOpenAI, AsyncAzureOpenAI],
model: str,
instructions: str,
tools: Optional[
Iterable[
Union[
Literal["web_search_preview", "image_generation", "local_shell"],
BuiltinToolConfig,
]
]
] = None,
temperature: Optional[float] = 1,
max_output_tokens: Optional[int] = None,
json_mode: bool = False,
store: bool = True,
truncation: str = "disabled",
) -> None:
super().__init__(name, description)
self._client: Union[AsyncOpenAI, AsyncAzureOpenAI] = client
self._model: str = model
self._instructions: str = instructions
self._temperature: Optional[float] = temperature
self._max_output_tokens: Optional[int] = max_output_tokens
self._json_mode: bool = json_mode
self._store: bool = store
self._truncation: str = truncation
self._last_response_id: Optional[str] = None
self._message_history: List[Dict[str, Any]] = []
self._tools: List[Dict[str, Any]] = []
if tools is not None:
for tool in tools:
if isinstance(tool, str):
# Handle built-in tool types
self._add_builtin_tool(tool)
elif isinstance(tool, dict) and "type" in tool:
# Handle configured built-in tools
self._tools.append(cast(dict[str, Any], tool))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
def _add_builtin_tool(self, tool_name: str) -> None:
"""Add a built-in tool by name."""
# Skip if an identical tool has already been registered (idempotent behaviour)
if any(td.get("type") == tool_name for td in self._tools):
return # Duplicate ignore rather than raise to stay backward-compatible
# Only allow string format for tools that don't require parameters
if tool_name == "web_search_preview":
self._tools.append({"type": "web_search_preview"})
elif tool_name == "image_generation":
self._tools.append({"type": "image_generation"})
elif tool_name == "local_shell":
# Special handling for local_shell - very limited model support
if self._model != "codex-mini-latest":
raise ValueError(
f"Tool 'local_shell' is only supported with model 'codex-mini-latest', "
f"but current model is '{self._model}'. "
f"This tool is available exclusively through the Responses API and has severe limitations. "
f"Consider using agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool with "
f"agentdhal_extensions.code_executors.local.LocalCommandLineCodeExecutor for shell execution instead."
)
self._tools.append({"type": "local_shell"})
elif tool_name in ["file_search", "code_interpreter", "computer_use_preview", "mcp"]:
# These tools require specific parameters and must use dict configuration
raise ValueError(
f"Tool '{tool_name}' requires specific parameters and cannot be added using string format. "
f"Use dict configuration instead. Required parameters for {tool_name}: "
f"{self._get_required_params_help(tool_name)}"
)
else:
raise ValueError(f"Unsupported built-in tool type: {tool_name}")
def _get_required_params_help(self, tool_name: str) -> str:
"""Get help text for required parameters of a tool."""
help_text = {
"file_search": "vector_store_ids (List[str])",
"code_interpreter": "container (str | dict)",
"computer_use_preview": "display_height (int), display_width (int), environment (str)",
"mcp": "server_label (str), server_url (str)",
}
return help_text.get(tool_name, "unknown parameters")
def _convert_message_to_dict(self, message: OpenAIMessage) -> Dict[str, Any]:
"""Convert an OpenAIMessage to a Dict[str, Any]."""
return dict(message)
@property
def produced_message_types(
self: "OpenAIAgent",
) -> Sequence[
Union[
Type[TextMessage],
Type[MultiModalMessage],
Type[StopMessage],
Type[ToolCallSummaryMessage],
Type[HandoffMessage],
]
]:
"""Return the types of messages that this agent can produce."""
return [TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage]
# Custom tool execution is not supported by this agent.
def _build_api_parameters(self: "OpenAIAgent", messages: List[Dict[str, Any]]) -> Dict[str, Any]:
has_system_message = any(msg.get("role") == "system" for msg in messages)
if self._instructions and not has_system_message:
messages = [{"role": "system", "content": self._instructions}] + messages
api_params: Dict[str, Any] = {
"model": self._model,
"input": messages, # Responses API expects 'input'
}
if self._temperature is not None:
api_params["temperature"] = self._temperature
if self._max_output_tokens is not None:
api_params["max_output_tokens"] = self._max_output_tokens
if self._tools:
api_params["tools"] = self._tools
if self._json_mode:
api_params["text"] = {"type": "json_object"}
api_params["store"] = self._store
api_params["truncation"] = self._truncation
if self._last_response_id:
api_params["previous_response_id"] = self._last_response_id
return api_params
async def on_messages(
self: "OpenAIAgent", messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> Response:
response = None
inner_messages: List[
Union[AgentEvent, TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage]
] = []
async for msg in self.on_messages_stream(messages, cancellation_token):
if isinstance(msg, Response):
response = msg
# ModelClientStreamingChunkEvent does not exist in this version, so skip this check
else:
inner_messages.append(msg)
if response is None:
raise ValueError("No response was generated")
if response.inner_messages is None:
response.inner_messages = []
for msg in inner_messages:
if msg not in response.inner_messages:
response.inner_messages = list(response.inner_messages) + [msg]
return response
async def on_messages_stream(
self: "OpenAIAgent", messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[
Union[
AgentEvent, TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage, Response
],
None,
]:
input_messages: List[Dict[str, Any]] = []
if self._message_history:
input_messages.extend(self._message_history)
for message in messages:
if isinstance(
message, (TextMessage, MultiModalMessage, StopMessage, ToolCallSummaryMessage, HandoffMessage)
):
openai_message = _convert_message_to_openai_message(message)
dict_message = self._convert_message_to_dict(openai_message)
input_messages.append(dict_message)
self._message_history.append(dict_message)
else:
msg_content = str(cast(Any, message).content) if hasattr(message, "content") else str(message)
dict_message = {"role": "user", "content": msg_content}
input_messages.append(dict_message)
self._message_history.append(dict_message)
inner_messages: List[AgentEvent | ChatMessage] = []
api_params = self._build_api_parameters(input_messages)
try:
client = cast(Any, self._client)
response_obj = await cancellation_token.link_future(
asyncio.ensure_future(client.responses.create(**api_params))
)
content = getattr(response_obj, "output_text", None)
response_id = getattr(response_obj, "id", None)
self._last_response_id = response_id
# Use a readable placeholder when the API returns no content to aid debugging
content_str: str = str(content) if content is not None else "[no content returned]"
self._message_history.append({"role": "assistant", "content": content_str})
final_message = TextMessage(source=self.name, content=content_str)
response = Response(chat_message=final_message, inner_messages=inner_messages)
yield response
except Exception as e:
error_message = f"Error generating response: {str(e)}"
event_logger.error(f"API error: {error_message}", exc_info=True)
error_response = TextMessage(source=self.name, content=error_message)
yield Response(chat_message=error_response, inner_messages=inner_messages)
async def on_reset(self: "OpenAIAgent", cancellation_token: CancellationToken) -> None:
self._last_response_id = None
self._message_history = []
async def save_state(self: "OpenAIAgent") -> Mapping[str, Any]:
state = OpenAIAgentState(
response_id=self._last_response_id,
history=self._message_history,
)
return state.model_dump()
async def load_state(self: "OpenAIAgent", state: Mapping[str, Any]) -> None:
agent_state = OpenAIAgentState.model_validate(state)
self._last_response_id = agent_state.response_id
self._message_history = agent_state.history
def _to_config(self: "OpenAIAgent") -> OpenAIAgentConfig:
"""Convert the OpenAI agent to a declarative config.
Serializes built-in tools to their appropriate configuration formats for JSON serialization.
Returns:
OpenAIAgentConfig: The configuration that can recreate this agent.
"""
return OpenAIAgentConfig(
name=self.name,
description=self.description,
model=self._model,
instructions=self._instructions,
tools=list(self._tools),
temperature=self._temperature,
max_output_tokens=self._max_output_tokens,
json_mode=self._json_mode,
store=self._store,
truncation=self._truncation,
)
@classmethod
def _from_config(cls: Type["OpenAIAgent"], config: OpenAIAgentConfig) -> "OpenAIAgent":
"""Create an OpenAI agent from a declarative config.
Handles built-in tools (from string or dict configurations).
Args:
config: The configuration to load the agent from.
Returns:
OpenAIAgent: The reconstructed agent.
"""
from openai import AsyncOpenAI
client = AsyncOpenAI()
return cls(
name=config.name,
description=config.description,
client=client,
model=config.model,
instructions=config.instructions,
tools=config.tools, # type: ignore
temperature=config.temperature,
max_output_tokens=config.max_output_tokens,
json_mode=config.json_mode,
store=config.store,
truncation=config.truncation,
)
# Add public API wrappers for configuration and tools
def to_config(self) -> OpenAIAgentConfig:
"""Public wrapper for the private _to_config method."""
return self._to_config()
@classmethod
def from_config(cls, config: OpenAIAgentConfig) -> "OpenAIAgent":
"""Public wrapper for the private _from_config classmethod."""
return cls._from_config(config)
@property
def tools(self) -> list[Any]:
"""Public access to the agent's tools."""
return self._tools
@property
def model(self) -> str:
"""Public access to the agent's model."""
return self._model

View File

@@ -0,0 +1,715 @@
import asyncio
import json
import logging
import os
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Union,
cast,
)
import aiofiles
from agentdhal_agentchat import EVENT_LOGGER_NAME
from agentdhal_agentchat.agents import BaseChatAgent
from agentdhal_agentchat.base import Response
from agentdhal_agentchat.messages import (
BaseAgentEvent,
BaseChatMessage,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
)
from agentdhal_core import CancellationToken, FunctionCall, Image
from agentdhal_core.models import ChatCompletionClient, FunctionExecutionResult
from agentdhal_core.tools import FunctionTool, Tool
from pydantic import BaseModel, Field
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI, NotGiven
from openai.pagination import AsyncCursorPage
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
from openai.types import FileObject
from openai.types.beta import thread_update_params
from openai.types.beta.assistant import Assistant
from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam
from openai.types.beta.assistant_tool_param import AssistantToolParam
from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
from openai.types.beta.file_search_tool_param import FileSearchToolParam
from openai.types.beta.function_tool_param import FunctionToolParam
from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter
from openai.types.beta.threads import Message, MessageDeleted, Run
from openai.types.beta.threads.image_url_content_block_param import ImageURLContentBlockParam
from openai.types.beta.threads.image_url_param import ImageURLParam
from openai.types.beta.threads.message_content_part_param import (
MessageContentPartParam,
)
from openai.types.beta.threads.text_content_block_param import TextContentBlockParam
from openai.types.shared_params.function_definition import FunctionDefinition
from openai.types.vector_store import VectorStore
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam":
"""Convert an autogen Tool to an OpenAI Assistant function tool parameter."""
schema = tool.schema
parameters: Dict[str, object] = {}
if "parameters" in schema:
parameters = {
"type": schema["parameters"]["type"],
"properties": schema["parameters"]["properties"],
}
if "required" in schema["parameters"]:
parameters["required"] = schema["parameters"]["required"]
function_def = FunctionDefinition(
name=schema["name"],
description=schema.get("description", ""),
parameters=parameters,
)
return FunctionToolParam(type="function", function=function_def)
class OpenAIAssistantAgentState(BaseModel):
type: str = Field(default="OpenAIAssistantAgentState")
assistant_id: Optional[str] = None
thread_id: Optional[str] = None
initial_message_ids: List[str] = Field(default_factory=list)
vector_store_id: Optional[str] = None
uploaded_file_ids: List[str] = Field(default_factory=list)
class OpenAIAssistantAgent(BaseChatAgent):
"""An agent implementation that uses the Assistant API to generate responses.
Installation:
.. code-block:: bash
pip install "agentdhal-ext[openai]" # For OpenAI Assistant
# pip install "agentdhal-ext[openai,azure]" # For Azure OpenAI Assistant
This agent leverages the Assistant API to create AI assistants with capabilities like:
* Code interpretation and execution
* File handling and search
* Custom function calling
* Multi-turn conversations
The agent maintains a thread of conversation and can use various tools including
* Code interpreter: For executing code and working with files
* File search: For searching through uploaded documents
* Custom functions: For extending capabilities with user-defined tools
Key Features:
* Supports multiple file formats including code, documents, images
* Can handle up to 128 tools per assistant
* Maintains conversation context in threads
* Supports file uploads for code interpreter and search
* Vector store integration for efficient file search
* Automatic file parsing and embedding
You can use an existing thread or assistant by providing the `thread_id` or `assistant_id` parameters.
Examples:
Use the assistant to analyze data in a CSV file:
.. code-block:: python
from openai import AsyncOpenAI
from agentdhal_core import CancellationToken
import asyncio
from agentdhal_extensions.agents.openai import OpenAIAssistantAgent
from agentdhal_agentchat.messages import TextMessage
async def example():
cancellation_token = CancellationToken()
# Create an OpenAI client
client = AsyncOpenAI(api_key="your-api-key", base_url="your-base-url")
# Create an assistant with code interpreter
assistant = OpenAIAssistantAgent(
name="PythonHelper",
description="Helps with Python programming",
client=client,
model="gpt-4",
instructions="You are a helpful Python programming assistant.",
tools=["code_interpreter"],
)
# Upload files for the assistant to use
await assistant.on_upload_for_code_interpreter("data.csv", cancellation_token)
# Get response from the assistant
response = await assistant.on_messages(
[TextMessage(source="user", content="Analyze the data in data.csv")], cancellation_token
)
print(response)
# Clean up resources
await assistant.delete_uploaded_files(cancellation_token)
await assistant.delete_assistant(cancellation_token)
asyncio.run(example())
Use Azure OpenAI Assistant with AAD authentication:
.. code-block:: python
from openai import AsyncAzureOpenAI
import asyncio
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from agentdhal_core import CancellationToken
from agentdhal_extensions.agents.openai import OpenAIAssistantAgent
from agentdhal_agentchat.messages import TextMessage
async def example():
cancellation_token = CancellationToken()
# Create an Azure OpenAI client
token_provider = get_bearer_token_provider(DefaultAzureCredential())
client = AsyncAzureOpenAI(
azure_deployment="YOUR_AZURE_DEPLOYMENT",
api_version="YOUR_API_VERSION",
azure_endpoint="YOUR_AZURE_ENDPOINT",
azure_ad_token_provider=token_provider,
)
# Create an assistant with code interpreter
assistant = OpenAIAssistantAgent(
name="PythonHelper",
description="Helps with Python programming",
client=client,
model="gpt-4o",
instructions="You are a helpful Python programming assistant.",
tools=["code_interpreter"],
)
# Get response from the assistant
response = await assistant.on_messages([TextMessage(source="user", content="Hello.")], cancellation_token)
print(response)
# Clean up resources
await assistant.delete_assistant(cancellation_token)
asyncio.run(example())
Args:
name (str): Name of the assistant
description (str): Description of the assistant's purpose
client (AsyncOpenAI | AsyncAzureOpenAI): OpenAI client or Azure OpenAI client instance
model (str): Model to use (e.g. "gpt-4")
instructions (str): System instructions for the assistant
tools (Optional[Iterable[Union[Literal["code_interpreter", "file_search"], Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]]]]): Tools the assistant can use
assistant_id (Optional[str]): ID of existing assistant to use
thread_id (Optional[str]): ID of existing thread to use
metadata (Optional[Dict[str, str]]): Additional metadata for the assistant.
response_format (Optional[AssistantResponseFormatOptionParam]): Response format settings
temperature (Optional[float]): Temperature for response generation
tool_resources (Optional[ToolResources]): Additional tool configuration
top_p (Optional[float]): Top p sampling parameter
"""
def __init__(
self,
name: str,
description: str,
client: AsyncOpenAI | AsyncAzureOpenAI,
model: str,
instructions: str,
tools: Optional[
Iterable[
Union[
Literal["code_interpreter", "file_search"],
Tool | Callable[..., Any] | Callable[..., Awaitable[Any]],
]
]
] = None,
assistant_id: Optional[str] = None,
thread_id: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
response_format: Optional["AssistantResponseFormatOptionParam"] = None,
temperature: Optional[float] = None,
tool_resources: Optional["ToolResources"] = None,
top_p: Optional[float] = None,
) -> None:
if isinstance(client, ChatCompletionClient):
raise ValueError(
"Incorrect client passed to OpenAIAssistantAgent. Please use an OpenAI AsyncClient instance instead of an AutoGen ChatCompletionClient instance."
)
super().__init__(name, description)
if tools is None:
tools = []
# Store original tools and converted tools separately
self._original_tools: List[Tool] = []
converted_tools: List["AssistantToolParam"] = []
for tool in tools:
if isinstance(tool, str):
if tool == "code_interpreter":
converted_tools.append(CodeInterpreterToolParam(type="code_interpreter"))
elif tool == "file_search":
converted_tools.append(FileSearchToolParam(type="file_search"))
elif isinstance(tool, Tool):
self._original_tools.append(tool)
converted_tools.append(_convert_tool_to_function_param(tool))
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
function_tool = FunctionTool(tool, description=description)
self._original_tools.append(function_tool)
converted_tools.append(_convert_tool_to_function_param(function_tool))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
self._client = client
self._assistant: Optional["Assistant"] = None
self._thread: Optional["Thread"] = None
self._init_thread_id = thread_id
self._model = model
self._instructions = instructions
self._api_tools = converted_tools
self._assistant_id = assistant_id
self._metadata = metadata
self._response_format = response_format
self._temperature = temperature
self._tool_resources = tool_resources
self._top_p = top_p
self._vector_store_id: Optional[str] = None
self._uploaded_file_ids: List[str] = []
# Variables to track initial state
self._initial_message_ids: Set[str] = set()
self._initial_state_retrieved: bool = False
async def _ensure_initialized(self) -> None:
"""Ensure assistant and thread are created."""
if self._assistant is None:
if self._assistant_id:
self._assistant = await self._client.beta.assistants.retrieve(assistant_id=self._assistant_id) # type: ignore[reportDeprecated]
else:
self._assistant = await self._client.beta.assistants.create( # type: ignore[reportDeprecated]
model=self._model,
description=self.description,
instructions=self._instructions,
tools=self._api_tools,
metadata=self._metadata,
response_format=self._response_format if self._response_format else NOT_GIVEN, # type: ignore
temperature=self._temperature,
tool_resources=self._tool_resources if self._tool_resources else NOT_GIVEN, # type: ignore
top_p=self._top_p,
)
if self._thread is None:
if self._init_thread_id:
self._thread = await self._client.beta.threads.retrieve(thread_id=self._init_thread_id) # type: ignore[reportDeprecated]
else:
self._thread = await self._client.beta.threads.create() # type: ignore[reportDeprecated]
# Retrieve initial state only once
if not self._initial_state_retrieved:
await self._retrieve_initial_state()
self._initial_state_retrieved = True
async def _retrieve_initial_state(self) -> None:
"""Retrieve and store the initial state of messages and runs."""
# Retrieve all initial message IDs
initial_message_ids: Set[str] = set()
after: str | NotGiven = NOT_GIVEN
while True:
msgs: AsyncCursorPage[Message] = await self._client.beta.threads.messages.list( # type: ignore[reportDeprecated]
self._thread_id, after=after, order="asc", limit=100
)
for msg in msgs.data:
initial_message_ids.add(msg.id)
if not msgs.has_next_page():
break
after = msgs.data[-1].id
self._initial_message_ids = initial_message_ids
@property
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
"""The types of messages that the assistant agent produces."""
return (TextMessage,)
@property
def threads(self) -> AsyncThreads:
return self._client.beta.threads
@property
def runs(self) -> AsyncRuns:
return self._client.beta.threads.runs
@property
def messages(self) -> AsyncMessages:
return self._client.beta.threads.messages
@property
def _get_assistant_id(self) -> str:
if self._assistant is None:
raise ValueError("Assistant not initialized")
return self._assistant.id
@property
def _thread_id(self) -> str:
if self._thread is None:
raise ValueError("Thread not initialized")
return self._thread.id
async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str:
"""Execute a tool call and return the result."""
if not self._original_tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._original_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token, call_id=tool_call.id)
return tool.return_value_as_string(result)
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handle incoming messages and return a response."""
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
async def on_messages_stream(
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
"""Handle incoming messages and return a response."""
await self._ensure_initialized()
# Process all messages in sequence
for message in messages:
await self.handle_incoming_message(message, cancellation_token)
# Inner messages for tool calls
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
# Create and start a run
run: Run = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.runs.create( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
assistant_id=self._get_assistant_id,
)
)
)
# Wait for run completion by polling
while True:
run = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.runs.retrieve( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
run_id=run.id,
)
)
)
if run.status == "failed":
raise ValueError(f"Run failed: {run.last_error}")
# If the run requires action (function calls), execute tools and continue
if run.status == "requires_action" and run.required_action is not None:
tool_calls: List[FunctionCall] = []
for required_tool_call in run.required_action.submit_tool_outputs.tool_calls:
if required_tool_call.type == "function":
tool_calls.append(
FunctionCall(
id=required_tool_call.id,
name=required_tool_call.function.name,
arguments=required_tool_call.function.arguments,
)
)
# Add tool call message to inner messages
tool_call_msg = ToolCallRequestEvent(source=self.name, content=tool_calls)
inner_messages.append(tool_call_msg)
event_logger.debug(tool_call_msg)
yield tool_call_msg
# Execute tool calls and get results
tool_outputs: List[FunctionExecutionResult] = []
for tool_call in tool_calls:
try:
result = await self._execute_tool_call(tool_call, cancellation_token)
is_error = False
except Exception as e:
result = f"Error: {e}"
is_error = True
tool_outputs.append(
FunctionExecutionResult(
content=result, call_id=tool_call.id, is_error=is_error, name=tool_call.name
)
)
# Add tool result message to inner messages
tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs)
inner_messages.append(tool_result_msg)
event_logger.debug(tool_result_msg)
yield tool_result_msg
# Submit tool outputs back to the run
run = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.runs.submit_tool_outputs( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
run_id=run.id,
tool_outputs=[{"tool_call_id": t.call_id, "output": t.content} for t in tool_outputs],
)
)
)
continue
if run.status == "completed":
break
await asyncio.sleep(0.5)
# Get messages after run completion
assistant_messages: AsyncCursorPage[Message] = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.list(thread_id=self._thread_id, order="desc", limit=1) # type: ignore[reportDeprecated]
)
)
if not assistant_messages.data:
raise ValueError("No messages received from assistant")
# Get the last message's content
last_message = assistant_messages.data[0]
if not last_message.content:
raise ValueError(f"No content in the last message: {last_message}")
# Extract text content
text_content = [content for content in last_message.content if content.type == "text"]
if not text_content:
raise ValueError(f"Expected text content in the last message: {last_message.content}")
# Return the assistant's response as a Response with inner messages
chat_message = TextMessage(source=self.name, content=text_content[0].text.value)
yield Response(chat_message=chat_message, inner_messages=inner_messages)
async def handle_incoming_message(self, message: BaseChatMessage, cancellation_token: CancellationToken) -> None:
"""Handle regular text messages by adding them to the thread."""
content: str | List[MessageContentPartParam] | None = None
llm_message = message.to_model_message()
if isinstance(llm_message.content, str):
content = llm_message.content
else:
content = []
for c in llm_message.content:
if isinstance(c, str):
content.append(TextContentBlockParam(text=c, type="text"))
elif isinstance(c, Image):
content.append(ImageURLContentBlockParam(image_url=ImageURLParam(url=c.data_uri), type="image_url"))
else:
raise ValueError(f"Unsupported content type: {type(c)} in {message}")
await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.create( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
content=content,
role="user",
)
)
)
async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Handle reset command by deleting new messages and runs since initialization."""
await self._ensure_initialized()
# Retrieve all message IDs in the thread
new_message_ids: List[str] = []
after: str | NotGiven = NOT_GIVEN
while True:
msgs: AsyncCursorPage[Message] = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.list(self._thread_id, after=after, order="asc", limit=100) # type: ignore[reportDeprecated]
)
)
for msg in msgs.data:
if msg.id not in self._initial_message_ids:
new_message_ids.append(msg.id)
if not msgs.has_next_page():
break
after = msgs.data[-1].id
# Delete new messages
for msg_id in new_message_ids:
status: MessageDeleted = await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id) # type: ignore[reportDeprecated]
)
)
assert status.deleted is True
async def _upload_files(self, file_paths: str | Iterable[str], cancellation_token: CancellationToken) -> List[str]:
"""Upload files and return their IDs."""
await self._ensure_initialized()
if isinstance(file_paths, str):
file_paths = [file_paths]
file_ids: List[str] = []
for file_path in file_paths:
async with aiofiles.open(file_path, mode="rb") as f:
file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read()))
file_name = os.path.basename(file_path)
file: FileObject = await cancellation_token.link_future(
asyncio.ensure_future(self._client.files.create(file=(file_name, file_content), purpose="assistants"))
)
file_ids.append(file.id)
self._uploaded_file_ids.append(file.id)
return file_ids
async def on_upload_for_code_interpreter(
self, file_paths: str | Iterable[str], cancellation_token: CancellationToken
) -> None:
"""Handle file uploads for the code interpreter."""
await self._ensure_initialized()
file_ids = await self._upload_files(file_paths, cancellation_token)
# Update thread with the new files
thread = await cancellation_token.link_future(
asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id)) # type: ignore[reportDeprecated]
)
tool_resources: ToolResources = thread.tool_resources or ToolResources()
code_interpreter: ToolResourcesCodeInterpreter = (
tool_resources.code_interpreter or ToolResourcesCodeInterpreter()
)
existing_file_ids: List[str] = code_interpreter.file_ids or []
existing_file_ids.extend(file_ids)
tool_resources.code_interpreter = ToolResourcesCodeInterpreter(file_ids=existing_file_ids)
await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.threads.update( # type: ignore[reportDeprecated]
thread_id=self._thread_id,
tool_resources=cast(thread_update_params.ToolResources, tool_resources.model_dump()),
)
)
)
async def on_upload_for_file_search(
self, file_paths: str | Iterable[str], cancellation_token: CancellationToken
) -> None:
"""Handle file uploads for file search."""
await self._ensure_initialized()
# Check if file_search is enabled in tools
if not any(tool.get("type") == "file_search" for tool in self._api_tools):
raise ValueError(
"File search is not enabled for this assistant. Add a file_search tool when creating the assistant."
)
# Create vector store if not already created
if self._vector_store_id is None:
vector_store: VectorStore = await cancellation_token.link_future(
asyncio.ensure_future(self._client.vector_stores.create())
)
self._vector_store_id = vector_store.id
# Update assistant with vector store ID
await cancellation_token.link_future(
asyncio.ensure_future(
self._client.beta.assistants.update(
assistant_id=self._get_assistant_id,
tool_resources={"file_search": {"vector_store_ids": [self._vector_store_id]}},
)
)
)
file_ids = await self._upload_files(file_paths, cancellation_token)
# Create file batch with the file IDs
await cancellation_token.link_future(
asyncio.ensure_future(
self._client.vector_stores.file_batches.create_and_poll(
vector_store_id=self._vector_store_id, file_ids=file_ids
)
)
)
async def delete_uploaded_files(self, cancellation_token: CancellationToken) -> None:
"""Delete all files that were uploaded by this agent instance."""
await self._ensure_initialized()
for file_id in self._uploaded_file_ids:
try:
await cancellation_token.link_future(asyncio.ensure_future(self._client.files.delete(file_id=file_id)))
except Exception as e:
event_logger.error(f"Failed to delete file {file_id}: {str(e)}")
self._uploaded_file_ids = []
async def delete_assistant(self, cancellation_token: CancellationToken) -> None:
"""Delete the assistant if it was created by this instance."""
await self._ensure_initialized()
if self._assistant is not None and not self._assistant_id:
try:
await cancellation_token.link_future(
asyncio.ensure_future(self._client.beta.assistants.delete(assistant_id=self._get_assistant_id)) # type: ignore[reportDeprecated]
)
self._assistant = None
except Exception as e:
event_logger.error(f"Failed to delete assistant: {str(e)}")
async def delete_vector_store(self, cancellation_token: CancellationToken) -> None:
"""Delete the vector store if it was created by this instance."""
await self._ensure_initialized()
if self._vector_store_id is not None:
try:
await cancellation_token.link_future(
asyncio.ensure_future(self._client.vector_stores.delete(vector_store_id=self._vector_store_id))
)
self._vector_store_id = None
except Exception as e:
event_logger.error(f"Failed to delete vector store: {str(e)}")
async def save_state(self) -> Mapping[str, Any]:
state = OpenAIAssistantAgentState(
assistant_id=self._assistant.id if self._assistant else self._assistant_id,
thread_id=self._thread.id if self._thread else self._init_thread_id,
initial_message_ids=list(self._initial_message_ids),
vector_store_id=self._vector_store_id,
uploaded_file_ids=self._uploaded_file_ids,
)
return state.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
agent_state = OpenAIAssistantAgentState.model_validate(state)
self._assistant_id = agent_state.assistant_id
self._init_thread_id = agent_state.thread_id
self._initial_message_ids = set(agent_state.initial_message_ids)
self._vector_store_id = agent_state.vector_store_id
self._uploaded_file_ids = agent_state.uploaded_file_ids