Files

191 lines
8.1 KiB
Python
Raw Permalink Normal View History

2026-03-13 12:56:43 -07:00
import asyncio
import builtins
import json
from abc import ABC
from typing import Any, Dict, Generic, Sequence, Type, TypeVar
from agentdhal_core import CancellationToken
from agentdhal_core.tools import BaseTool
from agentdhal_core.utils import schema_to_pydantic_model
from pydantic import BaseModel
from pydantic.networks import AnyUrl
from mcp import ClientSession, Tool
from mcp.types import AudioContent, ContentBlock, EmbeddedResource, ImageContent, ResourceLink, TextContent
from ._config import McpServerParams
from ._session import create_mcp_server_session
TServerParams = TypeVar("TServerParams", bound=McpServerParams)
class McpToolAdapter(BaseTool[BaseModel, Any], ABC, Generic[TServerParams]):
"""
Base adapter class for MCP tools to make them compatible with AutoGen.
Args:
server_params (TServerParams): Parameters for the MCP server connection.
tool (Tool): The MCP tool to wrap.
"""
component_type = "tool"
def __init__(self, server_params: TServerParams, tool: Tool, session: ClientSession | None = None) -> None:
self._tool = tool
self._server_params = server_params
self._session = session
# Extract name and description
name = tool.name
description = tool.description or ""
# Create the input model from the tool's schema
input_model = schema_to_pydantic_model(tool.inputSchema)
# Use Any as return type since MCP tool returns can vary
return_type: Type[Any] = object
super().__init__(input_model, return_type, name, description)
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
"""
Run the MCP tool with the provided arguments.
Args:
args (BaseModel): The arguments to pass to the tool.
cancellation_token (CancellationToken): Token to signal cancellation.
Returns:
Any: The result of the tool execution.
Raises:
Exception: If the operation is cancelled or the tool execution fails.
"""
# Convert the input model to a dictionary
# Exclude unset values to avoid sending them to the MCP servers which may cause errors
# for many servers.
kwargs = args.model_dump(exclude_unset=True)
if self._session is not None:
# If a session is provided, use it directly.
session = self._session
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)
async with create_mcp_server_session(self._server_params) as session:
await session.initialize()
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)
def _normalize_payload_to_content_list(self, payload: Sequence[ContentBlock]) -> list[ContentBlock]:
"""
Normalizes a raw tool output payload into a list of content items.
- If payload is already a sequence of ContentBlock items, it's converted to a list and returned.
- If payload is a single ContentBlock item, it's wrapped in a list.
- If payload is a string, it's wrapped in [TextContent(text=payload)].
- Otherwise, the payload is stringified and wrapped in [TextContent(text=str(payload))].
"""
if isinstance(payload, Sequence) and all(
isinstance(item, (TextContent, ImageContent, EmbeddedResource, AudioContent, ResourceLink))
for item in payload
):
return list(payload)
elif isinstance(payload, (TextContent, ImageContent, EmbeddedResource, AudioContent, ResourceLink)):
return [payload]
elif isinstance(payload, str):
return [TextContent(text=payload, type="text")]
else:
return [TextContent(text=str(payload), type="text")]
async def _run(self, args: Dict[str, Any], cancellation_token: CancellationToken, session: ClientSession) -> Any:
exceptions_to_catch: tuple[Type[BaseException], ...]
if hasattr(builtins, "ExceptionGroup"):
exceptions_to_catch = (asyncio.CancelledError, builtins.ExceptionGroup)
else:
exceptions_to_catch = (asyncio.CancelledError,)
try:
if cancellation_token.is_cancelled():
raise asyncio.CancelledError("Operation cancelled")
result_future = asyncio.ensure_future(session.call_tool(name=self._tool.name, arguments=args))
cancellation_token.link_future(result_future)
result = await result_future
normalized_content_list = self._normalize_payload_to_content_list(result.content)
if result.isError:
serialized_error_message = self.return_value_as_string(normalized_content_list)
raise Exception(serialized_error_message)
return normalized_content_list
except exceptions_to_catch:
# Re-raise these specific exception types directly.
raise
@classmethod
async def from_server_params(cls, server_params: TServerParams, tool_name: str) -> "McpToolAdapter[TServerParams]":
"""
Create an instance of McpToolAdapter from server parameters and tool name.
Args:
server_params (TServerParams): Parameters for the MCP server connection.
tool_name (str): The name of the tool to wrap.
Returns:
McpToolAdapter[TServerParams]: An instance of McpToolAdapter.
Raises:
ValueError: If the tool with the specified name is not found.
"""
async with create_mcp_server_session(server_params) as session:
await session.initialize()
tools_response = await session.list_tools()
matching_tool = next((t for t in tools_response.tools if t.name == tool_name), None)
if matching_tool is None:
raise ValueError(
f"Tool '{tool_name}' not found, available tools: {', '.join([t.name for t in tools_response.tools])}"
)
return cls(server_params=server_params, tool=matching_tool)
def return_value_as_string(self, value: list[Any]) -> str:
"""Return a string representation of the result."""
def serialize_item(item: Any) -> dict[str, Any]:
if isinstance(item, (TextContent, ImageContent, AudioContent)):
dumped = item.model_dump()
# Remove the 'meta' field if it exists and is None (for backward compatibility)
if dumped.get("meta") is None:
dumped.pop("meta", None)
return dumped
elif isinstance(item, EmbeddedResource):
type = item.type
resource = {}
for key, val in item.resource.model_dump().items():
# Skip 'meta' field if it's None (for backward compatibility)
if key == "meta" and val is None:
continue
if isinstance(val, AnyUrl):
resource[key] = str(val)
else:
resource[key] = val
dumped_annotations = item.annotations.model_dump() if item.annotations else None
# Remove 'meta' from annotations if it exists and is None
if dumped_annotations and dumped_annotations.get("meta") is None:
dumped_annotations.pop("meta", None)
return {"type": type, "resource": resource, "annotations": dumped_annotations}
elif isinstance(item, ResourceLink):
dumped = item.model_dump()
# Remove the 'meta' field if it exists and is None (for backward compatibility)
if dumped.get("meta") is None:
dumped.pop("meta", None)
# Convert AnyUrl to string for JSON serialization
if "uri" in dumped and isinstance(dumped["uri"], AnyUrl):
dumped["uri"] = str(dumped["uri"])
return dumped
else:
return {}
return json.dumps([serialize_item(item) for item in value])