first commit
This commit is contained in:
19
agent_dhal/agentdhal_extensions/tools/azure/__init__.py
Normal file
19
agent_dhal/agentdhal_extensions/tools/azure/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from ._ai_search import (
|
||||
AzureAISearchTool,
|
||||
BaseAzureAISearchTool,
|
||||
SearchQuery,
|
||||
SearchResult,
|
||||
SearchResults,
|
||||
VectorizableTextQuery,
|
||||
)
|
||||
from ._config import AzureAISearchConfig
|
||||
|
||||
__all__ = [
|
||||
"AzureAISearchTool",
|
||||
"BaseAzureAISearchTool",
|
||||
"SearchQuery",
|
||||
"SearchResult",
|
||||
"SearchResults",
|
||||
"AzureAISearchConfig",
|
||||
"VectorizableTextQuery",
|
||||
]
|
||||
1137
agent_dhal/agentdhal_extensions/tools/azure/_ai_search.py
Normal file
1137
agent_dhal/agentdhal_extensions/tools/azure/_ai_search.py
Normal file
File diff suppressed because it is too large
Load Diff
186
agent_dhal/agentdhal_extensions/tools/azure/_config.py
Normal file
186
agent_dhal/agentdhal_extensions/tools/azure/_config.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Configuration for Azure AI Search tool.
|
||||
|
||||
This module provides configuration classes for the Azure AI Search tool, including
|
||||
settings for authentication, search behavior, retry policies, and caching.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
|
||||
T = TypeVar("T", bound="AzureAISearchConfig")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QueryTypeLiteral = Literal["simple", "full", "semantic", "vector"]
|
||||
DEFAULT_API_VERSION = "2023-10-01-preview"
|
||||
|
||||
|
||||
class AzureAISearchConfig(BaseModel):
|
||||
"""Configuration for Azure AI Search with validation.
|
||||
|
||||
This class defines the configuration parameters for Azure AI Search tools, including
|
||||
authentication, search behavior, caching, and embedding settings.
|
||||
|
||||
.. note::
|
||||
This class requires the ``azure`` extra for the ``autogen-ext`` package.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-ext[azure]"
|
||||
|
||||
.. note::
|
||||
**Prerequisites:**
|
||||
|
||||
1. An Azure AI Search service must be created in your Azure subscription.
|
||||
2. The search index must be properly configured for your use case:
|
||||
|
||||
- For vector search: Index must have vector fields
|
||||
- For semantic search: Index must have semantic configuration
|
||||
- For hybrid search: Both vector fields and text fields must be configured
|
||||
3. Required packages:
|
||||
|
||||
- Base functionality: ``azure-search-documents>=11.4.0``
|
||||
- For Azure OpenAI embeddings: ``openai azure-identity``
|
||||
- For OpenAI embeddings: ``openai``
|
||||
|
||||
Example Usage:
|
||||
.. code-block:: python
|
||||
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from agentdhal_extensions.tools.azure import AzureAISearchConfig
|
||||
|
||||
# Basic configuration for full-text search
|
||||
config = AzureAISearchConfig(
|
||||
name="doc-search",
|
||||
endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint
|
||||
index_name="<your-index>", # Name of your search index
|
||||
credential=AzureKeyCredential("<your-key>"), # Your Azure AI Search admin key
|
||||
query_type="simple",
|
||||
search_fields=["content", "title"], # Update with your searchable fields
|
||||
top=5,
|
||||
)
|
||||
|
||||
# Configuration for vector search with Azure OpenAI embeddings
|
||||
vector_config = AzureAISearchConfig(
|
||||
name="vector-search",
|
||||
endpoint="https://your-search.search.windows.net",
|
||||
index_name="<your-index>",
|
||||
credential=AzureKeyCredential("<your-key>"),
|
||||
query_type="vector",
|
||||
vector_fields=["embedding"], # Update with your vector field name
|
||||
embedding_provider="azure_openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
openai_endpoint="https://your-openai.openai.azure.com", # Your Azure OpenAI endpoint
|
||||
openai_api_key="<your-openai-key>", # Your Azure OpenAI key
|
||||
top=5,
|
||||
)
|
||||
|
||||
# Configuration for hybrid search with semantic ranking
|
||||
hybrid_config = AzureAISearchConfig(
|
||||
name="hybrid-search",
|
||||
endpoint="https://your-search.search.windows.net",
|
||||
index_name="<your-index>",
|
||||
credential=AzureKeyCredential("<your-key>"),
|
||||
query_type="semantic",
|
||||
semantic_config_name="<your-semantic-config>", # Name of your semantic configuration
|
||||
search_fields=["content", "title"], # Update with your search fields
|
||||
vector_fields=["embedding"], # Update with your vector field name
|
||||
embedding_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
openai_api_key="<your-openai-key>", # Your OpenAI API key
|
||||
top=5,
|
||||
)
|
||||
"""
|
||||
|
||||
name: str = Field(description="The name of this tool instance")
|
||||
description: Optional[str] = Field(default=None, description="Description explaining the tool's purpose")
|
||||
endpoint: str = Field(description="The full URL of your Azure AI Search service")
|
||||
index_name: str = Field(description="Name of the search index to query")
|
||||
credential: Union[AzureKeyCredential, AsyncTokenCredential] = Field(
|
||||
description="Azure credential for authentication (API key or token)"
|
||||
)
|
||||
api_version: str = Field(
|
||||
default=DEFAULT_API_VERSION,
|
||||
description=f"Azure AI Search API version to use. Defaults to {DEFAULT_API_VERSION}.",
|
||||
)
|
||||
query_type: QueryTypeLiteral = Field(
|
||||
default="simple", description="Type of search to perform: simple, full, semantic, or vector"
|
||||
)
|
||||
search_fields: Optional[List[str]] = Field(default=None, description="Fields to search within documents")
|
||||
select_fields: Optional[List[str]] = Field(default=None, description="Fields to return in search results")
|
||||
vector_fields: Optional[List[str]] = Field(default=None, description="Fields to use for vector search")
|
||||
top: Optional[int] = Field(
|
||||
default=None, description="Maximum number of results to return. For vector searches, acts as k in k-NN."
|
||||
)
|
||||
filter: Optional[str] = Field(default=None, description="OData filter expression to refine search results")
|
||||
semantic_config_name: Optional[str] = Field(
|
||||
default=None, description="Semantic configuration name for enhanced results"
|
||||
)
|
||||
|
||||
enable_caching: bool = Field(default=False, description="Whether to cache search results")
|
||||
cache_ttl_seconds: int = Field(default=300, description="How long to cache results in seconds")
|
||||
|
||||
embedding_provider: Optional[str] = Field(
|
||||
default=None, description="Name of embedding provider for client-side embeddings"
|
||||
)
|
||||
embedding_model: Optional[str] = Field(default=None, description="Model name for client-side embeddings")
|
||||
openai_api_key: Optional[str] = Field(default=None, description="API key for OpenAI/Azure OpenAI embeddings")
|
||||
openai_api_version: Optional[str] = Field(default=None, description="API version for Azure OpenAI embeddings")
|
||||
openai_endpoint: Optional[str] = Field(default=None, description="Endpoint URL for Azure OpenAI embeddings")
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@field_validator("endpoint")
|
||||
def validate_endpoint(cls, v: str) -> str:
|
||||
"""Validate that the endpoint is a valid URL."""
|
||||
if not v.startswith(("http://", "https://")):
|
||||
raise ValueError("endpoint must be a valid URL starting with http:// or https://")
|
||||
return v
|
||||
|
||||
@field_validator("query_type")
|
||||
def normalize_query_type(cls, v: QueryTypeLiteral) -> QueryTypeLiteral:
|
||||
"""Normalize query type to standard values."""
|
||||
if not v:
|
||||
return "simple"
|
||||
|
||||
if isinstance(v, str) and v.lower() == "fulltext":
|
||||
return "full"
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("top")
|
||||
def validate_top(cls, v: Optional[int]) -> Optional[int]:
|
||||
"""Ensure top is a positive integer if provided."""
|
||||
if v is not None and v <= 0:
|
||||
raise ValueError("top must be a positive integer")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_interdependent_fields(self) -> "AzureAISearchConfig":
|
||||
"""Validate interdependent fields after all fields have been parsed."""
|
||||
if self.query_type == "semantic" and not self.semantic_config_name:
|
||||
raise ValueError("semantic_config_name must be provided when query_type is 'semantic'")
|
||||
|
||||
if self.query_type == "vector" and not self.vector_fields:
|
||||
raise ValueError("vector_fields must be provided for vector search")
|
||||
|
||||
if (
|
||||
self.embedding_provider
|
||||
and self.embedding_provider.lower() == "azure_openai"
|
||||
and self.embedding_model
|
||||
and not self.openai_endpoint
|
||||
):
|
||||
raise ValueError("openai_endpoint must be provided for azure_openai embedding provider")
|
||||
|
||||
return self
|
||||
@@ -0,0 +1,3 @@
|
||||
from ._code_execution import CodeExecutionInput, CodeExecutionResult, PythonCodeExecutionTool
|
||||
|
||||
__all__ = ["CodeExecutionInput", "CodeExecutionResult", "PythonCodeExecutionTool"]
|
||||
@@ -0,0 +1,96 @@
|
||||
from agentdhal_core import CancellationToken, Component, ComponentModel
|
||||
from agentdhal_core.code_executor import CodeBlock, CodeExecutor
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field, model_serializer
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class CodeExecutionInput(BaseModel):
|
||||
code: str = Field(description="The contents of the Python code block that should be executed")
|
||||
|
||||
|
||||
class CodeExecutionResult(BaseModel):
|
||||
success: bool
|
||||
output: str
|
||||
|
||||
@model_serializer
|
||||
def ser_model(self) -> str:
|
||||
return self.output
|
||||
|
||||
|
||||
class PythonCodeExecutionToolConfig(BaseModel):
|
||||
"""Configuration for PythonCodeExecutionTool"""
|
||||
|
||||
executor: ComponentModel
|
||||
description: str = "Execute Python code blocks."
|
||||
|
||||
|
||||
class PythonCodeExecutionTool(
|
||||
BaseTool[CodeExecutionInput, CodeExecutionResult], Component[PythonCodeExecutionToolConfig]
|
||||
):
|
||||
"""A tool that executes Python code in a code executor and returns output.
|
||||
|
||||
Example executors:
|
||||
|
||||
* :class:`agentdhal_extensions.code_executors.local.LocalCommandLineCodeExecutor`
|
||||
* :class:`agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor`
|
||||
* :class:`agentdhal_extensions.code_executors.azure.ACADynamicSessionsCodeExecutor`
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-agentchat" "agentdhal-ext[openai]" "yfinance" "matplotlib"
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.code_executors.local import LocalCommandLineCodeExecutor
|
||||
from agentdhal_extensions.tools.code_execution import PythonCodeExecutionTool
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
tool = PythonCodeExecutionTool(LocalCommandLineCodeExecutor(work_dir="coding"))
|
||||
agent = AssistantAgent(
|
||||
"assistant", OpenAIChatCompletionClient(model="gpt-4o"), tools=[tool], reflect_on_tool_use=True
|
||||
)
|
||||
await Console(
|
||||
agent.run_stream(
|
||||
task="Create a plot of MSFT stock prices in 2024 and save it to a file. Use yfinance and matplotlib."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Args:
|
||||
executor (CodeExecutor): The code executor that will be used to execute the code blocks.
|
||||
"""
|
||||
|
||||
component_config_schema = PythonCodeExecutionToolConfig
|
||||
component_provider_override = "agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool"
|
||||
|
||||
def __init__(self, executor: CodeExecutor):
|
||||
super().__init__(CodeExecutionInput, CodeExecutionResult, "CodeExecutor", "Execute Python code blocks.")
|
||||
self._executor = executor
|
||||
|
||||
async def run(self, args: CodeExecutionInput, cancellation_token: CancellationToken) -> CodeExecutionResult:
|
||||
code_blocks = [CodeBlock(code=args.code, language="python")]
|
||||
result = await self._executor.execute_code_blocks(
|
||||
code_blocks=code_blocks, cancellation_token=cancellation_token
|
||||
)
|
||||
return CodeExecutionResult(success=result.exit_code == 0, output=result.output)
|
||||
|
||||
def _to_config(self) -> PythonCodeExecutionToolConfig:
|
||||
"""Convert current instance to config object"""
|
||||
return PythonCodeExecutionToolConfig(executor=self._executor.dump_component())
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: PythonCodeExecutionToolConfig) -> Self:
|
||||
"""Create instance from config object"""
|
||||
executor = CodeExecutor.load_component(config.executor)
|
||||
return cls(executor=executor)
|
||||
25
agent_dhal/agentdhal_extensions/tools/graphrag/__init__.py
Normal file
25
agent_dhal/agentdhal_extensions/tools/graphrag/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from ._config import (
|
||||
GlobalContextConfig,
|
||||
GlobalDataConfig,
|
||||
LocalContextConfig,
|
||||
LocalDataConfig,
|
||||
MapReduceConfig,
|
||||
SearchConfig,
|
||||
)
|
||||
from ._global_search import GlobalSearchTool, GlobalSearchToolArgs, GlobalSearchToolReturn
|
||||
from ._local_search import LocalSearchTool, LocalSearchToolArgs, LocalSearchToolReturn
|
||||
|
||||
__all__ = [
|
||||
"GlobalSearchTool",
|
||||
"LocalSearchTool",
|
||||
"GlobalDataConfig",
|
||||
"LocalDataConfig",
|
||||
"GlobalContextConfig",
|
||||
"GlobalSearchToolArgs",
|
||||
"GlobalSearchToolReturn",
|
||||
"LocalContextConfig",
|
||||
"LocalSearchToolArgs",
|
||||
"LocalSearchToolReturn",
|
||||
"MapReduceConfig",
|
||||
"SearchConfig",
|
||||
]
|
||||
59
agent_dhal/agentdhal_extensions/tools/graphrag/_config.py
Normal file
59
agent_dhal/agentdhal_extensions/tools/graphrag/_config.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DataConfig(BaseModel):
|
||||
input_dir: str
|
||||
entity_table: str = "entities"
|
||||
entity_embedding_table: str = "entities"
|
||||
community_table: str = "communities"
|
||||
community_level: int = 2
|
||||
|
||||
|
||||
class GlobalDataConfig(DataConfig):
|
||||
community_report_table: str = "community_reports"
|
||||
|
||||
|
||||
class LocalDataConfig(DataConfig):
|
||||
relationship_table: str = "relationships"
|
||||
text_unit_table: str = "text_units"
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
max_data_tokens: int = 8000
|
||||
|
||||
|
||||
class GlobalContextConfig(ContextConfig):
|
||||
use_community_summary: bool = False
|
||||
shuffle_data: bool = True
|
||||
include_community_rank: bool = True
|
||||
min_community_rank: int = 0
|
||||
community_rank_name: str = "rank"
|
||||
include_community_weight: bool = True
|
||||
community_weight_name: str = "occurrence weight"
|
||||
normalize_community_weight: bool = True
|
||||
max_data_tokens: int = 12000
|
||||
|
||||
|
||||
class LocalContextConfig(ContextConfig):
|
||||
text_unit_prop: float = 0.5
|
||||
community_prop: float = 0.25
|
||||
include_entity_rank: bool = True
|
||||
rank_description: str = "number of relationships"
|
||||
include_relationship_weight: bool = True
|
||||
relationship_ranking_attribute: str = "rank"
|
||||
|
||||
|
||||
class MapReduceConfig(BaseModel):
|
||||
map_max_tokens: int = 1000
|
||||
map_temperature: float = 0.0
|
||||
reduce_max_tokens: int = 2000
|
||||
reduce_temperature: float = 0.0
|
||||
allow_general_knowledge: bool = False
|
||||
json_mode: bool = False
|
||||
response_type: str = "multiple paragraphs"
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
max_tokens: int = 1500
|
||||
temperature: float = 0.0
|
||||
response_type: str = "multiple paragraphs"
|
||||
233
agent_dhal/agentdhal_extensions/tools/graphrag/_global_search.py
Normal file
233
agent_dhal/agentdhal_extensions/tools/graphrag/_global_search.py
Normal file
@@ -0,0 +1,233 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol import ChatModel
|
||||
from graphrag.query.indexer_adapters import (
|
||||
read_indexer_communities,
|
||||
read_indexer_entities,
|
||||
read_indexer_reports,
|
||||
)
|
||||
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
|
||||
from graphrag.query.structured_search.global_search.search import GlobalSearch
|
||||
|
||||
from ._config import GlobalContextConfig as ContextConfig
|
||||
from ._config import GlobalDataConfig as DataConfig
|
||||
from ._config import MapReduceConfig
|
||||
|
||||
_default_context_config = ContextConfig()
|
||||
_default_mapreduce_config = MapReduceConfig()
|
||||
|
||||
|
||||
class GlobalSearchToolArgs(BaseModel):
|
||||
query: str = Field(..., description="The user query to perform global search on.")
|
||||
|
||||
|
||||
class GlobalSearchToolReturn(BaseModel):
|
||||
answer: str
|
||||
|
||||
|
||||
class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]):
|
||||
"""Enables running GraphRAG global search queries as an AutoGen tool.
|
||||
|
||||
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework.
|
||||
The search combines graph-based document relationships with semantic embeddings to find relevant information.
|
||||
|
||||
.. note::
|
||||
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package.
|
||||
|
||||
To install:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-agentchat" "agentdhal-ext[graphrag]"
|
||||
|
||||
Before using this tool, you must complete the GraphRAG setup and indexing process:
|
||||
|
||||
1. Follow the GraphRAG documentation to initialize your project and settings
|
||||
2. Configure and tune your prompts for the specific use case
|
||||
3. Run the indexing process to generate the required data files
|
||||
4. Ensure you have the settings.yaml file from the setup process
|
||||
|
||||
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/)
|
||||
for detailed instructions on completing these prerequisite steps.
|
||||
|
||||
Example usage with AssistantAgent:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.tools.graphrag import GlobalSearchTool
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize the OpenAI client
|
||||
openai_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key="<api-key>",
|
||||
)
|
||||
|
||||
# Set up global search tool
|
||||
global_tool = GlobalSearchTool.from_settings(root_dir=Path("./"), config_filepath=Path("./settings.yaml"))
|
||||
|
||||
# Create assistant agent with the global search tool
|
||||
assistant_agent = AssistantAgent(
|
||||
name="search_assistant",
|
||||
tools=[global_tool],
|
||||
model_client=openai_client,
|
||||
system_message=(
|
||||
"You are a tool selector AI assistant using the GraphRAG framework. "
|
||||
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
|
||||
"For broader, abstract questions requiring a comprehensive understanding of the dataset, call the 'global_search' function."
|
||||
),
|
||||
)
|
||||
|
||||
# Run a sample query
|
||||
query = "What is the overall sentiment of the community reports?"
|
||||
await Console(assistant_agent.run_stream(task=query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_encoder: tiktoken.Encoding,
|
||||
model: ChatModel,
|
||||
data_config: DataConfig,
|
||||
context_config: ContextConfig = _default_context_config,
|
||||
mapreduce_config: MapReduceConfig = _default_mapreduce_config,
|
||||
):
|
||||
super().__init__(
|
||||
args_type=GlobalSearchToolArgs,
|
||||
return_type=GlobalSearchToolReturn,
|
||||
name="global_search_tool",
|
||||
description="Perform a global search with given parameters using graphrag.",
|
||||
)
|
||||
# Use the provided model
|
||||
self._model = model
|
||||
|
||||
# Load parquet files
|
||||
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
|
||||
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
|
||||
report_df: pd.DataFrame = pd.read_parquet( # type: ignore
|
||||
f"{data_config.input_dir}/{data_config.community_report_table}.parquet"
|
||||
)
|
||||
|
||||
# Fix: Use correct argument order and types for GraphRAG API
|
||||
communities = read_indexer_communities(community_df, report_df)
|
||||
reports = read_indexer_reports(report_df, community_df, data_config.community_level)
|
||||
entities = read_indexer_entities(entity_df, community_df, data_config.community_level)
|
||||
|
||||
context_builder = GlobalCommunityContext(
|
||||
community_reports=reports,
|
||||
communities=communities,
|
||||
entities=entities,
|
||||
token_encoder=token_encoder,
|
||||
)
|
||||
|
||||
context_builder_params = {
|
||||
"use_community_summary": context_config.use_community_summary,
|
||||
"shuffle_data": context_config.shuffle_data,
|
||||
"include_community_rank": context_config.include_community_rank,
|
||||
"min_community_rank": context_config.min_community_rank,
|
||||
"community_rank_name": context_config.community_rank_name,
|
||||
"include_community_weight": context_config.include_community_weight,
|
||||
"community_weight_name": context_config.community_weight_name,
|
||||
"normalize_community_weight": context_config.normalize_community_weight,
|
||||
"max_tokens": context_config.max_data_tokens,
|
||||
"context_name": "Reports",
|
||||
}
|
||||
|
||||
map_llm_params = {
|
||||
"max_tokens": mapreduce_config.map_max_tokens,
|
||||
"temperature": mapreduce_config.map_temperature,
|
||||
"response_format": {"type": "json_object"},
|
||||
}
|
||||
|
||||
reduce_llm_params = {
|
||||
"max_tokens": mapreduce_config.reduce_max_tokens,
|
||||
"temperature": mapreduce_config.reduce_temperature,
|
||||
}
|
||||
|
||||
self._search_engine = GlobalSearch(
|
||||
model=self._model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
max_data_tokens=context_config.max_data_tokens,
|
||||
map_llm_params=map_llm_params,
|
||||
reduce_llm_params=reduce_llm_params,
|
||||
allow_general_knowledge=mapreduce_config.allow_general_knowledge,
|
||||
json_mode=mapreduce_config.json_mode,
|
||||
context_builder_params=context_builder_params,
|
||||
concurrent_coroutines=32,
|
||||
response_type=mapreduce_config.response_type,
|
||||
)
|
||||
|
||||
async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> GlobalSearchToolReturn:
|
||||
search_result = await self._search_engine.search(args.query)
|
||||
assert isinstance(search_result.response, str), "Expected response to be a string"
|
||||
return GlobalSearchToolReturn(answer=search_result.response)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls, root_dir: str | Path, config_filepath: str | Path | None = None) -> "GlobalSearchTool":
|
||||
"""Create a GlobalSearchTool instance from GraphRAG settings file.
|
||||
|
||||
Args:
|
||||
root_dir: Path to the GraphRAG root directory
|
||||
config_filepath: Path to the GraphRAG settings file (optional)
|
||||
|
||||
Returns:
|
||||
An initialized GlobalSearchTool instance
|
||||
"""
|
||||
# Load GraphRAG config
|
||||
if isinstance(root_dir, str):
|
||||
root_dir = Path(root_dir)
|
||||
if isinstance(config_filepath, str):
|
||||
config_filepath = Path(config_filepath)
|
||||
config = load_config(root_dir=root_dir, config_filepath=config_filepath)
|
||||
|
||||
# Get the language model configuration from the models section
|
||||
chat_model_config = config.models.get(defs.DEFAULT_CHAT_MODEL_ID)
|
||||
|
||||
if chat_model_config is None:
|
||||
raise ValueError("default_chat_model not found in config.models")
|
||||
|
||||
# Initialize token encoder based on the model being used
|
||||
try:
|
||||
token_encoder = tiktoken.encoding_for_model(chat_model_config.model)
|
||||
except KeyError:
|
||||
# Fallback to cl100k_base if model is not recognized by tiktoken
|
||||
token_encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
# Create the LLM using ModelManager
|
||||
model = ModelManager().get_or_create_chat_model(
|
||||
name="global_search_model",
|
||||
model_type=chat_model_config.type,
|
||||
config=chat_model_config,
|
||||
)
|
||||
|
||||
# Create data config from storage paths
|
||||
data_config = DataConfig(
|
||||
input_dir=str(config.output.base_dir),
|
||||
)
|
||||
|
||||
return cls(
|
||||
token_encoder=token_encoder,
|
||||
model=model,
|
||||
data_config=data_config,
|
||||
context_config=_default_context_config,
|
||||
mapreduce_config=_default_mapreduce_config,
|
||||
)
|
||||
245
agent_dhal/agentdhal_extensions/tools/graphrag/_local_search.py
Normal file
245
agent_dhal/agentdhal_extensions/tools/graphrag/_local_search.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# mypy: disable-error-code="no-any-unimported,misc"
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol import ChatModel, EmbeddingModel
|
||||
from graphrag.query.indexer_adapters import (
|
||||
read_indexer_entities,
|
||||
read_indexer_relationships,
|
||||
read_indexer_text_units,
|
||||
)
|
||||
from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext
|
||||
from graphrag.query.structured_search.local_search.search import LocalSearch
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
from ._config import LocalContextConfig, SearchConfig
|
||||
from ._config import LocalDataConfig as DataConfig
|
||||
|
||||
_default_context_config = LocalContextConfig()
|
||||
_default_search_config = SearchConfig()
|
||||
|
||||
|
||||
class LocalSearchToolArgs(BaseModel):
|
||||
query: str = Field(..., description="The user query to perform local search on.")
|
||||
|
||||
|
||||
class LocalSearchToolReturn(BaseModel):
|
||||
answer: str = Field(..., description="The answer to the user query.")
|
||||
|
||||
|
||||
class LocalSearchTool(BaseTool[LocalSearchToolArgs, LocalSearchToolReturn]):
|
||||
"""Enables running GraphRAG local search queries as an AutoGen tool.
|
||||
|
||||
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework.
|
||||
The search combines local document context with semantic embeddings to find relevant information.
|
||||
|
||||
.. note::
|
||||
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package.
|
||||
To install:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-agentchat" "agentdhal-ext[graphrag]"
|
||||
|
||||
Before using this tool, you must complete the GraphRAG setup and indexing process:
|
||||
|
||||
1. Follow the GraphRAG documentation to initialize your project and settings
|
||||
2. Configure and tune your prompts for the specific use case
|
||||
3. Run the indexing process to generate the required data files
|
||||
4. Ensure you have the settings.yaml file from the setup process
|
||||
|
||||
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/)
|
||||
for detailed instructions on completing these prerequisite steps.
|
||||
|
||||
Example usage with AssistantAgent:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.tools.graphrag import LocalSearchTool
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize the OpenAI client
|
||||
openai_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key="<api-key>",
|
||||
)
|
||||
|
||||
# Set up local search tool
|
||||
local_tool = LocalSearchTool.from_settings(root_dir=Path("./"), config_filepath=Path("./settings.yaml"))
|
||||
|
||||
# Create assistant agent with the local search tool
|
||||
assistant_agent = AssistantAgent(
|
||||
name="search_assistant",
|
||||
tools=[local_tool],
|
||||
model_client=openai_client,
|
||||
system_message=(
|
||||
"You are a tool selector AI assistant using the GraphRAG framework. "
|
||||
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
|
||||
"For specific, detailed information about particular entities or relationships, call the 'local_search' function."
|
||||
),
|
||||
)
|
||||
|
||||
# Run a sample query
|
||||
query = "What does the station-master say about Dr. Becher?"
|
||||
await Console(assistant_agent.run_stream(task=query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Args:
|
||||
token_encoder (tiktoken.Encoding): The tokenizer used for text encoding
|
||||
model: The chat model to use for search (GraphRAG ChatModel)
|
||||
embedder: The text embedding model to use (GraphRAG EmbeddingModel)
|
||||
data_config (DataConfig): Configuration for data source locations and settings
|
||||
context_config (LocalContextConfig, optional): Configuration for context building. Defaults to default config.
|
||||
search_config (SearchConfig, optional): Configuration for search operations. Defaults to default config.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_encoder: tiktoken.Encoding,
|
||||
model: ChatModel, # ChatModel from GraphRAG
|
||||
embedder: EmbeddingModel, # EmbeddingModel from GraphRAG
|
||||
data_config: DataConfig,
|
||||
context_config: LocalContextConfig = _default_context_config,
|
||||
search_config: SearchConfig = _default_search_config,
|
||||
):
|
||||
super().__init__(
|
||||
args_type=LocalSearchToolArgs,
|
||||
return_type=LocalSearchToolReturn,
|
||||
name="local_search_tool",
|
||||
description="Perform a local search with given parameters using graphrag.",
|
||||
)
|
||||
# Use the provided models
|
||||
self._model = model
|
||||
self._embedder = embedder
|
||||
|
||||
# Load parquet files
|
||||
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
|
||||
relationship_df: pd.DataFrame = pd.read_parquet( # type: ignore
|
||||
f"{data_config.input_dir}/{data_config.relationship_table}.parquet"
|
||||
)
|
||||
text_unit_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.text_unit_table}.parquet") # type: ignore
|
||||
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
|
||||
|
||||
# Read data using indexer adapters
|
||||
entities = read_indexer_entities(entity_df, community_df, data_config.community_level)
|
||||
relationships = read_indexer_relationships(relationship_df)
|
||||
text_units = read_indexer_text_units(text_unit_df)
|
||||
# Set up vector store for entity embeddings
|
||||
description_embedding_store = LanceDBVectorStore(
|
||||
collection_name="default-entity-description",
|
||||
)
|
||||
description_embedding_store.connect(db_uri=f"{data_config.input_dir}/lancedb")
|
||||
|
||||
# Set up context builder
|
||||
context_builder = LocalSearchMixedContext(
|
||||
entities=entities,
|
||||
entity_text_embeddings=description_embedding_store,
|
||||
text_embedder=self._embedder,
|
||||
text_units=text_units,
|
||||
relationships=relationships,
|
||||
token_encoder=token_encoder,
|
||||
)
|
||||
|
||||
context_builder_params = {
|
||||
"text_unit_prop": context_config.text_unit_prop,
|
||||
"community_prop": context_config.community_prop,
|
||||
"include_entity_rank": context_config.include_entity_rank,
|
||||
"rank_description": context_config.rank_description,
|
||||
"include_relationship_weight": context_config.include_relationship_weight,
|
||||
"relationship_ranking_attribute": context_config.relationship_ranking_attribute,
|
||||
"max_tokens": context_config.max_data_tokens,
|
||||
}
|
||||
|
||||
llm_params = {
|
||||
"max_tokens": search_config.max_tokens,
|
||||
"temperature": search_config.temperature,
|
||||
}
|
||||
|
||||
self._search_engine = LocalSearch(
|
||||
model=self._model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
response_type=search_config.response_type,
|
||||
context_builder_params=context_builder_params,
|
||||
model_params=llm_params,
|
||||
)
|
||||
|
||||
async def run(self, args: LocalSearchToolArgs, cancellation_token: CancellationToken) -> LocalSearchToolReturn:
|
||||
search_result = await self._search_engine.search(args.query) # type: ignore[reportUnknownMemberType]
|
||||
assert isinstance(search_result.response, str), "Expected response to be a string"
|
||||
return LocalSearchToolReturn(answer=search_result.response)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls, root_dir: Path, config_filepath: Path | None = None) -> "LocalSearchTool":
|
||||
"""Create a LocalSearchTool instance from GraphRAG settings file.
|
||||
|
||||
Args:
|
||||
root_dir: Path to the GraphRAG root directory
|
||||
config_filepath: Path to the GraphRAG settings file (optional)
|
||||
|
||||
Returns:
|
||||
An initialized LocalSearchTool instance
|
||||
"""
|
||||
# Load GraphRAG config
|
||||
config = load_config(root_dir=root_dir, config_filepath=config_filepath)
|
||||
|
||||
# Get the language model configurations from the models section
|
||||
chat_model_config = config.models.get(defs.DEFAULT_CHAT_MODEL_ID)
|
||||
embedding_model_config = config.models.get(defs.DEFAULT_EMBEDDING_MODEL_ID)
|
||||
|
||||
if chat_model_config is None:
|
||||
raise ValueError("default_chat_model not found in config.models")
|
||||
if embedding_model_config is None:
|
||||
raise ValueError("default_embedding_model not found in config.models")
|
||||
|
||||
# Initialize token encoder based on the model being used
|
||||
try:
|
||||
token_encoder = tiktoken.encoding_for_model(chat_model_config.model)
|
||||
except KeyError:
|
||||
# Fallback to cl100k_base if model is not recognized by tiktoken
|
||||
token_encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
# Create the models using ModelManager
|
||||
model = ModelManager().get_or_create_chat_model(
|
||||
name="local_search_model",
|
||||
model_type=chat_model_config.type,
|
||||
config=chat_model_config,
|
||||
)
|
||||
|
||||
embedder = ModelManager().get_or_create_embedding_model(
|
||||
name="local_search_embedder",
|
||||
model_type=embedding_model_config.type,
|
||||
config=embedding_model_config,
|
||||
)
|
||||
|
||||
# Create data config from storage paths
|
||||
data_config = DataConfig(
|
||||
input_dir=str(config.output.base_dir),
|
||||
)
|
||||
|
||||
return cls(
|
||||
token_encoder=token_encoder,
|
||||
model=model,
|
||||
embedder=embedder,
|
||||
data_config=data_config,
|
||||
context_config=_default_context_config,
|
||||
search_config=_default_search_config,
|
||||
)
|
||||
3
agent_dhal/agentdhal_extensions/tools/http/__init__.py
Normal file
3
agent_dhal/agentdhal_extensions/tools/http/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from ._http_tool import HttpTool
|
||||
|
||||
__all__ = ["HttpTool"]
|
||||
244
agent_dhal/agentdhal_extensions/tools/http/_http_tool.py
Normal file
244
agent_dhal/agentdhal_extensions/tools/http/_http_tool.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import re
|
||||
from typing import Any, Literal, Optional, Type
|
||||
|
||||
import httpx
|
||||
from agentdhal_core import CancellationToken, Component
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from json_schema_to_pydantic import create_model
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
DEFAULT_TIMEOUT_CONFIG = 5.0
|
||||
|
||||
|
||||
class HttpToolConfig(BaseModel):
|
||||
name: str
|
||||
"""
|
||||
The name of the tool.
|
||||
"""
|
||||
description: Optional[str]
|
||||
"""
|
||||
A description of the tool.
|
||||
"""
|
||||
scheme: Literal["http", "https"] = "http"
|
||||
"""
|
||||
The scheme to use for the request.
|
||||
"""
|
||||
host: str
|
||||
"""
|
||||
The URL to send the request to.
|
||||
"""
|
||||
port: int
|
||||
"""
|
||||
The port to send the request to.
|
||||
"""
|
||||
path: str = Field(default="/")
|
||||
"""
|
||||
The path to send the request to. defaults to "/"
|
||||
The path can accept parameters, e.g. "/{param1}/{param2}".
|
||||
These parameters will be templated from the inputs args, any additional parameters will be added as query parameters or the body of the request.
|
||||
"""
|
||||
method: Optional[Literal["GET", "POST", "PUT", "DELETE", "PATCH"]] = "POST"
|
||||
"""
|
||||
The HTTP method to use, will default to POST if not provided.
|
||||
"""
|
||||
headers: Optional[dict[str, Any]]
|
||||
"""
|
||||
A dictionary of headers to send with the request.
|
||||
"""
|
||||
json_schema: dict[str, Any]
|
||||
"""
|
||||
A JSON Schema object defining the expected parameters for the tool.
|
||||
Path parameters MUST also be included in the json_schema. They must also MUST be set to string
|
||||
"""
|
||||
return_type: Optional[Literal["text", "json"]] = "text"
|
||||
"""
|
||||
The type of response to return from the tool.
|
||||
"""
|
||||
timeout: float = DEFAULT_TIMEOUT_CONFIG
|
||||
"""
|
||||
The timeout for the tool request in seconds.
|
||||
"""
|
||||
|
||||
|
||||
class HttpTool(BaseTool[BaseModel, Any], Component[HttpToolConfig]):
|
||||
"""A wrapper for using an HTTP server as a tool.
|
||||
|
||||
Args:
|
||||
name (str): The name of the tool.
|
||||
description (str, optional): A description of the tool.
|
||||
scheme (str): The scheme to use for the request. Must be either "http" or "https".
|
||||
host (str): The host to send the request to.
|
||||
port (int): The port to send the request to.
|
||||
path (str, optional): The path to send the request to. Defaults to "/".
|
||||
Can include path parameters like "/{param1}/{param2}" which will be templated from input args.
|
||||
method (str, optional): The HTTP method to use, will default to POST if not provided.
|
||||
Must be one of "GET", "POST", "PUT", "DELETE", "PATCH".
|
||||
headers (dict[str, Any], optional): A dictionary of headers to send with the request.
|
||||
json_schema (dict[str, Any]): A JSON Schema object defining the expected parameters for the tool.
|
||||
Path parameters must also be included in the schema and must be strings.
|
||||
return_type (Literal["text", "json"], optional): The type of response to return from the tool.
|
||||
Defaults to "text".
|
||||
timeout (float, optional): The timeout for HTTP requests in seconds.
|
||||
Defaults to 5.0.
|
||||
|
||||
.. note::
|
||||
This tool requires the :code:`http-tool` extra for the :code:`autogen-ext` package.
|
||||
|
||||
To install:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-agentchat" "agentdhal-ext[http-tool]"
|
||||
|
||||
Example:
|
||||
Simple use case::
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.tools.http import HttpTool
|
||||
|
||||
# Define a JSON schema for a base64 decode tool
|
||||
base64_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {"type": "string", "description": "The base64 value to decode"},
|
||||
},
|
||||
"required": ["value"],
|
||||
}
|
||||
|
||||
# Create an HTTP tool for the httpbin API
|
||||
base64_tool = HttpTool(
|
||||
name="base64_decode",
|
||||
description="base64 decode a value",
|
||||
scheme="https",
|
||||
host="httpbin.org",
|
||||
port=443,
|
||||
path="/base64/{value}",
|
||||
method="GET",
|
||||
json_schema=base64_schema,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
# Create an assistant with the base64 tool
|
||||
model = OpenAIChatCompletionClient(model="gpt-4")
|
||||
assistant = AssistantAgent("base64_assistant", model_client=model, tools=[base64_tool])
|
||||
|
||||
# The assistant can now use the base64 tool to decode the string
|
||||
response = await assistant.on_messages(
|
||||
[TextMessage(content="Can you base64 decode the value 'YWJjZGU=', please?", source="user")],
|
||||
CancellationToken(),
|
||||
)
|
||||
print(response.chat_message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
component_type = "tool"
|
||||
component_provider_override = "agentdhal_extensions.tools.http.HttpTool"
|
||||
component_config_schema = HttpToolConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
host: str,
|
||||
port: int,
|
||||
json_schema: dict[str, Any],
|
||||
headers: Optional[dict[str, Any]] = None,
|
||||
description: str = "HTTP tool",
|
||||
path: str = "/",
|
||||
scheme: Literal["http", "https"] = "http",
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "POST",
|
||||
return_type: Literal["text", "json"] = "text",
|
||||
timeout: float = DEFAULT_TIMEOUT_CONFIG,
|
||||
) -> None:
|
||||
self.server_params = HttpToolConfig(
|
||||
name=name,
|
||||
description=description,
|
||||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
scheme=scheme,
|
||||
method=method,
|
||||
headers=headers,
|
||||
json_schema=json_schema,
|
||||
return_type=return_type,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Use regex to find all path parameters, we will need those later to template the path
|
||||
path_params = {match.group(1) for match in re.finditer(r"{([^}]*)}", path)}
|
||||
self._path_params = path_params
|
||||
|
||||
# Create the input model from the modified schema
|
||||
input_model = create_model(json_schema)
|
||||
|
||||
# Use Any as return type since HTTP responses can vary
|
||||
base_return_type: Type[Any] = object
|
||||
|
||||
super().__init__(input_model, base_return_type, name, description)
|
||||
|
||||
def _to_config(self) -> HttpToolConfig:
|
||||
copied_config = self.server_params.model_copy()
|
||||
return copied_config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: HttpToolConfig) -> Self:
|
||||
copied_config = config.model_copy().model_dump()
|
||||
return cls(**copied_config)
|
||||
|
||||
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
|
||||
"""Execute the HTTP tool with the given arguments.
|
||||
|
||||
Args:
|
||||
args: The validated input arguments
|
||||
cancellation_token: Token for cancelling the operation
|
||||
|
||||
Returns:
|
||||
The response body from the HTTP call in JSON format
|
||||
|
||||
Raises:
|
||||
Exception: If tool execution fails
|
||||
"""
|
||||
|
||||
model_dump = args.model_dump()
|
||||
path_params = {k: v for k, v in model_dump.items() if k in self._path_params}
|
||||
# Remove path params from the model dump
|
||||
for k in self._path_params:
|
||||
model_dump.pop(k)
|
||||
|
||||
path = self.server_params.path.format(**path_params)
|
||||
|
||||
url = httpx.URL(
|
||||
scheme=self.server_params.scheme,
|
||||
host=self.server_params.host,
|
||||
port=self.server_params.port,
|
||||
path=path,
|
||||
)
|
||||
timeout_config = httpx.Timeout(timeout=self.server_params.timeout)
|
||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||
match self.server_params.method:
|
||||
case "GET":
|
||||
response = await client.get(url, headers=self.server_params.headers, params=model_dump)
|
||||
case "PUT":
|
||||
response = await client.put(url, headers=self.server_params.headers, json=model_dump)
|
||||
case "DELETE":
|
||||
response = await client.delete(url, headers=self.server_params.headers, params=model_dump)
|
||||
case "PATCH":
|
||||
response = await client.patch(url, headers=self.server_params.headers, json=model_dump)
|
||||
case _: # Default case POST
|
||||
response = await client.post(url, headers=self.server_params.headers, json=model_dump)
|
||||
|
||||
match self.server_params.return_type:
|
||||
case "text":
|
||||
return response.text
|
||||
case "json":
|
||||
return response.json()
|
||||
case _:
|
||||
raise ValueError(f"Invalid return type: {self.server_params.return_type}")
|
||||
@@ -0,0 +1,3 @@
|
||||
from ._langchain_adapter import LangChainToolAdapter
|
||||
|
||||
__all__ = ["LangChainToolAdapter"]
|
||||
@@ -0,0 +1,198 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Type, cast
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.tools import BaseTool as LangChainTool
|
||||
|
||||
|
||||
class LangChainToolAdapter(BaseTool[BaseModel, Any]):
|
||||
"""Allows you to wrap a LangChain tool and make it available to AutoGen.
|
||||
|
||||
.. note::
|
||||
|
||||
This class requires the :code:`langchain` extra for the :code:`autogen-ext` package.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-ext[langchain]"
|
||||
|
||||
|
||||
Args:
|
||||
langchain_tool (LangChainTool): A LangChain tool to wrap
|
||||
|
||||
Examples:
|
||||
|
||||
Use the `PythonAstREPLTool` from the `langchain_experimental` package to
|
||||
create a tool that allows you to interact with a Pandas DataFrame.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
import pandas as pd
|
||||
from langchain_experimental.tools.python.tool import PythonAstREPLTool
|
||||
from agentdhal_extensions.tools.langchain import LangChainToolAdapter
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_core import CancellationToken
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
df = pd.read_csv("https://raw.githubusercontent.com/pandas-dev/pandas/main/doc/data/titanic.csv") # type: ignore
|
||||
tool = LangChainToolAdapter(PythonAstREPLTool(locals={"df": df}))
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
agent = AssistantAgent(
|
||||
"assistant",
|
||||
tools=[tool],
|
||||
model_client=model_client,
|
||||
system_message="Use the `df` variable to access the dataset.",
|
||||
)
|
||||
await Console(
|
||||
agent.on_messages_stream(
|
||||
[TextMessage(content="What's the average age of the passengers?", source="user")], CancellationToken()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
This example demonstrates how to use the `SQLDatabaseToolkit` from the `langchain_community`
|
||||
package to interact with an SQLite database.
|
||||
It uses the :class:`~agentdhal_agentchat.team.RoundRobinGroupChat` to iterate the single agent over multiple steps.
|
||||
If you want to one step at a time, you can just call `run_stream` method of the
|
||||
:class:`~agentdhal_agentchat.agents.AssistantAgent` class directly.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
import sqlite3
|
||||
|
||||
import requests
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.tools.langchain import LangChainToolAdapter
|
||||
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
from langchain_openai import ChatOpenAI
|
||||
from sqlalchemy import Engine, create_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
|
||||
def get_engine_for_chinook_db() -> Engine:
|
||||
url = "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql"
|
||||
response = requests.get(url)
|
||||
sql_script = response.text
|
||||
connection = sqlite3.connect(":memory:", check_same_thread=False)
|
||||
connection.executescript(sql_script)
|
||||
return create_engine(
|
||||
"sqlite://",
|
||||
creator=lambda: connection,
|
||||
poolclass=StaticPool,
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create the engine and database wrapper.
|
||||
engine = get_engine_for_chinook_db()
|
||||
db = SQLDatabase(engine)
|
||||
|
||||
# Create the toolkit.
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
||||
|
||||
# Create the LangChain tool adapter for every tool in the toolkit.
|
||||
tools = [LangChainToolAdapter(tool) for tool in toolkit.get_tools()]
|
||||
|
||||
# Create the chat completion client.
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
# Create the assistant agent.
|
||||
agent = AssistantAgent(
|
||||
"assistant",
|
||||
model_client=model_client,
|
||||
tools=tools, # type: ignore
|
||||
model_client_stream=True,
|
||||
system_message="Respond with 'TERMINATE' if the task is completed.",
|
||||
)
|
||||
|
||||
# Create termination condition.
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
|
||||
# Create a round-robin group chat to iterate the single agent over multiple steps.
|
||||
chat = RoundRobinGroupChat([agent], termination_condition=termination)
|
||||
|
||||
# Run the chat.
|
||||
await Console(chat.run_stream(task="Show some tables in the database"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, langchain_tool: LangChainTool):
|
||||
self._langchain_tool: LangChainTool = langchain_tool
|
||||
|
||||
# Extract name and description
|
||||
name = self._langchain_tool.name
|
||||
description = self._langchain_tool.description or ""
|
||||
|
||||
# Determine the callable method
|
||||
if hasattr(self._langchain_tool, "func") and callable(self._langchain_tool.func): # type: ignore
|
||||
assert self._langchain_tool.func is not None # type: ignore
|
||||
self._callable: Callable[..., Any] = self._langchain_tool.func # type: ignore
|
||||
elif hasattr(self._langchain_tool, "_run") and callable(self._langchain_tool._run): # type: ignore
|
||||
self._callable: Callable[..., Any] = self._langchain_tool._run # type: ignore
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"The provided LangChain tool '{name}' does not have a callable 'func' or '_run' method."
|
||||
)
|
||||
|
||||
# Determine args_type
|
||||
if self._langchain_tool.args_schema: # pyright: ignore
|
||||
args_type = self._langchain_tool.args_schema # pyright: ignore
|
||||
else:
|
||||
# Infer args_type from the callable's signature
|
||||
sig = inspect.signature(cast(Callable[..., Any], self._callable)) # type: ignore
|
||||
fields = {
|
||||
k: (v.annotation, Field(...))
|
||||
for k, v in sig.parameters.items()
|
||||
if k != "self" and v.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
|
||||
}
|
||||
args_type = create_model(f"{name}Args", **fields) # type: ignore
|
||||
# Note: type ignore is used due to a LangChain typing limitation
|
||||
|
||||
# Ensure args_type is a subclass of BaseModel
|
||||
if not issubclass(args_type, BaseModel):
|
||||
raise ValueError(f"Failed to create a valid Pydantic v2 model for {name}")
|
||||
|
||||
# Assume return_type as Any if not specified
|
||||
return_type: Type[Any] = object
|
||||
|
||||
super().__init__(args_type, return_type, name, description)
|
||||
|
||||
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
|
||||
# Prepare arguments
|
||||
kwargs = args.model_dump()
|
||||
|
||||
# Determine if the callable is asynchronous
|
||||
if inspect.iscoroutinefunction(self._callable):
|
||||
return await self._callable(**kwargs)
|
||||
else:
|
||||
# Run in a thread to avoid blocking the event loop
|
||||
return await asyncio.to_thread(self._call_sync, kwargs)
|
||||
|
||||
def _call_sync(self, kwargs: Dict[str, Any]) -> Any:
|
||||
return self._callable(**kwargs)
|
||||
22
agent_dhal/agentdhal_extensions/tools/mcp/__init__.py
Normal file
22
agent_dhal/agentdhal_extensions/tools/mcp/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from ._actor import McpSessionActor
|
||||
from ._config import McpServerParams, SseServerParams, StdioServerParams, StreamableHttpServerParams
|
||||
from ._factory import mcp_server_tools
|
||||
from ._session import create_mcp_server_session
|
||||
from ._sse import SseMcpToolAdapter
|
||||
from ._stdio import StdioMcpToolAdapter
|
||||
from ._streamable_http import StreamableHttpMcpToolAdapter
|
||||
from ._workbench import McpWorkbench
|
||||
|
||||
__all__ = [
|
||||
"create_mcp_server_session",
|
||||
"McpSessionActor",
|
||||
"StdioMcpToolAdapter",
|
||||
"StdioServerParams",
|
||||
"SseMcpToolAdapter",
|
||||
"SseServerParams",
|
||||
"StreamableHttpMcpToolAdapter",
|
||||
"StreamableHttpServerParams",
|
||||
"McpServerParams",
|
||||
"mcp_server_tools",
|
||||
"McpWorkbench",
|
||||
]
|
||||
310
agent_dhal/agentdhal_extensions/tools/mcp/_actor.py
Normal file
310
agent_dhal/agentdhal_extensions/tools/mcp/_actor.py
Normal file
@@ -0,0 +1,310 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from typing import Any, Coroutine, Dict, Mapping, TypedDict
|
||||
|
||||
from agentdhal_core import Component, ComponentBase, ComponentModel, Image
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
LLMMessage,
|
||||
ModelInfo,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from PIL import Image as PILImage
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from mcp import types as mcp_types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.shared.context import RequestContext
|
||||
|
||||
from ._config import McpServerParams
|
||||
from ._session import create_mcp_server_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
McpResult = (
|
||||
Coroutine[Any, Any, mcp_types.ListToolsResult]
|
||||
| Coroutine[Any, Any, mcp_types.CallToolResult]
|
||||
| Coroutine[Any, Any, mcp_types.ListPromptsResult]
|
||||
| Coroutine[Any, Any, mcp_types.ListResourcesResult]
|
||||
| Coroutine[Any, Any, mcp_types.ListResourceTemplatesResult]
|
||||
| Coroutine[Any, Any, mcp_types.ReadResourceResult]
|
||||
| Coroutine[Any, Any, mcp_types.GetPromptResult]
|
||||
)
|
||||
McpFuture = asyncio.Future[McpResult]
|
||||
|
||||
|
||||
def _parse_sampling_content(
|
||||
content: mcp_types.TextContent | mcp_types.ImageContent | mcp_types.AudioContent, model_info: ModelInfo
|
||||
) -> str | Image:
|
||||
"""Convert MCP content types to Autogen content types."""
|
||||
if content.type == "text":
|
||||
return content.text
|
||||
elif content.type == "image":
|
||||
if not model_info["vision"]:
|
||||
raise ValueError("Sampling model does not support image content.")
|
||||
# Decode base64 image data and create PIL Image
|
||||
image_data = base64.b64decode(content.data)
|
||||
pil_image = PILImage.open(io.BytesIO(image_data))
|
||||
return Image.from_pil(pil_image)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content.type}")
|
||||
|
||||
|
||||
def _parse_sampling_message(message: mcp_types.SamplingMessage, model_info: ModelInfo) -> LLMMessage:
|
||||
"""Convert MCP sampling messages to Autogen messages."""
|
||||
content = _parse_sampling_content(message.content, model_info=model_info)
|
||||
if message.role == "user":
|
||||
return UserMessage(
|
||||
source="user",
|
||||
content=[content],
|
||||
)
|
||||
elif message.role == "assistant":
|
||||
assert isinstance(content, str), "Assistant messages only support string content."
|
||||
return AssistantMessage(
|
||||
source="assistant",
|
||||
content=content,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized message role: {message.role}")
|
||||
|
||||
|
||||
class McpActorArgs(TypedDict):
|
||||
name: str | None
|
||||
kargs: Mapping[str, Any]
|
||||
|
||||
|
||||
class McpSessionActorConfig(BaseModel):
|
||||
server_params: McpServerParams
|
||||
model_client: ComponentModel | Dict[str, Any] | None = None
|
||||
|
||||
|
||||
class McpSessionActor(ComponentBase[BaseModel], Component[McpSessionActorConfig]):
|
||||
component_type = "mcp_session_actor"
|
||||
component_config_schema = McpSessionActorConfig
|
||||
component_provider_override = "agentdhal_extensions.tools.mcp.McpSessionActor"
|
||||
|
||||
server_params: McpServerParams
|
||||
|
||||
# model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, server_params: McpServerParams, model_client: ChatCompletionClient | None = None) -> None:
|
||||
self.server_params: McpServerParams = server_params
|
||||
self._model_client = model_client
|
||||
self.name = "mcp_session_actor"
|
||||
self.description = "MCP session actor"
|
||||
self._command_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
|
||||
self._actor_task: asyncio.Task[Any] | None = None
|
||||
self._shutdown_future: asyncio.Future[Any] | None = None
|
||||
self._active = False
|
||||
self._initialize_result: mcp_types.InitializeResult | None = None
|
||||
atexit.register(self._sync_shutdown)
|
||||
|
||||
@property
|
||||
def initialize_result(self) -> mcp_types.InitializeResult | None:
|
||||
return self._initialize_result
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if not self._active:
|
||||
self._active = True
|
||||
self._actor_task = asyncio.create_task(self._run_actor())
|
||||
|
||||
async def call(self, type: str, args: McpActorArgs | None = None) -> McpFuture:
|
||||
if not self._active:
|
||||
raise RuntimeError("MCP Actor not running, call initialize() first")
|
||||
if self._actor_task and self._actor_task.done():
|
||||
raise RuntimeError("MCP actor task crashed", self._actor_task.exception())
|
||||
fut: asyncio.Future[McpFuture] = asyncio.Future()
|
||||
if type in {"list_tools", "list_prompts", "list_resources", "list_resource_templates", "shutdown"}:
|
||||
await self._command_queue.put({"type": type, "future": fut})
|
||||
res = await fut
|
||||
elif type in {"call_tool", "read_resource", "get_prompt"}:
|
||||
if args is None:
|
||||
raise ValueError(f"args is required for {type}")
|
||||
name = args.get("name", None)
|
||||
kwargs = args.get("kargs", {})
|
||||
if type == "call_tool" and name is None:
|
||||
raise ValueError("name is required for call_tool")
|
||||
elif type == "read_resource":
|
||||
uri = kwargs.get("uri", None)
|
||||
if uri is None:
|
||||
raise ValueError("uri is required for read_resource")
|
||||
await self._command_queue.put({"type": type, "uri": uri, "future": fut})
|
||||
elif type == "get_prompt":
|
||||
if name is None:
|
||||
raise ValueError("name is required for get_prompt")
|
||||
prompt_args = kwargs.get("arguments", None)
|
||||
await self._command_queue.put({"type": type, "name": name, "args": prompt_args, "future": fut})
|
||||
else: # call_tool
|
||||
await self._command_queue.put({"type": type, "name": name, "args": kwargs, "future": fut})
|
||||
res = await fut
|
||||
else:
|
||||
raise ValueError(f"Unknown command type: {type}")
|
||||
return res
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self._active or self._actor_task is None:
|
||||
return
|
||||
self._shutdown_future = asyncio.Future()
|
||||
await self._command_queue.put({"type": "shutdown", "future": self._shutdown_future})
|
||||
await self._shutdown_future
|
||||
await self._actor_task
|
||||
self._active = False
|
||||
|
||||
async def _sampling_callback(
|
||||
self,
|
||||
context: RequestContext[ClientSession, Any],
|
||||
params: mcp_types.CreateMessageRequestParams,
|
||||
) -> mcp_types.CreateMessageResult | mcp_types.ErrorData:
|
||||
"""Handle sampling requests using the provided model client."""
|
||||
if self._model_client is None:
|
||||
# Return an error when no model client is available
|
||||
return mcp_types.ErrorData(
|
||||
code=mcp_types.INVALID_REQUEST,
|
||||
message="No model client available for sampling.",
|
||||
data=None,
|
||||
)
|
||||
|
||||
llm_messages: list[LLMMessage] = []
|
||||
|
||||
try:
|
||||
if params.systemPrompt:
|
||||
llm_messages.append(SystemMessage(content=params.systemPrompt))
|
||||
|
||||
for mcp_message in params.messages:
|
||||
llm_messages.append(_parse_sampling_message(mcp_message, model_info=self._model_client.model_info))
|
||||
|
||||
except Exception as e:
|
||||
return mcp_types.ErrorData(
|
||||
code=mcp_types.INVALID_PARAMS,
|
||||
message="Error processing sampling messages.",
|
||||
data=f"{type(e).__name__}: {e}",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._model_client.create(messages=llm_messages)
|
||||
|
||||
content = result.content
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
|
||||
return mcp_types.CreateMessageResult(
|
||||
role="assistant",
|
||||
content=mcp_types.TextContent(type="text", text=content),
|
||||
model=self._model_client.model_info["family"],
|
||||
stopReason=result.finish_reason,
|
||||
)
|
||||
except Exception as e:
|
||||
return mcp_types.ErrorData(
|
||||
code=mcp_types.INTERNAL_ERROR,
|
||||
message="Error sampling from model client.",
|
||||
data=f"{type(e).__name__}: {e}",
|
||||
)
|
||||
|
||||
async def _run_actor(self) -> None:
|
||||
result: McpResult
|
||||
try:
|
||||
async with create_mcp_server_session(
|
||||
self.server_params, sampling_callback=self._sampling_callback
|
||||
) as session:
|
||||
# Save the initialize result
|
||||
self._initialize_result = await session.initialize()
|
||||
while True:
|
||||
cmd = await self._command_queue.get()
|
||||
if cmd["type"] == "shutdown":
|
||||
cmd["future"].set_result("ok")
|
||||
break
|
||||
elif cmd["type"] == "call_tool":
|
||||
try:
|
||||
result = session.call_tool(name=cmd["name"], arguments=cmd["args"])
|
||||
cmd["future"].set_result(result)
|
||||
except Exception as e:
|
||||
cmd["future"].set_exception(e)
|
||||
elif cmd["type"] == "read_resource":
|
||||
try:
|
||||
result = session.read_resource(uri=cmd["uri"])
|
||||
cmd["future"].set_result(result)
|
||||
except Exception as e:
|
||||
cmd["future"].set_exception(e)
|
||||
elif cmd["type"] == "get_prompt":
|
||||
try:
|
||||
result = session.get_prompt(name=cmd["name"], arguments=cmd["args"])
|
||||
cmd["future"].set_result(result)
|
||||
except Exception as e:
|
||||
cmd["future"].set_exception(e)
|
||||
elif cmd["type"] == "list_tools":
|
||||
try:
|
||||
result = session.list_tools()
|
||||
cmd["future"].set_result(result)
|
||||
except Exception as e:
|
||||
cmd["future"].set_exception(e)
|
||||
elif cmd["type"] == "list_prompts":
|
||||
try:
|
||||
result = session.list_prompts()
|
||||
cmd["future"].set_result(result)
|
||||
except Exception as e:
|
||||
cmd["future"].set_exception(e)
|
||||
elif cmd["type"] == "list_resources":
|
||||
try:
|
||||
result = session.list_resources()
|
||||
cmd["future"].set_result(result)
|
||||
except Exception as e:
|
||||
cmd["future"].set_exception(e)
|
||||
elif cmd["type"] == "list_resource_templates":
|
||||
try:
|
||||
result = session.list_resource_templates()
|
||||
cmd["future"].set_result(result)
|
||||
except Exception as e:
|
||||
cmd["future"].set_exception(e)
|
||||
except Exception as e:
|
||||
if self._shutdown_future and not self._shutdown_future.done():
|
||||
self._shutdown_future.set_exception(e)
|
||||
else:
|
||||
logger.exception("Exception in MCP actor task")
|
||||
finally:
|
||||
self._active = False
|
||||
self._actor_task = None
|
||||
|
||||
def _sync_shutdown(self) -> None:
|
||||
if not self._active or self._actor_task is None:
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
# No loop available — interpreter is likely shutting down
|
||||
return
|
||||
|
||||
if loop.is_closed():
|
||||
return
|
||||
|
||||
if loop.is_running():
|
||||
loop.create_task(self.close())
|
||||
else:
|
||||
loop.run_until_complete(self.close())
|
||||
|
||||
def _to_config(self) -> McpSessionActorConfig:
|
||||
"""
|
||||
Convert the adapter to its configuration representation.
|
||||
|
||||
Returns:
|
||||
McpSessionConfig: The configuration of the adapter.
|
||||
"""
|
||||
return McpSessionActorConfig(server_params=self.server_params)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: McpSessionActorConfig) -> Self:
|
||||
"""
|
||||
Create an instance of McpSessionActor from its configuration.
|
||||
|
||||
Args:
|
||||
config (McpSessionConfig): The configuration of the adapter.
|
||||
|
||||
Returns:
|
||||
McpSessionActor: An instance of SseMcpToolAdapter.
|
||||
"""
|
||||
return cls(server_params=config.server_params)
|
||||
190
agent_dhal/agentdhal_extensions/tools/mcp/_base.py
Normal file
190
agent_dhal/agentdhal_extensions/tools/mcp/_base.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import asyncio
|
||||
import builtins
|
||||
import json
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Generic, Sequence, Type, TypeVar
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from agentdhal_core.utils import schema_to_pydantic_model
|
||||
from pydantic import BaseModel
|
||||
from pydantic.networks import AnyUrl
|
||||
|
||||
from mcp import ClientSession, Tool
|
||||
from mcp.types import AudioContent, ContentBlock, EmbeddedResource, ImageContent, ResourceLink, TextContent
|
||||
|
||||
from ._config import McpServerParams
|
||||
from ._session import create_mcp_server_session
|
||||
|
||||
TServerParams = TypeVar("TServerParams", bound=McpServerParams)
|
||||
|
||||
|
||||
class McpToolAdapter(BaseTool[BaseModel, Any], ABC, Generic[TServerParams]):
|
||||
"""
|
||||
Base adapter class for MCP tools to make them compatible with AutoGen.
|
||||
|
||||
Args:
|
||||
server_params (TServerParams): Parameters for the MCP server connection.
|
||||
tool (Tool): The MCP tool to wrap.
|
||||
"""
|
||||
|
||||
component_type = "tool"
|
||||
|
||||
def __init__(self, server_params: TServerParams, tool: Tool, session: ClientSession | None = None) -> None:
|
||||
self._tool = tool
|
||||
self._server_params = server_params
|
||||
self._session = session
|
||||
|
||||
# Extract name and description
|
||||
name = tool.name
|
||||
description = tool.description or ""
|
||||
|
||||
# Create the input model from the tool's schema
|
||||
input_model = schema_to_pydantic_model(tool.inputSchema)
|
||||
|
||||
# Use Any as return type since MCP tool returns can vary
|
||||
return_type: Type[Any] = object
|
||||
|
||||
super().__init__(input_model, return_type, name, description)
|
||||
|
||||
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
|
||||
"""
|
||||
Run the MCP tool with the provided arguments.
|
||||
|
||||
Args:
|
||||
args (BaseModel): The arguments to pass to the tool.
|
||||
cancellation_token (CancellationToken): Token to signal cancellation.
|
||||
|
||||
Returns:
|
||||
Any: The result of the tool execution.
|
||||
|
||||
Raises:
|
||||
Exception: If the operation is cancelled or the tool execution fails.
|
||||
"""
|
||||
# Convert the input model to a dictionary
|
||||
# Exclude unset values to avoid sending them to the MCP servers which may cause errors
|
||||
# for many servers.
|
||||
kwargs = args.model_dump(exclude_unset=True)
|
||||
|
||||
if self._session is not None:
|
||||
# If a session is provided, use it directly.
|
||||
session = self._session
|
||||
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)
|
||||
|
||||
async with create_mcp_server_session(self._server_params) as session:
|
||||
await session.initialize()
|
||||
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)
|
||||
|
||||
def _normalize_payload_to_content_list(self, payload: Sequence[ContentBlock]) -> list[ContentBlock]:
|
||||
"""
|
||||
Normalizes a raw tool output payload into a list of content items.
|
||||
- If payload is already a sequence of ContentBlock items, it's converted to a list and returned.
|
||||
- If payload is a single ContentBlock item, it's wrapped in a list.
|
||||
- If payload is a string, it's wrapped in [TextContent(text=payload)].
|
||||
- Otherwise, the payload is stringified and wrapped in [TextContent(text=str(payload))].
|
||||
"""
|
||||
if isinstance(payload, Sequence) and all(
|
||||
isinstance(item, (TextContent, ImageContent, EmbeddedResource, AudioContent, ResourceLink))
|
||||
for item in payload
|
||||
):
|
||||
return list(payload)
|
||||
elif isinstance(payload, (TextContent, ImageContent, EmbeddedResource, AudioContent, ResourceLink)):
|
||||
return [payload]
|
||||
elif isinstance(payload, str):
|
||||
return [TextContent(text=payload, type="text")]
|
||||
else:
|
||||
return [TextContent(text=str(payload), type="text")]
|
||||
|
||||
async def _run(self, args: Dict[str, Any], cancellation_token: CancellationToken, session: ClientSession) -> Any:
|
||||
exceptions_to_catch: tuple[Type[BaseException], ...]
|
||||
if hasattr(builtins, "ExceptionGroup"):
|
||||
exceptions_to_catch = (asyncio.CancelledError, builtins.ExceptionGroup)
|
||||
else:
|
||||
exceptions_to_catch = (asyncio.CancelledError,)
|
||||
|
||||
try:
|
||||
if cancellation_token.is_cancelled():
|
||||
raise asyncio.CancelledError("Operation cancelled")
|
||||
|
||||
result_future = asyncio.ensure_future(session.call_tool(name=self._tool.name, arguments=args))
|
||||
cancellation_token.link_future(result_future)
|
||||
result = await result_future
|
||||
|
||||
normalized_content_list = self._normalize_payload_to_content_list(result.content)
|
||||
|
||||
if result.isError:
|
||||
serialized_error_message = self.return_value_as_string(normalized_content_list)
|
||||
raise Exception(serialized_error_message)
|
||||
return normalized_content_list
|
||||
|
||||
except exceptions_to_catch:
|
||||
# Re-raise these specific exception types directly.
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def from_server_params(cls, server_params: TServerParams, tool_name: str) -> "McpToolAdapter[TServerParams]":
|
||||
"""
|
||||
Create an instance of McpToolAdapter from server parameters and tool name.
|
||||
|
||||
Args:
|
||||
server_params (TServerParams): Parameters for the MCP server connection.
|
||||
tool_name (str): The name of the tool to wrap.
|
||||
|
||||
Returns:
|
||||
McpToolAdapter[TServerParams]: An instance of McpToolAdapter.
|
||||
|
||||
Raises:
|
||||
ValueError: If the tool with the specified name is not found.
|
||||
"""
|
||||
async with create_mcp_server_session(server_params) as session:
|
||||
await session.initialize()
|
||||
|
||||
tools_response = await session.list_tools()
|
||||
matching_tool = next((t for t in tools_response.tools if t.name == tool_name), None)
|
||||
|
||||
if matching_tool is None:
|
||||
raise ValueError(
|
||||
f"Tool '{tool_name}' not found, available tools: {', '.join([t.name for t in tools_response.tools])}"
|
||||
)
|
||||
|
||||
return cls(server_params=server_params, tool=matching_tool)
|
||||
|
||||
def return_value_as_string(self, value: list[Any]) -> str:
|
||||
"""Return a string representation of the result."""
|
||||
|
||||
def serialize_item(item: Any) -> dict[str, Any]:
|
||||
if isinstance(item, (TextContent, ImageContent, AudioContent)):
|
||||
dumped = item.model_dump()
|
||||
# Remove the 'meta' field if it exists and is None (for backward compatibility)
|
||||
if dumped.get("meta") is None:
|
||||
dumped.pop("meta", None)
|
||||
return dumped
|
||||
elif isinstance(item, EmbeddedResource):
|
||||
type = item.type
|
||||
resource = {}
|
||||
for key, val in item.resource.model_dump().items():
|
||||
# Skip 'meta' field if it's None (for backward compatibility)
|
||||
if key == "meta" and val is None:
|
||||
continue
|
||||
if isinstance(val, AnyUrl):
|
||||
resource[key] = str(val)
|
||||
else:
|
||||
resource[key] = val
|
||||
dumped_annotations = item.annotations.model_dump() if item.annotations else None
|
||||
# Remove 'meta' from annotations if it exists and is None
|
||||
if dumped_annotations and dumped_annotations.get("meta") is None:
|
||||
dumped_annotations.pop("meta", None)
|
||||
return {"type": type, "resource": resource, "annotations": dumped_annotations}
|
||||
elif isinstance(item, ResourceLink):
|
||||
dumped = item.model_dump()
|
||||
# Remove the 'meta' field if it exists and is None (for backward compatibility)
|
||||
if dumped.get("meta") is None:
|
||||
dumped.pop("meta", None)
|
||||
# Convert AnyUrl to string for JSON serialization
|
||||
if "uri" in dumped and isinstance(dumped["uri"], AnyUrl):
|
||||
dumped["uri"] = str(dumped["uri"])
|
||||
return dumped
|
||||
else:
|
||||
return {}
|
||||
|
||||
return json.dumps([serialize_item(item) for item in value])
|
||||
42
agent_dhal/agentdhal_extensions/tools/mcp/_config.py
Normal file
42
agent_dhal/agentdhal_extensions/tools/mcp/_config.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from mcp import StdioServerParameters
|
||||
|
||||
|
||||
class StdioServerParams(StdioServerParameters):
|
||||
"""Parameters for connecting to an MCP server over STDIO."""
|
||||
|
||||
type: Literal["StdioServerParams"] = "StdioServerParams"
|
||||
|
||||
read_timeout_seconds: float = 5
|
||||
|
||||
|
||||
class SseServerParams(BaseModel):
|
||||
"""Parameters for connecting to an MCP server over SSE."""
|
||||
|
||||
type: Literal["SseServerParams"] = "SseServerParams"
|
||||
|
||||
url: str # The SSE endpoint URL.
|
||||
headers: dict[str, Any] | None = None # Optional headers to include in requests.
|
||||
timeout: float = 5 # HTTP timeout for regular operations.
|
||||
sse_read_timeout: float = 60 * 5 # Timeout for SSE read operations.
|
||||
|
||||
|
||||
class StreamableHttpServerParams(BaseModel):
|
||||
"""Parameters for connecting to an MCP server over Streamable HTTP."""
|
||||
|
||||
type: Literal["StreamableHttpServerParams"] = "StreamableHttpServerParams"
|
||||
|
||||
url: str # The endpoint URL.
|
||||
headers: dict[str, Any] | None = None # Optional headers to include in requests.
|
||||
timeout: float = 30.0 # HTTP timeout for regular operations in seconds.
|
||||
sse_read_timeout: float = 300.0 # Timeout for SSE read operations in seconds.
|
||||
terminate_on_close: bool = True
|
||||
|
||||
|
||||
McpServerParams = Annotated[
|
||||
StdioServerParams | SseServerParams | StreamableHttpServerParams, Field(discriminator="type")
|
||||
]
|
||||
214
agent_dhal/agentdhal_extensions/tools/mcp/_factory.py
Normal file
214
agent_dhal/agentdhal_extensions/tools/mcp/_factory.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from mcp import ClientSession
|
||||
|
||||
from ._config import McpServerParams, SseServerParams, StdioServerParams, StreamableHttpServerParams
|
||||
from ._session import create_mcp_server_session
|
||||
from ._sse import SseMcpToolAdapter
|
||||
from ._stdio import StdioMcpToolAdapter
|
||||
from ._streamable_http import StreamableHttpMcpToolAdapter
|
||||
|
||||
|
||||
async def mcp_server_tools(
|
||||
server_params: McpServerParams,
|
||||
session: ClientSession | None = None,
|
||||
) -> list[StdioMcpToolAdapter | SseMcpToolAdapter | StreamableHttpMcpToolAdapter]:
|
||||
"""Creates a list of MCP tool adapters that can be used with AutoGen agents.
|
||||
|
||||
.. warning::
|
||||
|
||||
Only connect to trusted MCP servers, especially when using
|
||||
`StdioServerParams` as it executes commands in the local environment.
|
||||
|
||||
This factory function connects to an MCP server and returns adapters for all available tools.
|
||||
The adapters can be directly assigned to an AutoGen agent's tools list.
|
||||
|
||||
.. note::
|
||||
|
||||
To use this function, you need to install `mcp` extra for the `autogen-ext` package.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-ext[mcp]"
|
||||
|
||||
Args:
|
||||
server_params (McpServerParams): Connection parameters for the MCP server.
|
||||
Can be either StdioServerParams for command-line tools or
|
||||
SseServerParams and StreamableHttpServerParams for HTTP/SSE services.
|
||||
session (ClientSession | None): Optional existing session to use. This is used
|
||||
when you want to reuse an existing connection to the MCP server. The session
|
||||
will be reused when creating the MCP tool adapters.
|
||||
|
||||
Returns:
|
||||
list[StdioMcpToolAdapter | SseMcpToolAdapter | StreamableHttpMcpToolAdapter]:
|
||||
A list of tool adapters ready to use with AutoGen agents.
|
||||
|
||||
Examples:
|
||||
|
||||
**Local file system MCP service over standard I/O example:**
|
||||
|
||||
Install the filesystem server package from npm (requires Node.js 16+ and npm).
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
npm install -g @modelcontextprotocol/server-filesystem
|
||||
|
||||
Create an agent that can use all tools from the local filesystem MCP server.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.tools.mcp import StdioServerParams, mcp_server_tools
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_core import CancellationToken
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Setup server params for local filesystem access
|
||||
desktop = str(Path.home() / "Desktop")
|
||||
server_params = StdioServerParams(
|
||||
command="npx.cmd", args=["-y", "@modelcontextprotocol/server-filesystem", desktop]
|
||||
)
|
||||
|
||||
# Get all available tools from the server
|
||||
tools = await mcp_server_tools(server_params)
|
||||
|
||||
# Create an agent that can use all the tools
|
||||
agent = AssistantAgent(
|
||||
name="file_manager",
|
||||
model_client=OpenAIChatCompletionClient(model="gpt-4"),
|
||||
tools=tools, # type: ignore
|
||||
)
|
||||
|
||||
# The agent can now use any of the filesystem tools
|
||||
await agent.run(task="Create a file called test.txt with some content", cancellation_token=CancellationToken())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
**Local fetch MCP service over standard I/O example:**
|
||||
|
||||
Install the `mcp-server-fetch` package.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install mcp-server-fetch
|
||||
|
||||
Create an agent that can use the `fetch` tool from the local MCP server.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.tools.mcp import StdioServerParams, mcp_server_tools
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Get the fetch tool from mcp-server-fetch.
|
||||
fetch_mcp_server = StdioServerParams(command="uvx", args=["mcp-server-fetch"])
|
||||
tools = await mcp_server_tools(fetch_mcp_server)
|
||||
|
||||
# Create an agent that can use the fetch tool.
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
agent = AssistantAgent(name="fetcher", model_client=model_client, tools=tools, reflect_on_tool_use=True) # type: ignore
|
||||
|
||||
# Let the agent fetch the content of a URL and summarize it.
|
||||
result = await agent.run(task="Summarize the content of https://en.wikipedia.org/wiki/Seattle")
|
||||
print(result.messages[-1])
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
**Sharing an MCP client session across multiple tools:**
|
||||
|
||||
You can create a single MCP client session and share it across multiple tools.
|
||||
This is sometimes required when the server maintains a session state
|
||||
(e.g., a browser state) that should be reused for multiple requests.
|
||||
|
||||
The following example show how to create a single MCP client session
|
||||
to a local `Playwright <https://github.com/microsoft/playwright-mcp>`_
|
||||
server and share it across multiple tools.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.tools.mcp import StdioServerParams, create_mcp_server_session, mcp_server_tools
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o", parallel_tool_calls=False) # type: ignore
|
||||
params = StdioServerParams(
|
||||
command="npx",
|
||||
args=["@playwright/mcp@latest"],
|
||||
read_timeout_seconds=60,
|
||||
)
|
||||
async with create_mcp_server_session(params) as session:
|
||||
await session.initialize()
|
||||
tools = await mcp_server_tools(server_params=params, session=session)
|
||||
print(f"Tools: {[tool.name for tool in tools]}")
|
||||
|
||||
agent = AssistantAgent(
|
||||
name="Assistant",
|
||||
model_client=model_client,
|
||||
tools=tools, # type: ignore
|
||||
)
|
||||
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team = RoundRobinGroupChat([agent], termination_condition=termination)
|
||||
await Console(
|
||||
team.run_stream(
|
||||
task="Go to https://ekzhu.com/, visit the first link in the page, then tell me about the linked page."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
**Remote MCP service over SSE example:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_extensions.tools.mcp import SseServerParams, mcp_server_tools
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Setup server params for remote service
|
||||
server_params = SseServerParams(url="https://api.example.com/mcp", headers={"Authorization": "Bearer token"})
|
||||
|
||||
# Get all available tools
|
||||
tools = await mcp_server_tools(server_params)
|
||||
|
||||
# Create an agent with all tools
|
||||
agent = AssistantAgent(name="tool_user", model_client=OpenAIChatCompletionClient(model="gpt-4"), tools=tools) # type: ignore
|
||||
|
||||
For more examples and detailed usage, see the samples directory in the package repository.
|
||||
"""
|
||||
if session is None:
|
||||
async with create_mcp_server_session(server_params) as temp_session:
|
||||
await temp_session.initialize()
|
||||
|
||||
tools = await temp_session.list_tools()
|
||||
else:
|
||||
tools = await session.list_tools()
|
||||
|
||||
if isinstance(server_params, StdioServerParams):
|
||||
return [StdioMcpToolAdapter(server_params=server_params, tool=tool, session=session) for tool in tools.tools]
|
||||
elif isinstance(server_params, SseServerParams):
|
||||
return [SseMcpToolAdapter(server_params=server_params, tool=tool, session=session) for tool in tools.tools]
|
||||
elif isinstance(server_params, StreamableHttpServerParams):
|
||||
return [
|
||||
StreamableHttpMcpToolAdapter(server_params=server_params, tool=tool, session=session)
|
||||
for tool in tools.tools
|
||||
]
|
||||
raise ValueError(f"Unsupported server params type: {type(server_params)}")
|
||||
55
agent_dhal/agentdhal_extensions/tools/mcp/_session.py
Normal file
55
agent_dhal/agentdhal_extensions/tools/mcp/_session.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.session import SamplingFnT
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
from ._config import McpServerParams, SseServerParams, StdioServerParams, StreamableHttpServerParams
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_mcp_server_session(
|
||||
server_params: McpServerParams, sampling_callback: SamplingFnT | None = None
|
||||
) -> AsyncGenerator[ClientSession, None]:
|
||||
"""Create an MCP client session for the given server parameters."""
|
||||
if isinstance(server_params, StdioServerParams):
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(
|
||||
read_stream=read,
|
||||
write_stream=write,
|
||||
read_timeout_seconds=timedelta(seconds=server_params.read_timeout_seconds),
|
||||
sampling_callback=sampling_callback,
|
||||
) as session:
|
||||
yield session
|
||||
elif isinstance(server_params, SseServerParams):
|
||||
async with sse_client(**server_params.model_dump(exclude={"type"})) as (read, write):
|
||||
async with ClientSession(
|
||||
read_stream=read,
|
||||
write_stream=write,
|
||||
read_timeout_seconds=timedelta(seconds=server_params.sse_read_timeout),
|
||||
sampling_callback=sampling_callback,
|
||||
) as session:
|
||||
yield session
|
||||
elif isinstance(server_params, StreamableHttpServerParams):
|
||||
# Convert float seconds to timedelta for the streamablehttp_client
|
||||
params_dict = server_params.model_dump(exclude={"type"})
|
||||
params_dict["timeout"] = timedelta(seconds=server_params.timeout)
|
||||
params_dict["sse_read_timeout"] = timedelta(seconds=server_params.sse_read_timeout)
|
||||
|
||||
async with streamablehttp_client(**params_dict) as (
|
||||
read,
|
||||
write,
|
||||
session_id_callback, # type: ignore[assignment, unused-variable]
|
||||
):
|
||||
# TODO: Handle session_id_callback if needed
|
||||
async with ClientSession(
|
||||
read_stream=read,
|
||||
write_stream=write,
|
||||
read_timeout_seconds=timedelta(seconds=server_params.sse_read_timeout),
|
||||
sampling_callback=sampling_callback,
|
||||
) as session:
|
||||
yield session
|
||||
116
agent_dhal/agentdhal_extensions/tools/mcp/_sse.py
Normal file
116
agent_dhal/agentdhal_extensions/tools/mcp/_sse.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from agentdhal_core import Component
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from mcp import ClientSession, Tool
|
||||
|
||||
from ._base import McpToolAdapter
|
||||
from ._config import SseServerParams
|
||||
|
||||
|
||||
class SseMcpToolAdapterConfig(BaseModel):
|
||||
"""Configuration for the MCP tool adapter."""
|
||||
|
||||
server_params: SseServerParams
|
||||
tool: Tool
|
||||
|
||||
|
||||
class SseMcpToolAdapter(
|
||||
McpToolAdapter[SseServerParams],
|
||||
Component[SseMcpToolAdapterConfig],
|
||||
):
|
||||
"""
|
||||
Allows you to wrap an MCP tool running over Server-Sent Events (SSE) and make it available to AutoGen.
|
||||
|
||||
This adapter enables using MCP-compatible tools that communicate over HTTP with SSE
|
||||
with AutoGen agents. Common use cases include integrating with remote MCP services,
|
||||
cloud-based tools, and web APIs that implement the Model Context Protocol (MCP).
|
||||
|
||||
.. note::
|
||||
|
||||
To use this class, you need to install `mcp` extra for the `autogen-ext` package.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-ext[mcp]"
|
||||
|
||||
Args:
|
||||
server_params (SseServerParameters): Parameters for the MCP server connection,
|
||||
including URL, headers, and timeouts.
|
||||
tool (Tool): The MCP tool to wrap.
|
||||
session (ClientSession, optional): The MCP client session to use. If not provided,
|
||||
it will create a new session. This is useful for testing or when you want to
|
||||
manage the session lifecycle yourself.
|
||||
|
||||
Examples:
|
||||
Use a remote translation service that implements MCP over SSE to create tools
|
||||
that allow AutoGen agents to perform translations:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.tools.mcp import SseMcpToolAdapter, SseServerParams
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_core import CancellationToken
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create server params for the remote MCP service
|
||||
server_params = SseServerParams(
|
||||
url="https://api.example.com/mcp",
|
||||
headers={"Authorization": "Bearer your-api-key", "Content-Type": "application/json"},
|
||||
timeout=30, # Connection timeout in seconds
|
||||
)
|
||||
|
||||
# Get the translation tool from the server
|
||||
adapter = await SseMcpToolAdapter.from_server_params(server_params, "translate")
|
||||
|
||||
# Create an agent that can use the translation tool
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4")
|
||||
agent = AssistantAgent(
|
||||
name="translator",
|
||||
model_client=model_client,
|
||||
tools=[adapter],
|
||||
system_message="You are a helpful translation assistant.",
|
||||
)
|
||||
|
||||
# Let the agent translate some text
|
||||
await Console(
|
||||
agent.run_stream(task="Translate 'Hello, how are you?' to Spanish", cancellation_token=CancellationToken())
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = SseMcpToolAdapterConfig
|
||||
component_provider_override = "agentdhal_extensions.tools.mcp.SseMcpToolAdapter"
|
||||
|
||||
def __init__(self, server_params: SseServerParams, tool: Tool, session: ClientSession | None = None) -> None:
|
||||
super().__init__(server_params=server_params, tool=tool, session=session)
|
||||
|
||||
def _to_config(self) -> SseMcpToolAdapterConfig:
|
||||
"""
|
||||
Convert the adapter to its configuration representation.
|
||||
|
||||
Returns:
|
||||
SseMcpToolAdapterConfig: The configuration of the adapter.
|
||||
"""
|
||||
return SseMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: SseMcpToolAdapterConfig) -> Self:
|
||||
"""
|
||||
Create an instance of SseMcpToolAdapter from its configuration.
|
||||
|
||||
Args:
|
||||
config (SseMcpToolAdapterConfig): The configuration of the adapter.
|
||||
|
||||
Returns:
|
||||
SseMcpToolAdapter: An instance of SseMcpToolAdapter.
|
||||
"""
|
||||
return cls(server_params=config.server_params, tool=config.tool)
|
||||
74
agent_dhal/agentdhal_extensions/tools/mcp/_stdio.py
Normal file
74
agent_dhal/agentdhal_extensions/tools/mcp/_stdio.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from agentdhal_core import Component
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from mcp import ClientSession, Tool
|
||||
|
||||
from ._base import McpToolAdapter
|
||||
from ._config import StdioServerParams
|
||||
|
||||
|
||||
class StdioMcpToolAdapterConfig(BaseModel):
|
||||
"""Configuration for the MCP tool adapter."""
|
||||
|
||||
server_params: StdioServerParams
|
||||
tool: Tool
|
||||
|
||||
|
||||
class StdioMcpToolAdapter(
|
||||
McpToolAdapter[StdioServerParams],
|
||||
Component[StdioMcpToolAdapterConfig],
|
||||
):
|
||||
"""Allows you to wrap an MCP tool running over STDIO and make it available to AutoGen.
|
||||
|
||||
This adapter enables using MCP-compatible tools that communicate over standard input/output
|
||||
with AutoGen agents. Common use cases include wrapping command-line tools and local services
|
||||
that implement the Model Context Protocol (MCP).
|
||||
|
||||
.. note::
|
||||
|
||||
To use this class, you need to install `mcp` extra for the `autogen-ext` package.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-ext[mcp]"
|
||||
|
||||
|
||||
Args:
|
||||
server_params (StdioServerParams): Parameters for the MCP server connection,
|
||||
including command to run and its arguments
|
||||
tool (Tool): The MCP tool to wrap
|
||||
session (ClientSession, optional): The MCP client session to use. If not provided,
|
||||
a new session will be created. This is useful for testing or when you want to
|
||||
manage the session lifecycle yourself.
|
||||
|
||||
See :func:`~agentdhal_extensions.tools.mcp.mcp_server_tools` for examples.
|
||||
"""
|
||||
|
||||
component_config_schema = StdioMcpToolAdapterConfig
|
||||
component_provider_override = "agentdhal_extensions.tools.mcp.StdioMcpToolAdapter"
|
||||
|
||||
def __init__(self, server_params: StdioServerParams, tool: Tool, session: ClientSession | None = None) -> None:
|
||||
super().__init__(server_params=server_params, tool=tool, session=session)
|
||||
|
||||
def _to_config(self) -> StdioMcpToolAdapterConfig:
|
||||
"""
|
||||
Convert the adapter to its configuration representation.
|
||||
|
||||
Returns:
|
||||
StdioMcpToolAdapterConfig: The configuration of the adapter.
|
||||
"""
|
||||
return StdioMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: StdioMcpToolAdapterConfig) -> Self:
|
||||
"""
|
||||
Create an instance of StdioMcpToolAdapter from its configuration.
|
||||
|
||||
Args:
|
||||
config (StdioMcpToolAdapterConfig): The configuration of the adapter.
|
||||
|
||||
Returns:
|
||||
StdioMcpToolAdapter: An instance of StdioMcpToolAdapter.
|
||||
"""
|
||||
return cls(server_params=config.server_params, tool=config.tool)
|
||||
121
agent_dhal/agentdhal_extensions/tools/mcp/_streamable_http.py
Normal file
121
agent_dhal/agentdhal_extensions/tools/mcp/_streamable_http.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from agentdhal_core import Component
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from mcp import ClientSession, Tool
|
||||
|
||||
from ._base import McpToolAdapter
|
||||
from ._config import StreamableHttpServerParams
|
||||
|
||||
|
||||
class StreamableHttpMcpToolAdapterConfig(BaseModel):
|
||||
"""Configuration for the MCP tool adapter."""
|
||||
|
||||
server_params: StreamableHttpServerParams
|
||||
tool: Tool
|
||||
|
||||
|
||||
class StreamableHttpMcpToolAdapter(
|
||||
McpToolAdapter[StreamableHttpServerParams],
|
||||
Component[StreamableHttpMcpToolAdapterConfig],
|
||||
):
|
||||
"""
|
||||
Allows you to wrap an MCP tool running over Streamable HTTP and make it available to AutoGen.
|
||||
|
||||
This adapter enables using MCP-compatible tools that communicate over Streamable HTTP
|
||||
with AutoGen agents. Common use cases include integrating with remote MCP services,
|
||||
cloud-based tools, and web APIs that implement the Model Context Protocol (MCP).
|
||||
|
||||
.. note::
|
||||
|
||||
To use this class, you need to install `mcp` extra for the `autogen-ext` package.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-ext[mcp]"
|
||||
|
||||
|
||||
Args:
|
||||
server_params (StreamableHttpServerParams): Parameters for the MCP server connection,
|
||||
including URL, headers, and timeouts.
|
||||
tool (Tool): The MCP tool to wrap.
|
||||
session (ClientSession, optional): The MCP client session to use. If not provided,
|
||||
it will create a new session. This is useful for testing or when you want to
|
||||
manage the session lifecycle yourself.
|
||||
|
||||
Examples:
|
||||
Use a remote translation service that implements MCP over Streamable HTTP to
|
||||
create tools that allow AutoGen agents to perform translations:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.tools.mcp import StreamableHttpMcpToolAdapter, StreamableHttpServerParams
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_core import CancellationToken
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create server params for the remote MCP service
|
||||
server_params = StreamableHttpServerParams(
|
||||
url="https://api.example.com/mcp",
|
||||
headers={"Authorization": "Bearer your-api-key", "Content-Type": "application/json"},
|
||||
timeout=30.0, # HTTP timeout in seconds
|
||||
sse_read_timeout=300.0, # SSE read timeout in seconds (5 minutes)
|
||||
terminate_on_close=True,
|
||||
)
|
||||
|
||||
# Get the translation tool from the server
|
||||
adapter = await StreamableHttpMcpToolAdapter.from_server_params(server_params, "translate")
|
||||
|
||||
# Create an agent that can use the translation tool
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4")
|
||||
agent = AssistantAgent(
|
||||
name="translator",
|
||||
model_client=model_client,
|
||||
tools=[adapter],
|
||||
system_message="You are a helpful translation assistant.",
|
||||
)
|
||||
|
||||
# Let the agent translate some text
|
||||
await Console(
|
||||
agent.run_stream(task="Translate 'Hello, how are you?' to Spanish", cancellation_token=CancellationToken())
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = StreamableHttpMcpToolAdapterConfig
|
||||
component_provider_override = "agentdhal_extensions.tools.mcp.StreamableHttpMcpToolAdapter"
|
||||
|
||||
def __init__(
|
||||
self, server_params: StreamableHttpServerParams, tool: Tool, session: ClientSession | None = None
|
||||
) -> None:
|
||||
super().__init__(server_params=server_params, tool=tool, session=session)
|
||||
|
||||
def _to_config(self) -> StreamableHttpMcpToolAdapterConfig:
|
||||
"""
|
||||
Convert the adapter to its configuration representation.
|
||||
|
||||
Returns:
|
||||
StreamableHttpMcpToolAdapterConfig: The configuration of the adapter.
|
||||
"""
|
||||
return StreamableHttpMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: StreamableHttpMcpToolAdapterConfig) -> Self:
|
||||
"""
|
||||
Create an instance of StreamableHttpMcpToolAdapter from its configuration.
|
||||
|
||||
Args:
|
||||
config (StreamableHttpMcpToolAdapterConfig): The configuration of the adapter.
|
||||
|
||||
Returns:
|
||||
StreamableHttpMcpToolAdapter: An instance of StreamableHttpMcpToolAdapter.
|
||||
"""
|
||||
return cls(server_params=config.server_params, tool=config.tool)
|
||||
518
agent_dhal/agentdhal_extensions/tools/mcp/_workbench.py
Normal file
518
agent_dhal/agentdhal_extensions/tools/mcp/_workbench.py
Normal file
@@ -0,0 +1,518 @@
|
||||
import asyncio
|
||||
import builtins
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional
|
||||
|
||||
from agentdhal_core import CancellationToken, Component, ComponentModel, Image, trace_tool_span
|
||||
from agentdhal_core.models import ChatCompletionClient
|
||||
from agentdhal_core.tools import (
|
||||
ImageResultContent,
|
||||
ParametersSchema,
|
||||
TextResultContent,
|
||||
ToolOverride,
|
||||
ToolResult,
|
||||
ToolSchema,
|
||||
Workbench,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from mcp.types import (
|
||||
CallToolResult,
|
||||
EmbeddedResource,
|
||||
GetPromptResult,
|
||||
ImageContent,
|
||||
ListPromptsResult,
|
||||
ListResourcesResult,
|
||||
ListResourceTemplatesResult,
|
||||
ListToolsResult,
|
||||
ReadResourceResult,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
from ._actor import McpSessionActor
|
||||
from ._config import McpServerParams, SseServerParams, StdioServerParams, StreamableHttpServerParams
|
||||
|
||||
|
||||
class McpWorkbenchConfig(BaseModel):
|
||||
server_params: McpServerParams
|
||||
tool_overrides: Dict[str, ToolOverride] = Field(default_factory=dict)
|
||||
model_client: ComponentModel | Dict[str, Any] | None = None
|
||||
|
||||
|
||||
class McpWorkbenchState(BaseModel):
|
||||
type: Literal["McpWorkBenchState"] = "McpWorkBenchState"
|
||||
|
||||
|
||||
class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
|
||||
"""A workbench that wraps an MCP server and provides an interface
|
||||
to list and call tools provided by the server.
|
||||
|
||||
.. warning::
|
||||
|
||||
Only connect to trusted MCP servers, especially when using
|
||||
`StdioServerParams` as it executes commands in the local environment.
|
||||
|
||||
This workbench should be used as a context manager to ensure proper
|
||||
initialization and cleanup of the underlying MCP session.
|
||||
|
||||
.. list-table:: MCP Support
|
||||
:header-rows: 1
|
||||
:widths: 30 70
|
||||
|
||||
* - MCP Capability
|
||||
- Supported Features
|
||||
* - Tools
|
||||
- list_tools, call_tool
|
||||
* - Resources
|
||||
- list_resources, read_resource
|
||||
* - ResourceTemplates
|
||||
- list_resource_templates, read_resource_template
|
||||
* - Prompts
|
||||
- list_prompts, get_prompt
|
||||
* - Sampling
|
||||
- Optional support via model_client
|
||||
* - Roots
|
||||
- not supported
|
||||
* - Ellicitation
|
||||
- not supported
|
||||
|
||||
Args:
|
||||
server_params (McpServerParams): The parameters to connect to the MCP server.
|
||||
This can be either a :class:`StdioServerParams` or :class:`SseServerParams`.
|
||||
tool_overrides (Optional[Dict[str, ToolOverride]]): Optional mapping of original tool
|
||||
names to override configurations for name and/or description. This allows
|
||||
customizing how server tools appear to consumers while maintaining the underlying
|
||||
tool functionality.
|
||||
model_client: Optional chat completion client to handle sampling requests
|
||||
from MCP servers that support the sampling capability. This allows MCP
|
||||
servers to request text generation from a language model during tool
|
||||
execution. If not provided, sampling requests will return an error.
|
||||
|
||||
Raises:
|
||||
ValueError: If there are conflicts in tool override names.
|
||||
|
||||
Examples:
|
||||
|
||||
Here is a simple example of how to use the workbench with a `mcp-server-fetch` server:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_extensions.tools.mcp import McpWorkbench, StdioServerParams
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
params = StdioServerParams(
|
||||
command="uvx",
|
||||
args=["mcp-server-fetch"],
|
||||
read_timeout_seconds=60,
|
||||
)
|
||||
|
||||
# You can also use `start()` and `stop()` to manage the session.
|
||||
async with McpWorkbench(server_params=params) as workbench:
|
||||
tools = await workbench.list_tools()
|
||||
print(tools)
|
||||
result = await workbench.call_tool(tools[0]["name"], {"url": "https://github.com/"})
|
||||
print(result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Example of using tool overrides:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.tools.mcp import McpWorkbench, StdioServerParams
|
||||
from agentdhal_core.tools import ToolOverride
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
params = StdioServerParams(
|
||||
command="uvx",
|
||||
args=["mcp-server-fetch"],
|
||||
read_timeout_seconds=60,
|
||||
)
|
||||
|
||||
# Override the fetch tool's name and description
|
||||
overrides = {
|
||||
"fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool with better error handling")
|
||||
}
|
||||
|
||||
async with McpWorkbench(server_params=params, tool_overrides=overrides) as workbench:
|
||||
tools = await workbench.list_tools()
|
||||
# The tool will now appear as "web_fetch" with the new description
|
||||
print(tools)
|
||||
# Call the overridden tool
|
||||
result = await workbench.call_tool("web_fetch", {"url": "https://github.com/"})
|
||||
print(result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Example of using the workbench with the `GitHub MCP Server <https://github.com/github/github-mcp-server>`_:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.tools.mcp import McpWorkbench, StdioServerParams
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
|
||||
server_params = StdioServerParams(
|
||||
command="docker",
|
||||
args=[
|
||||
"run",
|
||||
"-i",
|
||||
"--rm",
|
||||
"-e",
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN",
|
||||
"ghcr.io/github/github-mcp-server",
|
||||
],
|
||||
env={
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX",
|
||||
},
|
||||
)
|
||||
async with McpWorkbench(server_params) as mcp:
|
||||
agent = AssistantAgent(
|
||||
"github_assistant",
|
||||
model_client=model_client,
|
||||
workbench=mcp,
|
||||
reflect_on_tool_use=True,
|
||||
model_client_stream=True,
|
||||
)
|
||||
await Console(agent.run_stream(task="Is there a repository named Autogen"))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Example of using the workbench with the `Playwright MCP Server <https://github.com/microsoft/playwright-mcp>`_:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# First run `npm install -g @playwright/mcp@latest` to install the MCP server.
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.conditions import TextMessageTermination
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.tools.mcp import McpWorkbench, StdioServerParams
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
|
||||
server_params = StdioServerParams(
|
||||
command="npx",
|
||||
args=[
|
||||
"@playwright/mcp@latest",
|
||||
"--headless",
|
||||
],
|
||||
)
|
||||
async with McpWorkbench(server_params) as mcp:
|
||||
agent = AssistantAgent(
|
||||
"web_browsing_assistant",
|
||||
model_client=model_client,
|
||||
workbench=mcp,
|
||||
model_client_stream=True,
|
||||
)
|
||||
team = RoundRobinGroupChat(
|
||||
[agent],
|
||||
termination_condition=TextMessageTermination(source="web_browsing_assistant"),
|
||||
)
|
||||
await Console(team.run_stream(task="Find out how many contributors for the microsoft/autogen repository"))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
|
||||
component_provider_override = "agentdhal_extensions.tools.mcp.McpWorkbench"
|
||||
component_config_schema = McpWorkbenchConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_params: McpServerParams,
|
||||
tool_overrides: Optional[Dict[str, ToolOverride]] = None,
|
||||
model_client: ChatCompletionClient | None = None,
|
||||
) -> None:
|
||||
self._server_params = server_params
|
||||
self._tool_overrides = tool_overrides or {}
|
||||
self._model_client = model_client
|
||||
|
||||
# Build reverse mapping from override names to original names for call_tool
|
||||
self._override_name_to_original: Dict[str, str] = {}
|
||||
for original_name, override in self._tool_overrides.items():
|
||||
override_name = override.name
|
||||
if override_name and override_name != original_name:
|
||||
# Check for conflicts with other override names
|
||||
if override_name in self._override_name_to_original:
|
||||
existing_original = self._override_name_to_original[override_name]
|
||||
raise ValueError(
|
||||
f"Tool override name '{override_name}' is used by multiple tools: "
|
||||
f"'{existing_original}' and '{original_name}'. Override names must be unique."
|
||||
)
|
||||
self._override_name_to_original[override_name] = original_name
|
||||
|
||||
# self._session: ClientSession | None = None
|
||||
self._actor: McpSessionActor | None = None
|
||||
self._actor_loop: asyncio.AbstractEventLoop | None = None
|
||||
self._read = None
|
||||
self._write = None
|
||||
|
||||
@property
|
||||
def server_params(self) -> McpServerParams:
|
||||
return self._server_params
|
||||
|
||||
async def list_tools(self) -> List[ToolSchema]:
|
||||
if not self._actor:
|
||||
await self.start() # fallback to start the actor if not initialized instead of raising an error
|
||||
# Why? Because when deserializing the workbench, the actor might not be initialized yet.
|
||||
# raise RuntimeError("Actor is not initialized. Call start() first.")
|
||||
if self._actor is None:
|
||||
raise RuntimeError("Actor is not initialized. Please check the server connection.")
|
||||
result_future = await self._actor.call("list_tools", None)
|
||||
list_tool_result = await result_future
|
||||
assert isinstance(
|
||||
list_tool_result, ListToolsResult
|
||||
), f"list_tools must return a CallToolResult, instead of : {str(type(list_tool_result))}"
|
||||
schema: List[ToolSchema] = []
|
||||
for tool in list_tool_result.tools:
|
||||
original_name = tool.name
|
||||
name = original_name
|
||||
description = tool.description or ""
|
||||
|
||||
# Apply overrides if they exist for this tool
|
||||
if original_name in self._tool_overrides:
|
||||
override = self._tool_overrides[original_name]
|
||||
if override.name is not None:
|
||||
name = override.name
|
||||
if override.description is not None:
|
||||
description = override.description
|
||||
|
||||
parameters = ParametersSchema(
|
||||
type="object",
|
||||
properties=tool.inputSchema.get("properties", {}),
|
||||
required=tool.inputSchema.get("required", []),
|
||||
additionalProperties=tool.inputSchema.get("additionalProperties", False),
|
||||
)
|
||||
tool_schema = ToolSchema(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
)
|
||||
schema.append(tool_schema)
|
||||
return schema
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: Mapping[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
call_id: str | None = None,
|
||||
) -> ToolResult:
|
||||
if not self._actor:
|
||||
await self.start() # fallback to start the actor if not initialized instead of raising an error
|
||||
# Why? Because when deserializing the workbench, the actor might not be initialized yet.
|
||||
# raise RuntimeError("Actor is not initialized. Call start() first.")
|
||||
if self._actor is None:
|
||||
raise RuntimeError("Actor is not initialized. Please check the server connection.")
|
||||
if not cancellation_token:
|
||||
cancellation_token = CancellationToken()
|
||||
if not arguments:
|
||||
arguments = {}
|
||||
|
||||
# Check if the name is an override name and map it back to the original
|
||||
original_name = self._override_name_to_original.get(name, name)
|
||||
|
||||
with trace_tool_span(
|
||||
tool_name=name, # Use the requested name for tracing
|
||||
tool_call_id=call_id,
|
||||
):
|
||||
try:
|
||||
result_future = await self._actor.call("call_tool", {"name": original_name, "kargs": arguments})
|
||||
cancellation_token.link_future(result_future)
|
||||
result = await result_future
|
||||
assert isinstance(
|
||||
result, CallToolResult
|
||||
), f"call_tool must return a CallToolResult, instead of : {str(type(result))}"
|
||||
result_parts: List[TextResultContent | ImageResultContent] = []
|
||||
is_error = result.isError
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
result_parts.append(TextResultContent(content=content.text))
|
||||
elif isinstance(content, ImageContent):
|
||||
result_parts.append(ImageResultContent(content=Image.from_base64(content.data)))
|
||||
elif isinstance(content, EmbeddedResource):
|
||||
# TODO: how to handle embedded resources?
|
||||
# For now we just use text representation.
|
||||
result_parts.append(TextResultContent(content=content.model_dump_json()))
|
||||
else:
|
||||
raise ValueError(f"Unknown content type from server: {type(content)}")
|
||||
except Exception as e:
|
||||
error_message = self._format_errors(e)
|
||||
is_error = True
|
||||
result_parts = [TextResultContent(content=error_message)]
|
||||
return ToolResult(name=name, result=result_parts, is_error=is_error) # Return the requested name
|
||||
|
||||
@property
|
||||
def initialize_result(self) -> Any:
|
||||
if self._actor:
|
||||
return self._actor.initialize_result
|
||||
|
||||
return None
|
||||
|
||||
async def list_prompts(self) -> ListPromptsResult:
|
||||
"""List available prompts from the MCP server."""
|
||||
if not self._actor:
|
||||
await self.start()
|
||||
if self._actor is None:
|
||||
raise RuntimeError("Actor is not initialized. Please check the server connection.")
|
||||
|
||||
result_future = await self._actor.call("list_prompts", None)
|
||||
list_prompts_result = await result_future
|
||||
assert isinstance(
|
||||
list_prompts_result, ListPromptsResult
|
||||
), f"list_prompts must return a ListPromptsResult, instead of: {str(type(list_prompts_result))}"
|
||||
|
||||
return list_prompts_result
|
||||
|
||||
async def list_resources(self) -> ListResourcesResult:
|
||||
"""List available resources from the MCP server."""
|
||||
if not self._actor:
|
||||
await self.start()
|
||||
if self._actor is None:
|
||||
raise RuntimeError("Actor is not initialized. Please check the server connection.")
|
||||
|
||||
result_future = await self._actor.call("list_resources", None)
|
||||
list_resources_result = await result_future
|
||||
assert isinstance(
|
||||
list_resources_result, ListResourcesResult
|
||||
), f"list_resources must return a ListResourcesResult, instead of: {str(type(list_resources_result))}"
|
||||
|
||||
return list_resources_result
|
||||
|
||||
async def list_resource_templates(self) -> ListResourceTemplatesResult:
|
||||
"""List available resource templates from the MCP server."""
|
||||
if not self._actor:
|
||||
await self.start()
|
||||
if self._actor is None:
|
||||
raise RuntimeError("Actor is not initialized. Please check the server connection.")
|
||||
|
||||
result_future = await self._actor.call("list_resource_templates", None)
|
||||
list_templates_result = await result_future
|
||||
assert isinstance(
|
||||
list_templates_result, ListResourceTemplatesResult
|
||||
), f"list_resource_templates must return a ListResourceTemplatesResult, instead of: {str(type(list_templates_result))}"
|
||||
|
||||
return list_templates_result
|
||||
|
||||
async def read_resource(self, uri: str) -> ReadResourceResult:
|
||||
"""Read a resource from the MCP server."""
|
||||
if not self._actor:
|
||||
await self.start()
|
||||
if self._actor is None:
|
||||
raise RuntimeError("Actor is not initialized. Please check the server connection.")
|
||||
|
||||
result_future = await self._actor.call("read_resource", {"name": None, "kargs": {"uri": uri}})
|
||||
read_resource_result = await result_future
|
||||
assert isinstance(
|
||||
read_resource_result, ReadResourceResult
|
||||
), f"read_resource must return a ReadResourceResult, instead of: {str(type(read_resource_result))}"
|
||||
|
||||
return read_resource_result
|
||||
|
||||
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> GetPromptResult:
|
||||
"""Get a prompt from the MCP server."""
|
||||
if not self._actor:
|
||||
await self.start()
|
||||
if self._actor is None:
|
||||
raise RuntimeError("Actor is not initialized. Please check the server connection.")
|
||||
|
||||
result_future = await self._actor.call("get_prompt", {"name": name, "kargs": {"arguments": arguments}})
|
||||
get_prompt_result = await result_future
|
||||
assert isinstance(
|
||||
get_prompt_result, GetPromptResult
|
||||
), f"get_prompt must return a GetPromptResult, instead of: {str(type(get_prompt_result))}"
|
||||
|
||||
return get_prompt_result
|
||||
|
||||
def _format_errors(self, error: Exception) -> str:
|
||||
"""Recursively format errors into a string."""
|
||||
|
||||
error_message = ""
|
||||
if hasattr(builtins, "ExceptionGroup") and isinstance(error, builtins.ExceptionGroup):
|
||||
# ExceptionGroup is available in Python 3.11+.
|
||||
# TODO: how to make this compatible with Python 3.10?
|
||||
for sub_exception in error.exceptions: # type: ignore
|
||||
error_message += self._format_errors(sub_exception) # type: ignore
|
||||
else:
|
||||
error_message += f"{str(error)}\n"
|
||||
return error_message
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._actor:
|
||||
warnings.warn(
|
||||
"McpWorkbench is already started. No need to start again.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return # Already initialized, no need to start again
|
||||
|
||||
if isinstance(self._server_params, (StdioServerParams, SseServerParams, StreamableHttpServerParams)):
|
||||
self._actor = McpSessionActor(self._server_params, model_client=self._model_client)
|
||||
await self._actor.initialize()
|
||||
self._actor_loop = asyncio.get_event_loop()
|
||||
else:
|
||||
raise ValueError(f"Unsupported server params type: {type(self._server_params)}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._actor:
|
||||
# Close the actor
|
||||
await self._actor.close()
|
||||
self._actor = None
|
||||
else:
|
||||
raise RuntimeError("McpWorkbench is not started. Call start() first.")
|
||||
|
||||
async def reset(self) -> None:
|
||||
pass
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return McpWorkbenchState().model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def _to_config(self) -> McpWorkbenchConfig:
|
||||
model_client_config = None
|
||||
if self._model_client is not None:
|
||||
model_client_config = self._model_client.dump_component()
|
||||
return McpWorkbenchConfig(
|
||||
server_params=self._server_params, tool_overrides=self._tool_overrides, model_client=model_client_config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: McpWorkbenchConfig) -> Self:
|
||||
model_client = None
|
||||
if config.model_client is not None:
|
||||
model_client = ChatCompletionClient.load_component(config.model_client)
|
||||
return cls(server_params=config.server_params, tool_overrides=config.tool_overrides, model_client=model_client)
|
||||
|
||||
def __del__(self) -> None:
|
||||
# Ensure the actor is stopped when the workbench is deleted
|
||||
# Use getattr to safely handle cases where attributes may not be set (e.g., if __init__ failed)
|
||||
actor = getattr(self, "_actor", None)
|
||||
actor_loop = getattr(self, "_actor_loop", None)
|
||||
|
||||
if actor and actor_loop:
|
||||
if actor_loop.is_running() and not actor_loop.is_closed():
|
||||
actor_loop.call_soon_threadsafe(lambda: asyncio.create_task(self.stop()))
|
||||
else:
|
||||
msg = "Cannot safely stop actor at [McpWorkbench.__del__]: loop is closed or not running"
|
||||
warnings.warn(msg, RuntimeWarning, stacklevel=2)
|
||||
@@ -0,0 +1,6 @@
|
||||
from ._kernel_function_from_tool import KernelFunctionFromTool, KernelFunctionFromToolSchema
|
||||
|
||||
__all__ = [
|
||||
"KernelFunctionFromTool",
|
||||
"KernelFunctionFromToolSchema",
|
||||
]
|
||||
@@ -0,0 +1,94 @@
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import BaseTool, ToolSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from semantic_kernel.functions import KernelFunctionFromMethod, KernelFunctionFromPrompt, kernel_function
|
||||
from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata
|
||||
from semantic_kernel.prompt_template.input_variable import InputVariable
|
||||
from semantic_kernel.prompt_template.prompt_template_config import PromptTemplateConfig
|
||||
|
||||
InputT = TypeVar("InputT", bound=BaseModel)
|
||||
OutputT = TypeVar("OutputT", bound=BaseModel)
|
||||
|
||||
|
||||
class KernelFunctionFromTool(KernelFunctionFromMethod):
|
||||
def __init__(self, tool: BaseTool[InputT, OutputT], plugin_name: str | None = None):
|
||||
# Get the pydantic model types from the tool
|
||||
args_type = tool.args_type()
|
||||
return_type = tool.return_type()
|
||||
|
||||
# 1) Define an async function that calls the tool
|
||||
@kernel_function(name=tool.name, description=tool.description)
|
||||
async def tool_method(**kwargs: dict[str, Any]) -> Any:
|
||||
return await tool.run_json(kwargs, cancellation_token=CancellationToken())
|
||||
|
||||
# Parse schema for parameters
|
||||
parameters_meta: list[KernelParameterMetadata] = []
|
||||
properties = tool.schema.get("parameters", {}).get("properties", {})
|
||||
|
||||
# Get the field types from the pydantic model
|
||||
field_types = args_type.model_fields
|
||||
|
||||
for prop_name, prop_info in properties.items():
|
||||
assert prop_name in field_types, f"Property {prop_name} not found in Tool {tool.name}"
|
||||
assert isinstance(prop_info, dict), f"Property {prop_name} is not a dict in Tool {tool.name}"
|
||||
|
||||
# Get the actual type from the pydantic model field
|
||||
field_type = field_types[prop_name]
|
||||
parameters_meta.append(
|
||||
KernelParameterMetadata(
|
||||
name=prop_name,
|
||||
description=field_type.description or "",
|
||||
default_value=field_type.get_default(),
|
||||
type=prop_info.get("type", "string"), # type: ignore
|
||||
type_object=field_type.annotation,
|
||||
is_required=field_type.is_required(),
|
||||
)
|
||||
)
|
||||
|
||||
# Create return parameter metadata
|
||||
return_parameter = KernelParameterMetadata(
|
||||
name="return",
|
||||
description=f"Result from '{tool.name}' tool",
|
||||
default_value=None,
|
||||
type="object" if issubclass(return_type, BaseModel) else "string",
|
||||
type_object=return_type,
|
||||
is_required=True,
|
||||
)
|
||||
|
||||
# Initialize the parent class
|
||||
super().__init__(
|
||||
method=tool_method,
|
||||
plugin_name=plugin_name,
|
||||
parameters=parameters_meta,
|
||||
return_parameter=return_parameter,
|
||||
additional_metadata=None,
|
||||
)
|
||||
|
||||
self._tool = tool
|
||||
|
||||
|
||||
class KernelFunctionFromToolSchema(KernelFunctionFromPrompt):
|
||||
def __init__(self, tool_schema: ToolSchema, plugin_name: str | None = None):
|
||||
properties = tool_schema.get("parameters", {}).get("properties", {})
|
||||
required = properties.get("required", [])
|
||||
|
||||
prompt_template_config = PromptTemplateConfig(
|
||||
name=tool_schema.get("name", ""),
|
||||
description=tool_schema.get("description", ""),
|
||||
input_variables=[
|
||||
InputVariable(
|
||||
name=prop_name, description=prop_info.get("description", ""), is_required=prop_name in required
|
||||
)
|
||||
for prop_name, prop_info in properties.items()
|
||||
],
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
function_name=tool_schema.get("name", ""),
|
||||
plugin_name=plugin_name,
|
||||
description=tool_schema.get("description", ""),
|
||||
prompt_template_config=prompt_template_config,
|
||||
)
|
||||
Reference in New Issue
Block a user