first commit
This commit is contained in:
190
agent_dhal/agentdhal_extensions/tools/mcp/_base.py
Normal file
190
agent_dhal/agentdhal_extensions/tools/mcp/_base.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import asyncio
|
||||
import builtins
|
||||
import json
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Generic, Sequence, Type, TypeVar
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from agentdhal_core.utils import schema_to_pydantic_model
|
||||
from pydantic import BaseModel
|
||||
from pydantic.networks import AnyUrl
|
||||
|
||||
from mcp import ClientSession, Tool
|
||||
from mcp.types import AudioContent, ContentBlock, EmbeddedResource, ImageContent, ResourceLink, TextContent
|
||||
|
||||
from ._config import McpServerParams
|
||||
from ._session import create_mcp_server_session
|
||||
|
||||
TServerParams = TypeVar("TServerParams", bound=McpServerParams)
|
||||
|
||||
|
||||
class McpToolAdapter(BaseTool[BaseModel, Any], ABC, Generic[TServerParams]):
|
||||
"""
|
||||
Base adapter class for MCP tools to make them compatible with AutoGen.
|
||||
|
||||
Args:
|
||||
server_params (TServerParams): Parameters for the MCP server connection.
|
||||
tool (Tool): The MCP tool to wrap.
|
||||
"""
|
||||
|
||||
component_type = "tool"
|
||||
|
||||
def __init__(self, server_params: TServerParams, tool: Tool, session: ClientSession | None = None) -> None:
|
||||
self._tool = tool
|
||||
self._server_params = server_params
|
||||
self._session = session
|
||||
|
||||
# Extract name and description
|
||||
name = tool.name
|
||||
description = tool.description or ""
|
||||
|
||||
# Create the input model from the tool's schema
|
||||
input_model = schema_to_pydantic_model(tool.inputSchema)
|
||||
|
||||
# Use Any as return type since MCP tool returns can vary
|
||||
return_type: Type[Any] = object
|
||||
|
||||
super().__init__(input_model, return_type, name, description)
|
||||
|
||||
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
|
||||
"""
|
||||
Run the MCP tool with the provided arguments.
|
||||
|
||||
Args:
|
||||
args (BaseModel): The arguments to pass to the tool.
|
||||
cancellation_token (CancellationToken): Token to signal cancellation.
|
||||
|
||||
Returns:
|
||||
Any: The result of the tool execution.
|
||||
|
||||
Raises:
|
||||
Exception: If the operation is cancelled or the tool execution fails.
|
||||
"""
|
||||
# Convert the input model to a dictionary
|
||||
# Exclude unset values to avoid sending them to the MCP servers which may cause errors
|
||||
# for many servers.
|
||||
kwargs = args.model_dump(exclude_unset=True)
|
||||
|
||||
if self._session is not None:
|
||||
# If a session is provided, use it directly.
|
||||
session = self._session
|
||||
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)
|
||||
|
||||
async with create_mcp_server_session(self._server_params) as session:
|
||||
await session.initialize()
|
||||
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)
|
||||
|
||||
def _normalize_payload_to_content_list(self, payload: Sequence[ContentBlock]) -> list[ContentBlock]:
|
||||
"""
|
||||
Normalizes a raw tool output payload into a list of content items.
|
||||
- If payload is already a sequence of ContentBlock items, it's converted to a list and returned.
|
||||
- If payload is a single ContentBlock item, it's wrapped in a list.
|
||||
- If payload is a string, it's wrapped in [TextContent(text=payload)].
|
||||
- Otherwise, the payload is stringified and wrapped in [TextContent(text=str(payload))].
|
||||
"""
|
||||
if isinstance(payload, Sequence) and all(
|
||||
isinstance(item, (TextContent, ImageContent, EmbeddedResource, AudioContent, ResourceLink))
|
||||
for item in payload
|
||||
):
|
||||
return list(payload)
|
||||
elif isinstance(payload, (TextContent, ImageContent, EmbeddedResource, AudioContent, ResourceLink)):
|
||||
return [payload]
|
||||
elif isinstance(payload, str):
|
||||
return [TextContent(text=payload, type="text")]
|
||||
else:
|
||||
return [TextContent(text=str(payload), type="text")]
|
||||
|
||||
async def _run(self, args: Dict[str, Any], cancellation_token: CancellationToken, session: ClientSession) -> Any:
|
||||
exceptions_to_catch: tuple[Type[BaseException], ...]
|
||||
if hasattr(builtins, "ExceptionGroup"):
|
||||
exceptions_to_catch = (asyncio.CancelledError, builtins.ExceptionGroup)
|
||||
else:
|
||||
exceptions_to_catch = (asyncio.CancelledError,)
|
||||
|
||||
try:
|
||||
if cancellation_token.is_cancelled():
|
||||
raise asyncio.CancelledError("Operation cancelled")
|
||||
|
||||
result_future = asyncio.ensure_future(session.call_tool(name=self._tool.name, arguments=args))
|
||||
cancellation_token.link_future(result_future)
|
||||
result = await result_future
|
||||
|
||||
normalized_content_list = self._normalize_payload_to_content_list(result.content)
|
||||
|
||||
if result.isError:
|
||||
serialized_error_message = self.return_value_as_string(normalized_content_list)
|
||||
raise Exception(serialized_error_message)
|
||||
return normalized_content_list
|
||||
|
||||
except exceptions_to_catch:
|
||||
# Re-raise these specific exception types directly.
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def from_server_params(cls, server_params: TServerParams, tool_name: str) -> "McpToolAdapter[TServerParams]":
|
||||
"""
|
||||
Create an instance of McpToolAdapter from server parameters and tool name.
|
||||
|
||||
Args:
|
||||
server_params (TServerParams): Parameters for the MCP server connection.
|
||||
tool_name (str): The name of the tool to wrap.
|
||||
|
||||
Returns:
|
||||
McpToolAdapter[TServerParams]: An instance of McpToolAdapter.
|
||||
|
||||
Raises:
|
||||
ValueError: If the tool with the specified name is not found.
|
||||
"""
|
||||
async with create_mcp_server_session(server_params) as session:
|
||||
await session.initialize()
|
||||
|
||||
tools_response = await session.list_tools()
|
||||
matching_tool = next((t for t in tools_response.tools if t.name == tool_name), None)
|
||||
|
||||
if matching_tool is None:
|
||||
raise ValueError(
|
||||
f"Tool '{tool_name}' not found, available tools: {', '.join([t.name for t in tools_response.tools])}"
|
||||
)
|
||||
|
||||
return cls(server_params=server_params, tool=matching_tool)
|
||||
|
||||
def return_value_as_string(self, value: list[Any]) -> str:
|
||||
"""Return a string representation of the result."""
|
||||
|
||||
def serialize_item(item: Any) -> dict[str, Any]:
|
||||
if isinstance(item, (TextContent, ImageContent, AudioContent)):
|
||||
dumped = item.model_dump()
|
||||
# Remove the 'meta' field if it exists and is None (for backward compatibility)
|
||||
if dumped.get("meta") is None:
|
||||
dumped.pop("meta", None)
|
||||
return dumped
|
||||
elif isinstance(item, EmbeddedResource):
|
||||
type = item.type
|
||||
resource = {}
|
||||
for key, val in item.resource.model_dump().items():
|
||||
# Skip 'meta' field if it's None (for backward compatibility)
|
||||
if key == "meta" and val is None:
|
||||
continue
|
||||
if isinstance(val, AnyUrl):
|
||||
resource[key] = str(val)
|
||||
else:
|
||||
resource[key] = val
|
||||
dumped_annotations = item.annotations.model_dump() if item.annotations else None
|
||||
# Remove 'meta' from annotations if it exists and is None
|
||||
if dumped_annotations and dumped_annotations.get("meta") is None:
|
||||
dumped_annotations.pop("meta", None)
|
||||
return {"type": type, "resource": resource, "annotations": dumped_annotations}
|
||||
elif isinstance(item, ResourceLink):
|
||||
dumped = item.model_dump()
|
||||
# Remove the 'meta' field if it exists and is None (for backward compatibility)
|
||||
if dumped.get("meta") is None:
|
||||
dumped.pop("meta", None)
|
||||
# Convert AnyUrl to string for JSON serialization
|
||||
if "uri" in dumped and isinstance(dumped["uri"], AnyUrl):
|
||||
dumped["uri"] = str(dumped["uri"])
|
||||
return dumped
|
||||
else:
|
||||
return {}
|
||||
|
||||
return json.dumps([serialize_item(item) for item in value])
|
||||
Reference in New Issue
Block a user