first commit

This commit is contained in:
DigiJ
2026-03-13 12:56:43 -07:00
commit 159cf9fcfe
309 changed files with 64584 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
from ._ai_search import (
AzureAISearchTool,
BaseAzureAISearchTool,
SearchQuery,
SearchResult,
SearchResults,
VectorizableTextQuery,
)
from ._config import AzureAISearchConfig
__all__ = [
"AzureAISearchTool",
"BaseAzureAISearchTool",
"SearchQuery",
"SearchResult",
"SearchResults",
"AzureAISearchConfig",
"VectorizableTextQuery",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,186 @@
"""Configuration for Azure AI Search tool.
This module provides configuration classes for the Azure AI Search tool, including
settings for authentication, search behavior, retry policies, and caching.
"""
import logging
from typing import (
List,
Literal,
Optional,
TypeVar,
Union,
)
from pydantic import BaseModel, Field, field_validator, model_validator
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
T = TypeVar("T", bound="AzureAISearchConfig")
logger = logging.getLogger(__name__)
QueryTypeLiteral = Literal["simple", "full", "semantic", "vector"]
DEFAULT_API_VERSION = "2023-10-01-preview"
class AzureAISearchConfig(BaseModel):
"""Configuration for Azure AI Search with validation.
This class defines the configuration parameters for Azure AI Search tools, including
authentication, search behavior, caching, and embedding settings.
.. note::
This class requires the ``azure`` extra for the ``autogen-ext`` package.
.. code-block:: bash
pip install -U "agentdhal-ext[azure]"
.. note::
**Prerequisites:**
1. An Azure AI Search service must be created in your Azure subscription.
2. The search index must be properly configured for your use case:
- For vector search: Index must have vector fields
- For semantic search: Index must have semantic configuration
- For hybrid search: Both vector fields and text fields must be configured
3. Required packages:
- Base functionality: ``azure-search-documents>=11.4.0``
- For Azure OpenAI embeddings: ``openai azure-identity``
- For OpenAI embeddings: ``openai``
Example Usage:
.. code-block:: python
from azure.core.credentials import AzureKeyCredential
from agentdhal_extensions.tools.azure import AzureAISearchConfig
# Basic configuration for full-text search
config = AzureAISearchConfig(
name="doc-search",
endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint
index_name="<your-index>", # Name of your search index
credential=AzureKeyCredential("<your-key>"), # Your Azure AI Search admin key
query_type="simple",
search_fields=["content", "title"], # Update with your searchable fields
top=5,
)
# Configuration for vector search with Azure OpenAI embeddings
vector_config = AzureAISearchConfig(
name="vector-search",
endpoint="https://your-search.search.windows.net",
index_name="<your-index>",
credential=AzureKeyCredential("<your-key>"),
query_type="vector",
vector_fields=["embedding"], # Update with your vector field name
embedding_provider="azure_openai",
embedding_model="text-embedding-ada-002",
openai_endpoint="https://your-openai.openai.azure.com", # Your Azure OpenAI endpoint
openai_api_key="<your-openai-key>", # Your Azure OpenAI key
top=5,
)
# Configuration for hybrid search with semantic ranking
hybrid_config = AzureAISearchConfig(
name="hybrid-search",
endpoint="https://your-search.search.windows.net",
index_name="<your-index>",
credential=AzureKeyCredential("<your-key>"),
query_type="semantic",
semantic_config_name="<your-semantic-config>", # Name of your semantic configuration
search_fields=["content", "title"], # Update with your search fields
vector_fields=["embedding"], # Update with your vector field name
embedding_provider="openai",
embedding_model="text-embedding-ada-002",
openai_api_key="<your-openai-key>", # Your OpenAI API key
top=5,
)
"""
name: str = Field(description="The name of this tool instance")
description: Optional[str] = Field(default=None, description="Description explaining the tool's purpose")
endpoint: str = Field(description="The full URL of your Azure AI Search service")
index_name: str = Field(description="Name of the search index to query")
credential: Union[AzureKeyCredential, AsyncTokenCredential] = Field(
description="Azure credential for authentication (API key or token)"
)
api_version: str = Field(
default=DEFAULT_API_VERSION,
description=f"Azure AI Search API version to use. Defaults to {DEFAULT_API_VERSION}.",
)
query_type: QueryTypeLiteral = Field(
default="simple", description="Type of search to perform: simple, full, semantic, or vector"
)
search_fields: Optional[List[str]] = Field(default=None, description="Fields to search within documents")
select_fields: Optional[List[str]] = Field(default=None, description="Fields to return in search results")
vector_fields: Optional[List[str]] = Field(default=None, description="Fields to use for vector search")
top: Optional[int] = Field(
default=None, description="Maximum number of results to return. For vector searches, acts as k in k-NN."
)
filter: Optional[str] = Field(default=None, description="OData filter expression to refine search results")
semantic_config_name: Optional[str] = Field(
default=None, description="Semantic configuration name for enhanced results"
)
enable_caching: bool = Field(default=False, description="Whether to cache search results")
cache_ttl_seconds: int = Field(default=300, description="How long to cache results in seconds")
embedding_provider: Optional[str] = Field(
default=None, description="Name of embedding provider for client-side embeddings"
)
embedding_model: Optional[str] = Field(default=None, description="Model name for client-side embeddings")
openai_api_key: Optional[str] = Field(default=None, description="API key for OpenAI/Azure OpenAI embeddings")
openai_api_version: Optional[str] = Field(default=None, description="API version for Azure OpenAI embeddings")
openai_endpoint: Optional[str] = Field(default=None, description="Endpoint URL for Azure OpenAI embeddings")
model_config = {"arbitrary_types_allowed": True}
@field_validator("endpoint")
def validate_endpoint(cls, v: str) -> str:
"""Validate that the endpoint is a valid URL."""
if not v.startswith(("http://", "https://")):
raise ValueError("endpoint must be a valid URL starting with http:// or https://")
return v
@field_validator("query_type")
def normalize_query_type(cls, v: QueryTypeLiteral) -> QueryTypeLiteral:
"""Normalize query type to standard values."""
if not v:
return "simple"
if isinstance(v, str) and v.lower() == "fulltext":
return "full"
return v
@field_validator("top")
def validate_top(cls, v: Optional[int]) -> Optional[int]:
"""Ensure top is a positive integer if provided."""
if v is not None and v <= 0:
raise ValueError("top must be a positive integer")
return v
@model_validator(mode="after")
def validate_interdependent_fields(self) -> "AzureAISearchConfig":
"""Validate interdependent fields after all fields have been parsed."""
if self.query_type == "semantic" and not self.semantic_config_name:
raise ValueError("semantic_config_name must be provided when query_type is 'semantic'")
if self.query_type == "vector" and not self.vector_fields:
raise ValueError("vector_fields must be provided for vector search")
if (
self.embedding_provider
and self.embedding_provider.lower() == "azure_openai"
and self.embedding_model
and not self.openai_endpoint
):
raise ValueError("openai_endpoint must be provided for azure_openai embedding provider")
return self

View File

@@ -0,0 +1,3 @@
from ._code_execution import CodeExecutionInput, CodeExecutionResult, PythonCodeExecutionTool
__all__ = ["CodeExecutionInput", "CodeExecutionResult", "PythonCodeExecutionTool"]

View File

@@ -0,0 +1,96 @@
from agentdhal_core import CancellationToken, Component, ComponentModel
from agentdhal_core.code_executor import CodeBlock, CodeExecutor
from agentdhal_core.tools import BaseTool
from pydantic import BaseModel, Field, model_serializer
from typing_extensions import Self
class CodeExecutionInput(BaseModel):
code: str = Field(description="The contents of the Python code block that should be executed")
class CodeExecutionResult(BaseModel):
success: bool
output: str
@model_serializer
def ser_model(self) -> str:
return self.output
class PythonCodeExecutionToolConfig(BaseModel):
"""Configuration for PythonCodeExecutionTool"""
executor: ComponentModel
description: str = "Execute Python code blocks."
class PythonCodeExecutionTool(
BaseTool[CodeExecutionInput, CodeExecutionResult], Component[PythonCodeExecutionToolConfig]
):
"""A tool that executes Python code in a code executor and returns output.
Example executors:
* :class:`agentdhal_extensions.code_executors.local.LocalCommandLineCodeExecutor`
* :class:`agentdhal_extensions.code_executors.docker.DockerCommandLineCodeExecutor`
* :class:`agentdhal_extensions.code_executors.azure.ACADynamicSessionsCodeExecutor`
Example usage:
.. code-block:: bash
pip install -U "agentdhal-agentchat" "agentdhal-ext[openai]" "yfinance" "matplotlib"
.. code-block:: python
import asyncio
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.code_executors.local import LocalCommandLineCodeExecutor
from agentdhal_extensions.tools.code_execution import PythonCodeExecutionTool
async def main() -> None:
tool = PythonCodeExecutionTool(LocalCommandLineCodeExecutor(work_dir="coding"))
agent = AssistantAgent(
"assistant", OpenAIChatCompletionClient(model="gpt-4o"), tools=[tool], reflect_on_tool_use=True
)
await Console(
agent.run_stream(
task="Create a plot of MSFT stock prices in 2024 and save it to a file. Use yfinance and matplotlib."
)
)
asyncio.run(main())
Args:
executor (CodeExecutor): The code executor that will be used to execute the code blocks.
"""
component_config_schema = PythonCodeExecutionToolConfig
component_provider_override = "agentdhal_extensions.tools.code_execution.PythonCodeExecutionTool"
def __init__(self, executor: CodeExecutor):
super().__init__(CodeExecutionInput, CodeExecutionResult, "CodeExecutor", "Execute Python code blocks.")
self._executor = executor
async def run(self, args: CodeExecutionInput, cancellation_token: CancellationToken) -> CodeExecutionResult:
code_blocks = [CodeBlock(code=args.code, language="python")]
result = await self._executor.execute_code_blocks(
code_blocks=code_blocks, cancellation_token=cancellation_token
)
return CodeExecutionResult(success=result.exit_code == 0, output=result.output)
def _to_config(self) -> PythonCodeExecutionToolConfig:
"""Convert current instance to config object"""
return PythonCodeExecutionToolConfig(executor=self._executor.dump_component())
@classmethod
def _from_config(cls, config: PythonCodeExecutionToolConfig) -> Self:
"""Create instance from config object"""
executor = CodeExecutor.load_component(config.executor)
return cls(executor=executor)

View File

@@ -0,0 +1,25 @@
from ._config import (
GlobalContextConfig,
GlobalDataConfig,
LocalContextConfig,
LocalDataConfig,
MapReduceConfig,
SearchConfig,
)
from ._global_search import GlobalSearchTool, GlobalSearchToolArgs, GlobalSearchToolReturn
from ._local_search import LocalSearchTool, LocalSearchToolArgs, LocalSearchToolReturn
__all__ = [
"GlobalSearchTool",
"LocalSearchTool",
"GlobalDataConfig",
"LocalDataConfig",
"GlobalContextConfig",
"GlobalSearchToolArgs",
"GlobalSearchToolReturn",
"LocalContextConfig",
"LocalSearchToolArgs",
"LocalSearchToolReturn",
"MapReduceConfig",
"SearchConfig",
]

View File

@@ -0,0 +1,59 @@
from pydantic import BaseModel
class DataConfig(BaseModel):
input_dir: str
entity_table: str = "entities"
entity_embedding_table: str = "entities"
community_table: str = "communities"
community_level: int = 2
class GlobalDataConfig(DataConfig):
community_report_table: str = "community_reports"
class LocalDataConfig(DataConfig):
relationship_table: str = "relationships"
text_unit_table: str = "text_units"
class ContextConfig(BaseModel):
max_data_tokens: int = 8000
class GlobalContextConfig(ContextConfig):
use_community_summary: bool = False
shuffle_data: bool = True
include_community_rank: bool = True
min_community_rank: int = 0
community_rank_name: str = "rank"
include_community_weight: bool = True
community_weight_name: str = "occurrence weight"
normalize_community_weight: bool = True
max_data_tokens: int = 12000
class LocalContextConfig(ContextConfig):
text_unit_prop: float = 0.5
community_prop: float = 0.25
include_entity_rank: bool = True
rank_description: str = "number of relationships"
include_relationship_weight: bool = True
relationship_ranking_attribute: str = "rank"
class MapReduceConfig(BaseModel):
map_max_tokens: int = 1000
map_temperature: float = 0.0
reduce_max_tokens: int = 2000
reduce_temperature: float = 0.0
allow_general_knowledge: bool = False
json_mode: bool = False
response_type: str = "multiple paragraphs"
class SearchConfig(BaseModel):
max_tokens: int = 1500
temperature: float = 0.0
response_type: str = "multiple paragraphs"

View File

@@ -0,0 +1,233 @@
from pathlib import Path
import pandas as pd
import tiktoken
from agentdhal_core import CancellationToken
from agentdhal_core.tools import BaseTool
from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
from graphrag.config.load_config import load_config
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol import ChatModel
from graphrag.query.indexer_adapters import (
read_indexer_communities,
read_indexer_entities,
read_indexer_reports,
)
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
from graphrag.query.structured_search.global_search.search import GlobalSearch
from ._config import GlobalContextConfig as ContextConfig
from ._config import GlobalDataConfig as DataConfig
from ._config import MapReduceConfig
_default_context_config = ContextConfig()
_default_mapreduce_config = MapReduceConfig()
class GlobalSearchToolArgs(BaseModel):
query: str = Field(..., description="The user query to perform global search on.")
class GlobalSearchToolReturn(BaseModel):
answer: str
class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]):
"""Enables running GraphRAG global search queries as an AutoGen tool.
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework.
The search combines graph-based document relationships with semantic embeddings to find relevant information.
.. note::
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package.
To install:
.. code-block:: bash
pip install -U "agentdhal-agentchat" "agentdhal-ext[graphrag]"
Before using this tool, you must complete the GraphRAG setup and indexing process:
1. Follow the GraphRAG documentation to initialize your project and settings
2. Configure and tune your prompts for the specific use case
3. Run the indexing process to generate the required data files
4. Ensure you have the settings.yaml file from the setup process
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/)
for detailed instructions on completing these prerequisite steps.
Example usage with AssistantAgent:
.. code-block:: python
import asyncio
from pathlib import Path
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.tools.graphrag import GlobalSearchTool
from agentdhal_agentchat.agents import AssistantAgent
async def main():
# Initialize the OpenAI client
openai_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key="<api-key>",
)
# Set up global search tool
global_tool = GlobalSearchTool.from_settings(root_dir=Path("./"), config_filepath=Path("./settings.yaml"))
# Create assistant agent with the global search tool
assistant_agent = AssistantAgent(
name="search_assistant",
tools=[global_tool],
model_client=openai_client,
system_message=(
"You are a tool selector AI assistant using the GraphRAG framework. "
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
"For broader, abstract questions requiring a comprehensive understanding of the dataset, call the 'global_search' function."
),
)
# Run a sample query
query = "What is the overall sentiment of the community reports?"
await Console(assistant_agent.run_stream(task=query))
if __name__ == "__main__":
asyncio.run(main())
"""
def __init__(
self,
token_encoder: tiktoken.Encoding,
model: ChatModel,
data_config: DataConfig,
context_config: ContextConfig = _default_context_config,
mapreduce_config: MapReduceConfig = _default_mapreduce_config,
):
super().__init__(
args_type=GlobalSearchToolArgs,
return_type=GlobalSearchToolReturn,
name="global_search_tool",
description="Perform a global search with given parameters using graphrag.",
)
# Use the provided model
self._model = model
# Load parquet files
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
report_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.community_report_table}.parquet"
)
# Fix: Use correct argument order and types for GraphRAG API
communities = read_indexer_communities(community_df, report_df)
reports = read_indexer_reports(report_df, community_df, data_config.community_level)
entities = read_indexer_entities(entity_df, community_df, data_config.community_level)
context_builder = GlobalCommunityContext(
community_reports=reports,
communities=communities,
entities=entities,
token_encoder=token_encoder,
)
context_builder_params = {
"use_community_summary": context_config.use_community_summary,
"shuffle_data": context_config.shuffle_data,
"include_community_rank": context_config.include_community_rank,
"min_community_rank": context_config.min_community_rank,
"community_rank_name": context_config.community_rank_name,
"include_community_weight": context_config.include_community_weight,
"community_weight_name": context_config.community_weight_name,
"normalize_community_weight": context_config.normalize_community_weight,
"max_tokens": context_config.max_data_tokens,
"context_name": "Reports",
}
map_llm_params = {
"max_tokens": mapreduce_config.map_max_tokens,
"temperature": mapreduce_config.map_temperature,
"response_format": {"type": "json_object"},
}
reduce_llm_params = {
"max_tokens": mapreduce_config.reduce_max_tokens,
"temperature": mapreduce_config.reduce_temperature,
}
self._search_engine = GlobalSearch(
model=self._model,
context_builder=context_builder,
token_encoder=token_encoder,
max_data_tokens=context_config.max_data_tokens,
map_llm_params=map_llm_params,
reduce_llm_params=reduce_llm_params,
allow_general_knowledge=mapreduce_config.allow_general_knowledge,
json_mode=mapreduce_config.json_mode,
context_builder_params=context_builder_params,
concurrent_coroutines=32,
response_type=mapreduce_config.response_type,
)
async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> GlobalSearchToolReturn:
search_result = await self._search_engine.search(args.query)
assert isinstance(search_result.response, str), "Expected response to be a string"
return GlobalSearchToolReturn(answer=search_result.response)
@classmethod
def from_settings(cls, root_dir: str | Path, config_filepath: str | Path | None = None) -> "GlobalSearchTool":
"""Create a GlobalSearchTool instance from GraphRAG settings file.
Args:
root_dir: Path to the GraphRAG root directory
config_filepath: Path to the GraphRAG settings file (optional)
Returns:
An initialized GlobalSearchTool instance
"""
# Load GraphRAG config
if isinstance(root_dir, str):
root_dir = Path(root_dir)
if isinstance(config_filepath, str):
config_filepath = Path(config_filepath)
config = load_config(root_dir=root_dir, config_filepath=config_filepath)
# Get the language model configuration from the models section
chat_model_config = config.models.get(defs.DEFAULT_CHAT_MODEL_ID)
if chat_model_config is None:
raise ValueError("default_chat_model not found in config.models")
# Initialize token encoder based on the model being used
try:
token_encoder = tiktoken.encoding_for_model(chat_model_config.model)
except KeyError:
# Fallback to cl100k_base if model is not recognized by tiktoken
token_encoder = tiktoken.get_encoding("cl100k_base")
# Create the LLM using ModelManager
model = ModelManager().get_or_create_chat_model(
name="global_search_model",
model_type=chat_model_config.type,
config=chat_model_config,
)
# Create data config from storage paths
data_config = DataConfig(
input_dir=str(config.output.base_dir),
)
return cls(
token_encoder=token_encoder,
model=model,
data_config=data_config,
context_config=_default_context_config,
mapreduce_config=_default_mapreduce_config,
)

View File

@@ -0,0 +1,245 @@
# mypy: disable-error-code="no-any-unimported,misc"
from pathlib import Path
import pandas as pd
import tiktoken
from agentdhal_core import CancellationToken
from agentdhal_core.tools import BaseTool
from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
from graphrag.config.load_config import load_config
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol import ChatModel, EmbeddingModel
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_text_units,
)
from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from ._config import LocalContextConfig, SearchConfig
from ._config import LocalDataConfig as DataConfig
_default_context_config = LocalContextConfig()
_default_search_config = SearchConfig()
class LocalSearchToolArgs(BaseModel):
query: str = Field(..., description="The user query to perform local search on.")
class LocalSearchToolReturn(BaseModel):
answer: str = Field(..., description="The answer to the user query.")
class LocalSearchTool(BaseTool[LocalSearchToolArgs, LocalSearchToolReturn]):
"""Enables running GraphRAG local search queries as an AutoGen tool.
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework.
The search combines local document context with semantic embeddings to find relevant information.
.. note::
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package.
To install:
.. code-block:: bash
pip install -U "agentdhal-agentchat" "agentdhal-ext[graphrag]"
Before using this tool, you must complete the GraphRAG setup and indexing process:
1. Follow the GraphRAG documentation to initialize your project and settings
2. Configure and tune your prompts for the specific use case
3. Run the indexing process to generate the required data files
4. Ensure you have the settings.yaml file from the setup process
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/)
for detailed instructions on completing these prerequisite steps.
Example usage with AssistantAgent:
.. code-block:: python
import asyncio
from pathlib import Path
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.tools.graphrag import LocalSearchTool
from agentdhal_agentchat.agents import AssistantAgent
async def main():
# Initialize the OpenAI client
openai_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key="<api-key>",
)
# Set up local search tool
local_tool = LocalSearchTool.from_settings(root_dir=Path("./"), config_filepath=Path("./settings.yaml"))
# Create assistant agent with the local search tool
assistant_agent = AssistantAgent(
name="search_assistant",
tools=[local_tool],
model_client=openai_client,
system_message=(
"You are a tool selector AI assistant using the GraphRAG framework. "
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
"For specific, detailed information about particular entities or relationships, call the 'local_search' function."
),
)
# Run a sample query
query = "What does the station-master say about Dr. Becher?"
await Console(assistant_agent.run_stream(task=query))
if __name__ == "__main__":
asyncio.run(main())
Args:
token_encoder (tiktoken.Encoding): The tokenizer used for text encoding
model: The chat model to use for search (GraphRAG ChatModel)
embedder: The text embedding model to use (GraphRAG EmbeddingModel)
data_config (DataConfig): Configuration for data source locations and settings
context_config (LocalContextConfig, optional): Configuration for context building. Defaults to default config.
search_config (SearchConfig, optional): Configuration for search operations. Defaults to default config.
"""
def __init__(
self,
token_encoder: tiktoken.Encoding,
model: ChatModel, # ChatModel from GraphRAG
embedder: EmbeddingModel, # EmbeddingModel from GraphRAG
data_config: DataConfig,
context_config: LocalContextConfig = _default_context_config,
search_config: SearchConfig = _default_search_config,
):
super().__init__(
args_type=LocalSearchToolArgs,
return_type=LocalSearchToolReturn,
name="local_search_tool",
description="Perform a local search with given parameters using graphrag.",
)
# Use the provided models
self._model = model
self._embedder = embedder
# Load parquet files
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
relationship_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.relationship_table}.parquet"
)
text_unit_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.text_unit_table}.parquet") # type: ignore
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
# Read data using indexer adapters
entities = read_indexer_entities(entity_df, community_df, data_config.community_level)
relationships = read_indexer_relationships(relationship_df)
text_units = read_indexer_text_units(text_unit_df)
# Set up vector store for entity embeddings
description_embedding_store = LanceDBVectorStore(
collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=f"{data_config.input_dir}/lancedb")
# Set up context builder
context_builder = LocalSearchMixedContext(
entities=entities,
entity_text_embeddings=description_embedding_store,
text_embedder=self._embedder,
text_units=text_units,
relationships=relationships,
token_encoder=token_encoder,
)
context_builder_params = {
"text_unit_prop": context_config.text_unit_prop,
"community_prop": context_config.community_prop,
"include_entity_rank": context_config.include_entity_rank,
"rank_description": context_config.rank_description,
"include_relationship_weight": context_config.include_relationship_weight,
"relationship_ranking_attribute": context_config.relationship_ranking_attribute,
"max_tokens": context_config.max_data_tokens,
}
llm_params = {
"max_tokens": search_config.max_tokens,
"temperature": search_config.temperature,
}
self._search_engine = LocalSearch(
model=self._model,
context_builder=context_builder,
token_encoder=token_encoder,
response_type=search_config.response_type,
context_builder_params=context_builder_params,
model_params=llm_params,
)
async def run(self, args: LocalSearchToolArgs, cancellation_token: CancellationToken) -> LocalSearchToolReturn:
search_result = await self._search_engine.search(args.query) # type: ignore[reportUnknownMemberType]
assert isinstance(search_result.response, str), "Expected response to be a string"
return LocalSearchToolReturn(answer=search_result.response)
@classmethod
def from_settings(cls, root_dir: Path, config_filepath: Path | None = None) -> "LocalSearchTool":
"""Create a LocalSearchTool instance from GraphRAG settings file.
Args:
root_dir: Path to the GraphRAG root directory
config_filepath: Path to the GraphRAG settings file (optional)
Returns:
An initialized LocalSearchTool instance
"""
# Load GraphRAG config
config = load_config(root_dir=root_dir, config_filepath=config_filepath)
# Get the language model configurations from the models section
chat_model_config = config.models.get(defs.DEFAULT_CHAT_MODEL_ID)
embedding_model_config = config.models.get(defs.DEFAULT_EMBEDDING_MODEL_ID)
if chat_model_config is None:
raise ValueError("default_chat_model not found in config.models")
if embedding_model_config is None:
raise ValueError("default_embedding_model not found in config.models")
# Initialize token encoder based on the model being used
try:
token_encoder = tiktoken.encoding_for_model(chat_model_config.model)
except KeyError:
# Fallback to cl100k_base if model is not recognized by tiktoken
token_encoder = tiktoken.get_encoding("cl100k_base")
# Create the models using ModelManager
model = ModelManager().get_or_create_chat_model(
name="local_search_model",
model_type=chat_model_config.type,
config=chat_model_config,
)
embedder = ModelManager().get_or_create_embedding_model(
name="local_search_embedder",
model_type=embedding_model_config.type,
config=embedding_model_config,
)
# Create data config from storage paths
data_config = DataConfig(
input_dir=str(config.output.base_dir),
)
return cls(
token_encoder=token_encoder,
model=model,
embedder=embedder,
data_config=data_config,
context_config=_default_context_config,
search_config=_default_search_config,
)

View File

@@ -0,0 +1,3 @@
from ._http_tool import HttpTool
__all__ = ["HttpTool"]

View File

@@ -0,0 +1,244 @@
import re
from typing import Any, Literal, Optional, Type
import httpx
from agentdhal_core import CancellationToken, Component
from agentdhal_core.tools import BaseTool
from json_schema_to_pydantic import create_model
from pydantic import BaseModel, Field
from typing_extensions import Self
DEFAULT_TIMEOUT_CONFIG = 5.0
class HttpToolConfig(BaseModel):
name: str
"""
The name of the tool.
"""
description: Optional[str]
"""
A description of the tool.
"""
scheme: Literal["http", "https"] = "http"
"""
The scheme to use for the request.
"""
host: str
"""
The URL to send the request to.
"""
port: int
"""
The port to send the request to.
"""
path: str = Field(default="/")
"""
The path to send the request to. defaults to "/"
The path can accept parameters, e.g. "/{param1}/{param2}".
These parameters will be templated from the inputs args, any additional parameters will be added as query parameters or the body of the request.
"""
method: Optional[Literal["GET", "POST", "PUT", "DELETE", "PATCH"]] = "POST"
"""
The HTTP method to use, will default to POST if not provided.
"""
headers: Optional[dict[str, Any]]
"""
A dictionary of headers to send with the request.
"""
json_schema: dict[str, Any]
"""
A JSON Schema object defining the expected parameters for the tool.
Path parameters MUST also be included in the json_schema. They must also MUST be set to string
"""
return_type: Optional[Literal["text", "json"]] = "text"
"""
The type of response to return from the tool.
"""
timeout: float = DEFAULT_TIMEOUT_CONFIG
"""
The timeout for the tool request in seconds.
"""
class HttpTool(BaseTool[BaseModel, Any], Component[HttpToolConfig]):
"""A wrapper for using an HTTP server as a tool.
Args:
name (str): The name of the tool.
description (str, optional): A description of the tool.
scheme (str): The scheme to use for the request. Must be either "http" or "https".
host (str): The host to send the request to.
port (int): The port to send the request to.
path (str, optional): The path to send the request to. Defaults to "/".
Can include path parameters like "/{param1}/{param2}" which will be templated from input args.
method (str, optional): The HTTP method to use, will default to POST if not provided.
Must be one of "GET", "POST", "PUT", "DELETE", "PATCH".
headers (dict[str, Any], optional): A dictionary of headers to send with the request.
json_schema (dict[str, Any]): A JSON Schema object defining the expected parameters for the tool.
Path parameters must also be included in the schema and must be strings.
return_type (Literal["text", "json"], optional): The type of response to return from the tool.
Defaults to "text".
timeout (float, optional): The timeout for HTTP requests in seconds.
Defaults to 5.0.
.. note::
This tool requires the :code:`http-tool` extra for the :code:`autogen-ext` package.
To install:
.. code-block:: bash
pip install -U "agentdhal-agentchat" "agentdhal-ext[http-tool]"
Example:
Simple use case::
import asyncio
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.messages import TextMessage
from agentdhal_core import CancellationToken
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.tools.http import HttpTool
# Define a JSON schema for a base64 decode tool
base64_schema = {
"type": "object",
"properties": {
"value": {"type": "string", "description": "The base64 value to decode"},
},
"required": ["value"],
}
# Create an HTTP tool for the httpbin API
base64_tool = HttpTool(
name="base64_decode",
description="base64 decode a value",
scheme="https",
host="httpbin.org",
port=443,
path="/base64/{value}",
method="GET",
json_schema=base64_schema,
)
async def main():
# Create an assistant with the base64 tool
model = OpenAIChatCompletionClient(model="gpt-4")
assistant = AssistantAgent("base64_assistant", model_client=model, tools=[base64_tool])
# The assistant can now use the base64 tool to decode the string
response = await assistant.on_messages(
[TextMessage(content="Can you base64 decode the value 'YWJjZGU=', please?", source="user")],
CancellationToken(),
)
print(response.chat_message)
asyncio.run(main())
"""
component_type = "tool"
component_provider_override = "agentdhal_extensions.tools.http.HttpTool"
component_config_schema = HttpToolConfig
def __init__(
self,
name: str,
host: str,
port: int,
json_schema: dict[str, Any],
headers: Optional[dict[str, Any]] = None,
description: str = "HTTP tool",
path: str = "/",
scheme: Literal["http", "https"] = "http",
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "POST",
return_type: Literal["text", "json"] = "text",
timeout: float = DEFAULT_TIMEOUT_CONFIG,
) -> None:
self.server_params = HttpToolConfig(
name=name,
description=description,
host=host,
port=port,
path=path,
scheme=scheme,
method=method,
headers=headers,
json_schema=json_schema,
return_type=return_type,
timeout=timeout,
)
# Use regex to find all path parameters, we will need those later to template the path
path_params = {match.group(1) for match in re.finditer(r"{([^}]*)}", path)}
self._path_params = path_params
# Create the input model from the modified schema
input_model = create_model(json_schema)
# Use Any as return type since HTTP responses can vary
base_return_type: Type[Any] = object
super().__init__(input_model, base_return_type, name, description)
def _to_config(self) -> HttpToolConfig:
copied_config = self.server_params.model_copy()
return copied_config
@classmethod
def _from_config(cls, config: HttpToolConfig) -> Self:
copied_config = config.model_copy().model_dump()
return cls(**copied_config)
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
"""Execute the HTTP tool with the given arguments.
Args:
args: The validated input arguments
cancellation_token: Token for cancelling the operation
Returns:
The response body from the HTTP call in JSON format
Raises:
Exception: If tool execution fails
"""
model_dump = args.model_dump()
path_params = {k: v for k, v in model_dump.items() if k in self._path_params}
# Remove path params from the model dump
for k in self._path_params:
model_dump.pop(k)
path = self.server_params.path.format(**path_params)
url = httpx.URL(
scheme=self.server_params.scheme,
host=self.server_params.host,
port=self.server_params.port,
path=path,
)
timeout_config = httpx.Timeout(timeout=self.server_params.timeout)
async with httpx.AsyncClient(timeout=timeout_config) as client:
match self.server_params.method:
case "GET":
response = await client.get(url, headers=self.server_params.headers, params=model_dump)
case "PUT":
response = await client.put(url, headers=self.server_params.headers, json=model_dump)
case "DELETE":
response = await client.delete(url, headers=self.server_params.headers, params=model_dump)
case "PATCH":
response = await client.patch(url, headers=self.server_params.headers, json=model_dump)
case _: # Default case POST
response = await client.post(url, headers=self.server_params.headers, json=model_dump)
match self.server_params.return_type:
case "text":
return response.text
case "json":
return response.json()
case _:
raise ValueError(f"Invalid return type: {self.server_params.return_type}")

View File

@@ -0,0 +1,3 @@
from ._langchain_adapter import LangChainToolAdapter
__all__ = ["LangChainToolAdapter"]

View File

@@ -0,0 +1,198 @@
from __future__ import annotations
import asyncio
import inspect
from typing import TYPE_CHECKING, Any, Callable, Dict, Type, cast
from agentdhal_core import CancellationToken
from agentdhal_core.tools import BaseTool
from pydantic import BaseModel, Field, create_model
if TYPE_CHECKING:
from langchain_core.tools import BaseTool as LangChainTool
class LangChainToolAdapter(BaseTool[BaseModel, Any]):
"""Allows you to wrap a LangChain tool and make it available to AutoGen.
.. note::
This class requires the :code:`langchain` extra for the :code:`autogen-ext` package.
.. code-block:: bash
pip install -U "agentdhal-ext[langchain]"
Args:
langchain_tool (LangChainTool): A LangChain tool to wrap
Examples:
Use the `PythonAstREPLTool` from the `langchain_experimental` package to
create a tool that allows you to interact with a Pandas DataFrame.
.. code-block:: python
import asyncio
import pandas as pd
from langchain_experimental.tools.python.tool import PythonAstREPLTool
from agentdhal_extensions.tools.langchain import LangChainToolAdapter
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_agentchat.messages import TextMessage
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.ui import Console
from agentdhal_core import CancellationToken
async def main() -> None:
df = pd.read_csv("https://raw.githubusercontent.com/pandas-dev/pandas/main/doc/data/titanic.csv") # type: ignore
tool = LangChainToolAdapter(PythonAstREPLTool(locals={"df": df}))
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(
"assistant",
tools=[tool],
model_client=model_client,
system_message="Use the `df` variable to access the dataset.",
)
await Console(
agent.on_messages_stream(
[TextMessage(content="What's the average age of the passengers?", source="user")], CancellationToken()
)
)
asyncio.run(main())
This example demonstrates how to use the `SQLDatabaseToolkit` from the `langchain_community`
package to interact with an SQLite database.
It uses the :class:`~agentdhal_agentchat.team.RoundRobinGroupChat` to iterate the single agent over multiple steps.
If you want to one step at a time, you can just call `run_stream` method of the
:class:`~agentdhal_agentchat.agents.AssistantAgent` class directly.
.. code-block:: python
import asyncio
import sqlite3
import requests
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.conditions import TextMentionTermination
from agentdhal_agentchat.teams import RoundRobinGroupChat
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.tools.langchain import LangChainToolAdapter
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_openai import ChatOpenAI
from sqlalchemy import Engine, create_engine
from sqlalchemy.pool import StaticPool
def get_engine_for_chinook_db() -> Engine:
url = "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql"
response = requests.get(url)
sql_script = response.text
connection = sqlite3.connect(":memory:", check_same_thread=False)
connection.executescript(sql_script)
return create_engine(
"sqlite://",
creator=lambda: connection,
poolclass=StaticPool,
connect_args={"check_same_thread": False},
)
async def main() -> None:
# Create the engine and database wrapper.
engine = get_engine_for_chinook_db()
db = SQLDatabase(engine)
# Create the toolkit.
llm = ChatOpenAI(temperature=0)
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
# Create the LangChain tool adapter for every tool in the toolkit.
tools = [LangChainToolAdapter(tool) for tool in toolkit.get_tools()]
# Create the chat completion client.
model_client = OpenAIChatCompletionClient(model="gpt-4o")
# Create the assistant agent.
agent = AssistantAgent(
"assistant",
model_client=model_client,
tools=tools, # type: ignore
model_client_stream=True,
system_message="Respond with 'TERMINATE' if the task is completed.",
)
# Create termination condition.
termination = TextMentionTermination("TERMINATE")
# Create a round-robin group chat to iterate the single agent over multiple steps.
chat = RoundRobinGroupChat([agent], termination_condition=termination)
# Run the chat.
await Console(chat.run_stream(task="Show some tables in the database"))
if __name__ == "__main__":
asyncio.run(main())
"""
def __init__(self, langchain_tool: LangChainTool):
self._langchain_tool: LangChainTool = langchain_tool
# Extract name and description
name = self._langchain_tool.name
description = self._langchain_tool.description or ""
# Determine the callable method
if hasattr(self._langchain_tool, "func") and callable(self._langchain_tool.func): # type: ignore
assert self._langchain_tool.func is not None # type: ignore
self._callable: Callable[..., Any] = self._langchain_tool.func # type: ignore
elif hasattr(self._langchain_tool, "_run") and callable(self._langchain_tool._run): # type: ignore
self._callable: Callable[..., Any] = self._langchain_tool._run # type: ignore
else:
raise AttributeError(
f"The provided LangChain tool '{name}' does not have a callable 'func' or '_run' method."
)
# Determine args_type
if self._langchain_tool.args_schema: # pyright: ignore
args_type = self._langchain_tool.args_schema # pyright: ignore
else:
# Infer args_type from the callable's signature
sig = inspect.signature(cast(Callable[..., Any], self._callable)) # type: ignore
fields = {
k: (v.annotation, Field(...))
for k, v in sig.parameters.items()
if k != "self" and v.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
}
args_type = create_model(f"{name}Args", **fields) # type: ignore
# Note: type ignore is used due to a LangChain typing limitation
# Ensure args_type is a subclass of BaseModel
if not issubclass(args_type, BaseModel):
raise ValueError(f"Failed to create a valid Pydantic v2 model for {name}")
# Assume return_type as Any if not specified
return_type: Type[Any] = object
super().__init__(args_type, return_type, name, description)
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
# Prepare arguments
kwargs = args.model_dump()
# Determine if the callable is asynchronous
if inspect.iscoroutinefunction(self._callable):
return await self._callable(**kwargs)
else:
# Run in a thread to avoid blocking the event loop
return await asyncio.to_thread(self._call_sync, kwargs)
def _call_sync(self, kwargs: Dict[str, Any]) -> Any:
return self._callable(**kwargs)

View File

@@ -0,0 +1,22 @@
from ._actor import McpSessionActor
from ._config import McpServerParams, SseServerParams, StdioServerParams, StreamableHttpServerParams
from ._factory import mcp_server_tools
from ._session import create_mcp_server_session
from ._sse import SseMcpToolAdapter
from ._stdio import StdioMcpToolAdapter
from ._streamable_http import StreamableHttpMcpToolAdapter
from ._workbench import McpWorkbench
__all__ = [
"create_mcp_server_session",
"McpSessionActor",
"StdioMcpToolAdapter",
"StdioServerParams",
"SseMcpToolAdapter",
"SseServerParams",
"StreamableHttpMcpToolAdapter",
"StreamableHttpServerParams",
"McpServerParams",
"mcp_server_tools",
"McpWorkbench",
]

View File

@@ -0,0 +1,310 @@
import asyncio
import atexit
import base64
import io
import logging
from typing import Any, Coroutine, Dict, Mapping, TypedDict
from agentdhal_core import Component, ComponentBase, ComponentModel, Image
from agentdhal_core.models import (
AssistantMessage,
ChatCompletionClient,
LLMMessage,
ModelInfo,
SystemMessage,
UserMessage,
)
from PIL import Image as PILImage
from pydantic import BaseModel
from typing_extensions import Self
from mcp import types as mcp_types
from mcp.client.session import ClientSession
from mcp.shared.context import RequestContext
from ._config import McpServerParams
from ._session import create_mcp_server_session
logger = logging.getLogger(__name__)
McpResult = (
Coroutine[Any, Any, mcp_types.ListToolsResult]
| Coroutine[Any, Any, mcp_types.CallToolResult]
| Coroutine[Any, Any, mcp_types.ListPromptsResult]
| Coroutine[Any, Any, mcp_types.ListResourcesResult]
| Coroutine[Any, Any, mcp_types.ListResourceTemplatesResult]
| Coroutine[Any, Any, mcp_types.ReadResourceResult]
| Coroutine[Any, Any, mcp_types.GetPromptResult]
)
McpFuture = asyncio.Future[McpResult]
def _parse_sampling_content(
content: mcp_types.TextContent | mcp_types.ImageContent | mcp_types.AudioContent, model_info: ModelInfo
) -> str | Image:
"""Convert MCP content types to Autogen content types."""
if content.type == "text":
return content.text
elif content.type == "image":
if not model_info["vision"]:
raise ValueError("Sampling model does not support image content.")
# Decode base64 image data and create PIL Image
image_data = base64.b64decode(content.data)
pil_image = PILImage.open(io.BytesIO(image_data))
return Image.from_pil(pil_image)
else:
raise ValueError(f"Unsupported content type: {content.type}")
def _parse_sampling_message(message: mcp_types.SamplingMessage, model_info: ModelInfo) -> LLMMessage:
"""Convert MCP sampling messages to Autogen messages."""
content = _parse_sampling_content(message.content, model_info=model_info)
if message.role == "user":
return UserMessage(
source="user",
content=[content],
)
elif message.role == "assistant":
assert isinstance(content, str), "Assistant messages only support string content."
return AssistantMessage(
source="assistant",
content=content,
)
else:
raise ValueError(f"Unrecognized message role: {message.role}")
class McpActorArgs(TypedDict):
name: str | None
kargs: Mapping[str, Any]
class McpSessionActorConfig(BaseModel):
server_params: McpServerParams
model_client: ComponentModel | Dict[str, Any] | None = None
class McpSessionActor(ComponentBase[BaseModel], Component[McpSessionActorConfig]):
component_type = "mcp_session_actor"
component_config_schema = McpSessionActorConfig
component_provider_override = "agentdhal_extensions.tools.mcp.McpSessionActor"
server_params: McpServerParams
# model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, server_params: McpServerParams, model_client: ChatCompletionClient | None = None) -> None:
self.server_params: McpServerParams = server_params
self._model_client = model_client
self.name = "mcp_session_actor"
self.description = "MCP session actor"
self._command_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
self._actor_task: asyncio.Task[Any] | None = None
self._shutdown_future: asyncio.Future[Any] | None = None
self._active = False
self._initialize_result: mcp_types.InitializeResult | None = None
atexit.register(self._sync_shutdown)
@property
def initialize_result(self) -> mcp_types.InitializeResult | None:
return self._initialize_result
async def initialize(self) -> None:
if not self._active:
self._active = True
self._actor_task = asyncio.create_task(self._run_actor())
async def call(self, type: str, args: McpActorArgs | None = None) -> McpFuture:
if not self._active:
raise RuntimeError("MCP Actor not running, call initialize() first")
if self._actor_task and self._actor_task.done():
raise RuntimeError("MCP actor task crashed", self._actor_task.exception())
fut: asyncio.Future[McpFuture] = asyncio.Future()
if type in {"list_tools", "list_prompts", "list_resources", "list_resource_templates", "shutdown"}:
await self._command_queue.put({"type": type, "future": fut})
res = await fut
elif type in {"call_tool", "read_resource", "get_prompt"}:
if args is None:
raise ValueError(f"args is required for {type}")
name = args.get("name", None)
kwargs = args.get("kargs", {})
if type == "call_tool" and name is None:
raise ValueError("name is required for call_tool")
elif type == "read_resource":
uri = kwargs.get("uri", None)
if uri is None:
raise ValueError("uri is required for read_resource")
await self._command_queue.put({"type": type, "uri": uri, "future": fut})
elif type == "get_prompt":
if name is None:
raise ValueError("name is required for get_prompt")
prompt_args = kwargs.get("arguments", None)
await self._command_queue.put({"type": type, "name": name, "args": prompt_args, "future": fut})
else: # call_tool
await self._command_queue.put({"type": type, "name": name, "args": kwargs, "future": fut})
res = await fut
else:
raise ValueError(f"Unknown command type: {type}")
return res
async def close(self) -> None:
if not self._active or self._actor_task is None:
return
self._shutdown_future = asyncio.Future()
await self._command_queue.put({"type": "shutdown", "future": self._shutdown_future})
await self._shutdown_future
await self._actor_task
self._active = False
async def _sampling_callback(
self,
context: RequestContext[ClientSession, Any],
params: mcp_types.CreateMessageRequestParams,
) -> mcp_types.CreateMessageResult | mcp_types.ErrorData:
"""Handle sampling requests using the provided model client."""
if self._model_client is None:
# Return an error when no model client is available
return mcp_types.ErrorData(
code=mcp_types.INVALID_REQUEST,
message="No model client available for sampling.",
data=None,
)
llm_messages: list[LLMMessage] = []
try:
if params.systemPrompt:
llm_messages.append(SystemMessage(content=params.systemPrompt))
for mcp_message in params.messages:
llm_messages.append(_parse_sampling_message(mcp_message, model_info=self._model_client.model_info))
except Exception as e:
return mcp_types.ErrorData(
code=mcp_types.INVALID_PARAMS,
message="Error processing sampling messages.",
data=f"{type(e).__name__}: {e}",
)
try:
result = await self._model_client.create(messages=llm_messages)
content = result.content
if not isinstance(content, str):
content = str(content)
return mcp_types.CreateMessageResult(
role="assistant",
content=mcp_types.TextContent(type="text", text=content),
model=self._model_client.model_info["family"],
stopReason=result.finish_reason,
)
except Exception as e:
return mcp_types.ErrorData(
code=mcp_types.INTERNAL_ERROR,
message="Error sampling from model client.",
data=f"{type(e).__name__}: {e}",
)
async def _run_actor(self) -> None:
result: McpResult
try:
async with create_mcp_server_session(
self.server_params, sampling_callback=self._sampling_callback
) as session:
# Save the initialize result
self._initialize_result = await session.initialize()
while True:
cmd = await self._command_queue.get()
if cmd["type"] == "shutdown":
cmd["future"].set_result("ok")
break
elif cmd["type"] == "call_tool":
try:
result = session.call_tool(name=cmd["name"], arguments=cmd["args"])
cmd["future"].set_result(result)
except Exception as e:
cmd["future"].set_exception(e)
elif cmd["type"] == "read_resource":
try:
result = session.read_resource(uri=cmd["uri"])
cmd["future"].set_result(result)
except Exception as e:
cmd["future"].set_exception(e)
elif cmd["type"] == "get_prompt":
try:
result = session.get_prompt(name=cmd["name"], arguments=cmd["args"])
cmd["future"].set_result(result)
except Exception as e:
cmd["future"].set_exception(e)
elif cmd["type"] == "list_tools":
try:
result = session.list_tools()
cmd["future"].set_result(result)
except Exception as e:
cmd["future"].set_exception(e)
elif cmd["type"] == "list_prompts":
try:
result = session.list_prompts()
cmd["future"].set_result(result)
except Exception as e:
cmd["future"].set_exception(e)
elif cmd["type"] == "list_resources":
try:
result = session.list_resources()
cmd["future"].set_result(result)
except Exception as e:
cmd["future"].set_exception(e)
elif cmd["type"] == "list_resource_templates":
try:
result = session.list_resource_templates()
cmd["future"].set_result(result)
except Exception as e:
cmd["future"].set_exception(e)
except Exception as e:
if self._shutdown_future and not self._shutdown_future.done():
self._shutdown_future.set_exception(e)
else:
logger.exception("Exception in MCP actor task")
finally:
self._active = False
self._actor_task = None
def _sync_shutdown(self) -> None:
if not self._active or self._actor_task is None:
return
try:
loop = asyncio.get_event_loop()
except RuntimeError:
# No loop available — interpreter is likely shutting down
return
if loop.is_closed():
return
if loop.is_running():
loop.create_task(self.close())
else:
loop.run_until_complete(self.close())
def _to_config(self) -> McpSessionActorConfig:
"""
Convert the adapter to its configuration representation.
Returns:
McpSessionConfig: The configuration of the adapter.
"""
return McpSessionActorConfig(server_params=self.server_params)
@classmethod
def _from_config(cls, config: McpSessionActorConfig) -> Self:
"""
Create an instance of McpSessionActor from its configuration.
Args:
config (McpSessionConfig): The configuration of the adapter.
Returns:
McpSessionActor: An instance of SseMcpToolAdapter.
"""
return cls(server_params=config.server_params)

View File

@@ -0,0 +1,190 @@
import asyncio
import builtins
import json
from abc import ABC
from typing import Any, Dict, Generic, Sequence, Type, TypeVar
from agentdhal_core import CancellationToken
from agentdhal_core.tools import BaseTool
from agentdhal_core.utils import schema_to_pydantic_model
from pydantic import BaseModel
from pydantic.networks import AnyUrl
from mcp import ClientSession, Tool
from mcp.types import AudioContent, ContentBlock, EmbeddedResource, ImageContent, ResourceLink, TextContent
from ._config import McpServerParams
from ._session import create_mcp_server_session
TServerParams = TypeVar("TServerParams", bound=McpServerParams)
class McpToolAdapter(BaseTool[BaseModel, Any], ABC, Generic[TServerParams]):
"""
Base adapter class for MCP tools to make them compatible with AutoGen.
Args:
server_params (TServerParams): Parameters for the MCP server connection.
tool (Tool): The MCP tool to wrap.
"""
component_type = "tool"
def __init__(self, server_params: TServerParams, tool: Tool, session: ClientSession | None = None) -> None:
self._tool = tool
self._server_params = server_params
self._session = session
# Extract name and description
name = tool.name
description = tool.description or ""
# Create the input model from the tool's schema
input_model = schema_to_pydantic_model(tool.inputSchema)
# Use Any as return type since MCP tool returns can vary
return_type: Type[Any] = object
super().__init__(input_model, return_type, name, description)
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
"""
Run the MCP tool with the provided arguments.
Args:
args (BaseModel): The arguments to pass to the tool.
cancellation_token (CancellationToken): Token to signal cancellation.
Returns:
Any: The result of the tool execution.
Raises:
Exception: If the operation is cancelled or the tool execution fails.
"""
# Convert the input model to a dictionary
# Exclude unset values to avoid sending them to the MCP servers which may cause errors
# for many servers.
kwargs = args.model_dump(exclude_unset=True)
if self._session is not None:
# If a session is provided, use it directly.
session = self._session
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)
async with create_mcp_server_session(self._server_params) as session:
await session.initialize()
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)
def _normalize_payload_to_content_list(self, payload: Sequence[ContentBlock]) -> list[ContentBlock]:
"""
Normalizes a raw tool output payload into a list of content items.
- If payload is already a sequence of ContentBlock items, it's converted to a list and returned.
- If payload is a single ContentBlock item, it's wrapped in a list.
- If payload is a string, it's wrapped in [TextContent(text=payload)].
- Otherwise, the payload is stringified and wrapped in [TextContent(text=str(payload))].
"""
if isinstance(payload, Sequence) and all(
isinstance(item, (TextContent, ImageContent, EmbeddedResource, AudioContent, ResourceLink))
for item in payload
):
return list(payload)
elif isinstance(payload, (TextContent, ImageContent, EmbeddedResource, AudioContent, ResourceLink)):
return [payload]
elif isinstance(payload, str):
return [TextContent(text=payload, type="text")]
else:
return [TextContent(text=str(payload), type="text")]
async def _run(self, args: Dict[str, Any], cancellation_token: CancellationToken, session: ClientSession) -> Any:
exceptions_to_catch: tuple[Type[BaseException], ...]
if hasattr(builtins, "ExceptionGroup"):
exceptions_to_catch = (asyncio.CancelledError, builtins.ExceptionGroup)
else:
exceptions_to_catch = (asyncio.CancelledError,)
try:
if cancellation_token.is_cancelled():
raise asyncio.CancelledError("Operation cancelled")
result_future = asyncio.ensure_future(session.call_tool(name=self._tool.name, arguments=args))
cancellation_token.link_future(result_future)
result = await result_future
normalized_content_list = self._normalize_payload_to_content_list(result.content)
if result.isError:
serialized_error_message = self.return_value_as_string(normalized_content_list)
raise Exception(serialized_error_message)
return normalized_content_list
except exceptions_to_catch:
# Re-raise these specific exception types directly.
raise
@classmethod
async def from_server_params(cls, server_params: TServerParams, tool_name: str) -> "McpToolAdapter[TServerParams]":
"""
Create an instance of McpToolAdapter from server parameters and tool name.
Args:
server_params (TServerParams): Parameters for the MCP server connection.
tool_name (str): The name of the tool to wrap.
Returns:
McpToolAdapter[TServerParams]: An instance of McpToolAdapter.
Raises:
ValueError: If the tool with the specified name is not found.
"""
async with create_mcp_server_session(server_params) as session:
await session.initialize()
tools_response = await session.list_tools()
matching_tool = next((t for t in tools_response.tools if t.name == tool_name), None)
if matching_tool is None:
raise ValueError(
f"Tool '{tool_name}' not found, available tools: {', '.join([t.name for t in tools_response.tools])}"
)
return cls(server_params=server_params, tool=matching_tool)
def return_value_as_string(self, value: list[Any]) -> str:
"""Return a string representation of the result."""
def serialize_item(item: Any) -> dict[str, Any]:
if isinstance(item, (TextContent, ImageContent, AudioContent)):
dumped = item.model_dump()
# Remove the 'meta' field if it exists and is None (for backward compatibility)
if dumped.get("meta") is None:
dumped.pop("meta", None)
return dumped
elif isinstance(item, EmbeddedResource):
type = item.type
resource = {}
for key, val in item.resource.model_dump().items():
# Skip 'meta' field if it's None (for backward compatibility)
if key == "meta" and val is None:
continue
if isinstance(val, AnyUrl):
resource[key] = str(val)
else:
resource[key] = val
dumped_annotations = item.annotations.model_dump() if item.annotations else None
# Remove 'meta' from annotations if it exists and is None
if dumped_annotations and dumped_annotations.get("meta") is None:
dumped_annotations.pop("meta", None)
return {"type": type, "resource": resource, "annotations": dumped_annotations}
elif isinstance(item, ResourceLink):
dumped = item.model_dump()
# Remove the 'meta' field if it exists and is None (for backward compatibility)
if dumped.get("meta") is None:
dumped.pop("meta", None)
# Convert AnyUrl to string for JSON serialization
if "uri" in dumped and isinstance(dumped["uri"], AnyUrl):
dumped["uri"] = str(dumped["uri"])
return dumped
else:
return {}
return json.dumps([serialize_item(item) for item in value])

View File

@@ -0,0 +1,42 @@
from typing import Any, Literal
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from mcp import StdioServerParameters
class StdioServerParams(StdioServerParameters):
"""Parameters for connecting to an MCP server over STDIO."""
type: Literal["StdioServerParams"] = "StdioServerParams"
read_timeout_seconds: float = 5
class SseServerParams(BaseModel):
"""Parameters for connecting to an MCP server over SSE."""
type: Literal["SseServerParams"] = "SseServerParams"
url: str # The SSE endpoint URL.
headers: dict[str, Any] | None = None # Optional headers to include in requests.
timeout: float = 5 # HTTP timeout for regular operations.
sse_read_timeout: float = 60 * 5 # Timeout for SSE read operations.
class StreamableHttpServerParams(BaseModel):
"""Parameters for connecting to an MCP server over Streamable HTTP."""
type: Literal["StreamableHttpServerParams"] = "StreamableHttpServerParams"
url: str # The endpoint URL.
headers: dict[str, Any] | None = None # Optional headers to include in requests.
timeout: float = 30.0 # HTTP timeout for regular operations in seconds.
sse_read_timeout: float = 300.0 # Timeout for SSE read operations in seconds.
terminate_on_close: bool = True
McpServerParams = Annotated[
StdioServerParams | SseServerParams | StreamableHttpServerParams, Field(discriminator="type")
]

View File

@@ -0,0 +1,214 @@
from mcp import ClientSession
from ._config import McpServerParams, SseServerParams, StdioServerParams, StreamableHttpServerParams
from ._session import create_mcp_server_session
from ._sse import SseMcpToolAdapter
from ._stdio import StdioMcpToolAdapter
from ._streamable_http import StreamableHttpMcpToolAdapter
async def mcp_server_tools(
server_params: McpServerParams,
session: ClientSession | None = None,
) -> list[StdioMcpToolAdapter | SseMcpToolAdapter | StreamableHttpMcpToolAdapter]:
"""Creates a list of MCP tool adapters that can be used with AutoGen agents.
.. warning::
Only connect to trusted MCP servers, especially when using
`StdioServerParams` as it executes commands in the local environment.
This factory function connects to an MCP server and returns adapters for all available tools.
The adapters can be directly assigned to an AutoGen agent's tools list.
.. note::
To use this function, you need to install `mcp` extra for the `autogen-ext` package.
.. code-block:: bash
pip install -U "agentdhal-ext[mcp]"
Args:
server_params (McpServerParams): Connection parameters for the MCP server.
Can be either StdioServerParams for command-line tools or
SseServerParams and StreamableHttpServerParams for HTTP/SSE services.
session (ClientSession | None): Optional existing session to use. This is used
when you want to reuse an existing connection to the MCP server. The session
will be reused when creating the MCP tool adapters.
Returns:
list[StdioMcpToolAdapter | SseMcpToolAdapter | StreamableHttpMcpToolAdapter]:
A list of tool adapters ready to use with AutoGen agents.
Examples:
**Local file system MCP service over standard I/O example:**
Install the filesystem server package from npm (requires Node.js 16+ and npm).
.. code-block:: bash
npm install -g @modelcontextprotocol/server-filesystem
Create an agent that can use all tools from the local filesystem MCP server.
.. code-block:: python
import asyncio
from pathlib import Path
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.tools.mcp import StdioServerParams, mcp_server_tools
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_core import CancellationToken
async def main() -> None:
# Setup server params for local filesystem access
desktop = str(Path.home() / "Desktop")
server_params = StdioServerParams(
command="npx.cmd", args=["-y", "@modelcontextprotocol/server-filesystem", desktop]
)
# Get all available tools from the server
tools = await mcp_server_tools(server_params)
# Create an agent that can use all the tools
agent = AssistantAgent(
name="file_manager",
model_client=OpenAIChatCompletionClient(model="gpt-4"),
tools=tools, # type: ignore
)
# The agent can now use any of the filesystem tools
await agent.run(task="Create a file called test.txt with some content", cancellation_token=CancellationToken())
if __name__ == "__main__":
asyncio.run(main())
**Local fetch MCP service over standard I/O example:**
Install the `mcp-server-fetch` package.
.. code-block:: bash
pip install mcp-server-fetch
Create an agent that can use the `fetch` tool from the local MCP server.
.. code-block:: python
import asyncio
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.tools.mcp import StdioServerParams, mcp_server_tools
async def main() -> None:
# Get the fetch tool from mcp-server-fetch.
fetch_mcp_server = StdioServerParams(command="uvx", args=["mcp-server-fetch"])
tools = await mcp_server_tools(fetch_mcp_server)
# Create an agent that can use the fetch tool.
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(name="fetcher", model_client=model_client, tools=tools, reflect_on_tool_use=True) # type: ignore
# Let the agent fetch the content of a URL and summarize it.
result = await agent.run(task="Summarize the content of https://en.wikipedia.org/wiki/Seattle")
print(result.messages[-1])
asyncio.run(main())
**Sharing an MCP client session across multiple tools:**
You can create a single MCP client session and share it across multiple tools.
This is sometimes required when the server maintains a session state
(e.g., a browser state) that should be reused for multiple requests.
The following example show how to create a single MCP client session
to a local `Playwright <https://github.com/microsoft/playwright-mcp>`_
server and share it across multiple tools.
.. code-block:: python
import asyncio
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.conditions import TextMentionTermination
from agentdhal_agentchat.teams import RoundRobinGroupChat
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.tools.mcp import StdioServerParams, create_mcp_server_session, mcp_server_tools
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o", parallel_tool_calls=False) # type: ignore
params = StdioServerParams(
command="npx",
args=["@playwright/mcp@latest"],
read_timeout_seconds=60,
)
async with create_mcp_server_session(params) as session:
await session.initialize()
tools = await mcp_server_tools(server_params=params, session=session)
print(f"Tools: {[tool.name for tool in tools]}")
agent = AssistantAgent(
name="Assistant",
model_client=model_client,
tools=tools, # type: ignore
)
termination = TextMentionTermination("TERMINATE")
team = RoundRobinGroupChat([agent], termination_condition=termination)
await Console(
team.run_stream(
task="Go to https://ekzhu.com/, visit the first link in the page, then tell me about the linked page."
)
)
asyncio.run(main())
**Remote MCP service over SSE example:**
.. code-block:: python
from agentdhal_extensions.tools.mcp import SseServerParams, mcp_server_tools
async def main() -> None:
# Setup server params for remote service
server_params = SseServerParams(url="https://api.example.com/mcp", headers={"Authorization": "Bearer token"})
# Get all available tools
tools = await mcp_server_tools(server_params)
# Create an agent with all tools
agent = AssistantAgent(name="tool_user", model_client=OpenAIChatCompletionClient(model="gpt-4"), tools=tools) # type: ignore
For more examples and detailed usage, see the samples directory in the package repository.
"""
if session is None:
async with create_mcp_server_session(server_params) as temp_session:
await temp_session.initialize()
tools = await temp_session.list_tools()
else:
tools = await session.list_tools()
if isinstance(server_params, StdioServerParams):
return [StdioMcpToolAdapter(server_params=server_params, tool=tool, session=session) for tool in tools.tools]
elif isinstance(server_params, SseServerParams):
return [SseMcpToolAdapter(server_params=server_params, tool=tool, session=session) for tool in tools.tools]
elif isinstance(server_params, StreamableHttpServerParams):
return [
StreamableHttpMcpToolAdapter(server_params=server_params, tool=tool, session=session)
for tool in tools.tools
]
raise ValueError(f"Unsupported server params type: {type(server_params)}")

View File

@@ -0,0 +1,55 @@
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import AsyncGenerator
from mcp import ClientSession
from mcp.client.session import SamplingFnT
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from ._config import McpServerParams, SseServerParams, StdioServerParams, StreamableHttpServerParams
@asynccontextmanager
async def create_mcp_server_session(
server_params: McpServerParams, sampling_callback: SamplingFnT | None = None
) -> AsyncGenerator[ClientSession, None]:
"""Create an MCP client session for the given server parameters."""
if isinstance(server_params, StdioServerParams):
async with stdio_client(server_params) as (read, write):
async with ClientSession(
read_stream=read,
write_stream=write,
read_timeout_seconds=timedelta(seconds=server_params.read_timeout_seconds),
sampling_callback=sampling_callback,
) as session:
yield session
elif isinstance(server_params, SseServerParams):
async with sse_client(**server_params.model_dump(exclude={"type"})) as (read, write):
async with ClientSession(
read_stream=read,
write_stream=write,
read_timeout_seconds=timedelta(seconds=server_params.sse_read_timeout),
sampling_callback=sampling_callback,
) as session:
yield session
elif isinstance(server_params, StreamableHttpServerParams):
# Convert float seconds to timedelta for the streamablehttp_client
params_dict = server_params.model_dump(exclude={"type"})
params_dict["timeout"] = timedelta(seconds=server_params.timeout)
params_dict["sse_read_timeout"] = timedelta(seconds=server_params.sse_read_timeout)
async with streamablehttp_client(**params_dict) as (
read,
write,
session_id_callback, # type: ignore[assignment, unused-variable]
):
# TODO: Handle session_id_callback if needed
async with ClientSession(
read_stream=read,
write_stream=write,
read_timeout_seconds=timedelta(seconds=server_params.sse_read_timeout),
sampling_callback=sampling_callback,
) as session:
yield session

View File

@@ -0,0 +1,116 @@
from agentdhal_core import Component
from pydantic import BaseModel
from typing_extensions import Self
from mcp import ClientSession, Tool
from ._base import McpToolAdapter
from ._config import SseServerParams
class SseMcpToolAdapterConfig(BaseModel):
"""Configuration for the MCP tool adapter."""
server_params: SseServerParams
tool: Tool
class SseMcpToolAdapter(
McpToolAdapter[SseServerParams],
Component[SseMcpToolAdapterConfig],
):
"""
Allows you to wrap an MCP tool running over Server-Sent Events (SSE) and make it available to AutoGen.
This adapter enables using MCP-compatible tools that communicate over HTTP with SSE
with AutoGen agents. Common use cases include integrating with remote MCP services,
cloud-based tools, and web APIs that implement the Model Context Protocol (MCP).
.. note::
To use this class, you need to install `mcp` extra for the `autogen-ext` package.
.. code-block:: bash
pip install -U "agentdhal-ext[mcp]"
Args:
server_params (SseServerParameters): Parameters for the MCP server connection,
including URL, headers, and timeouts.
tool (Tool): The MCP tool to wrap.
session (ClientSession, optional): The MCP client session to use. If not provided,
it will create a new session. This is useful for testing or when you want to
manage the session lifecycle yourself.
Examples:
Use a remote translation service that implements MCP over SSE to create tools
that allow AutoGen agents to perform translations:
.. code-block:: python
import asyncio
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.tools.mcp import SseMcpToolAdapter, SseServerParams
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.ui import Console
from agentdhal_core import CancellationToken
async def main() -> None:
# Create server params for the remote MCP service
server_params = SseServerParams(
url="https://api.example.com/mcp",
headers={"Authorization": "Bearer your-api-key", "Content-Type": "application/json"},
timeout=30, # Connection timeout in seconds
)
# Get the translation tool from the server
adapter = await SseMcpToolAdapter.from_server_params(server_params, "translate")
# Create an agent that can use the translation tool
model_client = OpenAIChatCompletionClient(model="gpt-4")
agent = AssistantAgent(
name="translator",
model_client=model_client,
tools=[adapter],
system_message="You are a helpful translation assistant.",
)
# Let the agent translate some text
await Console(
agent.run_stream(task="Translate 'Hello, how are you?' to Spanish", cancellation_token=CancellationToken())
)
if __name__ == "__main__":
asyncio.run(main())
"""
component_config_schema = SseMcpToolAdapterConfig
component_provider_override = "agentdhal_extensions.tools.mcp.SseMcpToolAdapter"
def __init__(self, server_params: SseServerParams, tool: Tool, session: ClientSession | None = None) -> None:
super().__init__(server_params=server_params, tool=tool, session=session)
def _to_config(self) -> SseMcpToolAdapterConfig:
"""
Convert the adapter to its configuration representation.
Returns:
SseMcpToolAdapterConfig: The configuration of the adapter.
"""
return SseMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool)
@classmethod
def _from_config(cls, config: SseMcpToolAdapterConfig) -> Self:
"""
Create an instance of SseMcpToolAdapter from its configuration.
Args:
config (SseMcpToolAdapterConfig): The configuration of the adapter.
Returns:
SseMcpToolAdapter: An instance of SseMcpToolAdapter.
"""
return cls(server_params=config.server_params, tool=config.tool)

View File

@@ -0,0 +1,74 @@
from agentdhal_core import Component
from pydantic import BaseModel
from typing_extensions import Self
from mcp import ClientSession, Tool
from ._base import McpToolAdapter
from ._config import StdioServerParams
class StdioMcpToolAdapterConfig(BaseModel):
"""Configuration for the MCP tool adapter."""
server_params: StdioServerParams
tool: Tool
class StdioMcpToolAdapter(
McpToolAdapter[StdioServerParams],
Component[StdioMcpToolAdapterConfig],
):
"""Allows you to wrap an MCP tool running over STDIO and make it available to AutoGen.
This adapter enables using MCP-compatible tools that communicate over standard input/output
with AutoGen agents. Common use cases include wrapping command-line tools and local services
that implement the Model Context Protocol (MCP).
.. note::
To use this class, you need to install `mcp` extra for the `autogen-ext` package.
.. code-block:: bash
pip install -U "agentdhal-ext[mcp]"
Args:
server_params (StdioServerParams): Parameters for the MCP server connection,
including command to run and its arguments
tool (Tool): The MCP tool to wrap
session (ClientSession, optional): The MCP client session to use. If not provided,
a new session will be created. This is useful for testing or when you want to
manage the session lifecycle yourself.
See :func:`~agentdhal_extensions.tools.mcp.mcp_server_tools` for examples.
"""
component_config_schema = StdioMcpToolAdapterConfig
component_provider_override = "agentdhal_extensions.tools.mcp.StdioMcpToolAdapter"
def __init__(self, server_params: StdioServerParams, tool: Tool, session: ClientSession | None = None) -> None:
super().__init__(server_params=server_params, tool=tool, session=session)
def _to_config(self) -> StdioMcpToolAdapterConfig:
"""
Convert the adapter to its configuration representation.
Returns:
StdioMcpToolAdapterConfig: The configuration of the adapter.
"""
return StdioMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool)
@classmethod
def _from_config(cls, config: StdioMcpToolAdapterConfig) -> Self:
"""
Create an instance of StdioMcpToolAdapter from its configuration.
Args:
config (StdioMcpToolAdapterConfig): The configuration of the adapter.
Returns:
StdioMcpToolAdapter: An instance of StdioMcpToolAdapter.
"""
return cls(server_params=config.server_params, tool=config.tool)

View File

@@ -0,0 +1,121 @@
from agentdhal_core import Component
from pydantic import BaseModel
from typing_extensions import Self
from mcp import ClientSession, Tool
from ._base import McpToolAdapter
from ._config import StreamableHttpServerParams
class StreamableHttpMcpToolAdapterConfig(BaseModel):
"""Configuration for the MCP tool adapter."""
server_params: StreamableHttpServerParams
tool: Tool
class StreamableHttpMcpToolAdapter(
McpToolAdapter[StreamableHttpServerParams],
Component[StreamableHttpMcpToolAdapterConfig],
):
"""
Allows you to wrap an MCP tool running over Streamable HTTP and make it available to AutoGen.
This adapter enables using MCP-compatible tools that communicate over Streamable HTTP
with AutoGen agents. Common use cases include integrating with remote MCP services,
cloud-based tools, and web APIs that implement the Model Context Protocol (MCP).
.. note::
To use this class, you need to install `mcp` extra for the `autogen-ext` package.
.. code-block:: bash
pip install -U "agentdhal-ext[mcp]"
Args:
server_params (StreamableHttpServerParams): Parameters for the MCP server connection,
including URL, headers, and timeouts.
tool (Tool): The MCP tool to wrap.
session (ClientSession, optional): The MCP client session to use. If not provided,
it will create a new session. This is useful for testing or when you want to
manage the session lifecycle yourself.
Examples:
Use a remote translation service that implements MCP over Streamable HTTP to
create tools that allow AutoGen agents to perform translations:
.. code-block:: python
import asyncio
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.tools.mcp import StreamableHttpMcpToolAdapter, StreamableHttpServerParams
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.ui import Console
from agentdhal_core import CancellationToken
async def main() -> None:
# Create server params for the remote MCP service
server_params = StreamableHttpServerParams(
url="https://api.example.com/mcp",
headers={"Authorization": "Bearer your-api-key", "Content-Type": "application/json"},
timeout=30.0, # HTTP timeout in seconds
sse_read_timeout=300.0, # SSE read timeout in seconds (5 minutes)
terminate_on_close=True,
)
# Get the translation tool from the server
adapter = await StreamableHttpMcpToolAdapter.from_server_params(server_params, "translate")
# Create an agent that can use the translation tool
model_client = OpenAIChatCompletionClient(model="gpt-4")
agent = AssistantAgent(
name="translator",
model_client=model_client,
tools=[adapter],
system_message="You are a helpful translation assistant.",
)
# Let the agent translate some text
await Console(
agent.run_stream(task="Translate 'Hello, how are you?' to Spanish", cancellation_token=CancellationToken())
)
if __name__ == "__main__":
asyncio.run(main())
"""
component_config_schema = StreamableHttpMcpToolAdapterConfig
component_provider_override = "agentdhal_extensions.tools.mcp.StreamableHttpMcpToolAdapter"
def __init__(
self, server_params: StreamableHttpServerParams, tool: Tool, session: ClientSession | None = None
) -> None:
super().__init__(server_params=server_params, tool=tool, session=session)
def _to_config(self) -> StreamableHttpMcpToolAdapterConfig:
"""
Convert the adapter to its configuration representation.
Returns:
StreamableHttpMcpToolAdapterConfig: The configuration of the adapter.
"""
return StreamableHttpMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool)
@classmethod
def _from_config(cls, config: StreamableHttpMcpToolAdapterConfig) -> Self:
"""
Create an instance of StreamableHttpMcpToolAdapter from its configuration.
Args:
config (StreamableHttpMcpToolAdapterConfig): The configuration of the adapter.
Returns:
StreamableHttpMcpToolAdapter: An instance of StreamableHttpMcpToolAdapter.
"""
return cls(server_params=config.server_params, tool=config.tool)

View File

@@ -0,0 +1,518 @@
import asyncio
import builtins
import warnings
from typing import Any, Dict, List, Literal, Mapping, Optional
from agentdhal_core import CancellationToken, Component, ComponentModel, Image, trace_tool_span
from agentdhal_core.models import ChatCompletionClient
from agentdhal_core.tools import (
ImageResultContent,
ParametersSchema,
TextResultContent,
ToolOverride,
ToolResult,
ToolSchema,
Workbench,
)
from pydantic import BaseModel, Field
from typing_extensions import Self
from mcp.types import (
CallToolResult,
EmbeddedResource,
GetPromptResult,
ImageContent,
ListPromptsResult,
ListResourcesResult,
ListResourceTemplatesResult,
ListToolsResult,
ReadResourceResult,
TextContent,
)
from ._actor import McpSessionActor
from ._config import McpServerParams, SseServerParams, StdioServerParams, StreamableHttpServerParams
class McpWorkbenchConfig(BaseModel):
server_params: McpServerParams
tool_overrides: Dict[str, ToolOverride] = Field(default_factory=dict)
model_client: ComponentModel | Dict[str, Any] | None = None
class McpWorkbenchState(BaseModel):
type: Literal["McpWorkBenchState"] = "McpWorkBenchState"
class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
"""A workbench that wraps an MCP server and provides an interface
to list and call tools provided by the server.
.. warning::
Only connect to trusted MCP servers, especially when using
`StdioServerParams` as it executes commands in the local environment.
This workbench should be used as a context manager to ensure proper
initialization and cleanup of the underlying MCP session.
.. list-table:: MCP Support
:header-rows: 1
:widths: 30 70
* - MCP Capability
- Supported Features
* - Tools
- list_tools, call_tool
* - Resources
- list_resources, read_resource
* - ResourceTemplates
- list_resource_templates, read_resource_template
* - Prompts
- list_prompts, get_prompt
* - Sampling
- Optional support via model_client
* - Roots
- not supported
* - Ellicitation
- not supported
Args:
server_params (McpServerParams): The parameters to connect to the MCP server.
This can be either a :class:`StdioServerParams` or :class:`SseServerParams`.
tool_overrides (Optional[Dict[str, ToolOverride]]): Optional mapping of original tool
names to override configurations for name and/or description. This allows
customizing how server tools appear to consumers while maintaining the underlying
tool functionality.
model_client: Optional chat completion client to handle sampling requests
from MCP servers that support the sampling capability. This allows MCP
servers to request text generation from a language model during tool
execution. If not provided, sampling requests will return an error.
Raises:
ValueError: If there are conflicts in tool override names.
Examples:
Here is a simple example of how to use the workbench with a `mcp-server-fetch` server:
.. code-block:: python
import asyncio
from agentdhal_extensions.tools.mcp import McpWorkbench, StdioServerParams
async def main() -> None:
params = StdioServerParams(
command="uvx",
args=["mcp-server-fetch"],
read_timeout_seconds=60,
)
# You can also use `start()` and `stop()` to manage the session.
async with McpWorkbench(server_params=params) as workbench:
tools = await workbench.list_tools()
print(tools)
result = await workbench.call_tool(tools[0]["name"], {"url": "https://github.com/"})
print(result)
asyncio.run(main())
Example of using tool overrides:
.. code-block:: python
import asyncio
from agentdhal_extensions.tools.mcp import McpWorkbench, StdioServerParams
from agentdhal_core.tools import ToolOverride
async def main() -> None:
params = StdioServerParams(
command="uvx",
args=["mcp-server-fetch"],
read_timeout_seconds=60,
)
# Override the fetch tool's name and description
overrides = {
"fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool with better error handling")
}
async with McpWorkbench(server_params=params, tool_overrides=overrides) as workbench:
tools = await workbench.list_tools()
# The tool will now appear as "web_fetch" with the new description
print(tools)
# Call the overridden tool
result = await workbench.call_tool("web_fetch", {"url": "https://github.com/"})
print(result)
asyncio.run(main())
Example of using the workbench with the `GitHub MCP Server <https://github.com/github/github-mcp-server>`_:
.. code-block:: python
import asyncio
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.tools.mcp import McpWorkbench, StdioServerParams
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
server_params = StdioServerParams(
command="docker",
args=[
"run",
"-i",
"--rm",
"-e",
"GITHUB_PERSONAL_ACCESS_TOKEN",
"ghcr.io/github/github-mcp-server",
],
env={
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX",
},
)
async with McpWorkbench(server_params) as mcp:
agent = AssistantAgent(
"github_assistant",
model_client=model_client,
workbench=mcp,
reflect_on_tool_use=True,
model_client_stream=True,
)
await Console(agent.run_stream(task="Is there a repository named Autogen"))
asyncio.run(main())
Example of using the workbench with the `Playwright MCP Server <https://github.com/microsoft/playwright-mcp>`_:
.. code-block:: python
# First run `npm install -g @playwright/mcp@latest` to install the MCP server.
import asyncio
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.teams import RoundRobinGroupChat
from agentdhal_agentchat.conditions import TextMessageTermination
from agentdhal_agentchat.ui import Console
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.tools.mcp import McpWorkbench, StdioServerParams
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
server_params = StdioServerParams(
command="npx",
args=[
"@playwright/mcp@latest",
"--headless",
],
)
async with McpWorkbench(server_params) as mcp:
agent = AssistantAgent(
"web_browsing_assistant",
model_client=model_client,
workbench=mcp,
model_client_stream=True,
)
team = RoundRobinGroupChat(
[agent],
termination_condition=TextMessageTermination(source="web_browsing_assistant"),
)
await Console(team.run_stream(task="Find out how many contributors for the microsoft/autogen repository"))
asyncio.run(main())
"""
component_provider_override = "agentdhal_extensions.tools.mcp.McpWorkbench"
component_config_schema = McpWorkbenchConfig
def __init__(
self,
server_params: McpServerParams,
tool_overrides: Optional[Dict[str, ToolOverride]] = None,
model_client: ChatCompletionClient | None = None,
) -> None:
self._server_params = server_params
self._tool_overrides = tool_overrides or {}
self._model_client = model_client
# Build reverse mapping from override names to original names for call_tool
self._override_name_to_original: Dict[str, str] = {}
for original_name, override in self._tool_overrides.items():
override_name = override.name
if override_name and override_name != original_name:
# Check for conflicts with other override names
if override_name in self._override_name_to_original:
existing_original = self._override_name_to_original[override_name]
raise ValueError(
f"Tool override name '{override_name}' is used by multiple tools: "
f"'{existing_original}' and '{original_name}'. Override names must be unique."
)
self._override_name_to_original[override_name] = original_name
# self._session: ClientSession | None = None
self._actor: McpSessionActor | None = None
self._actor_loop: asyncio.AbstractEventLoop | None = None
self._read = None
self._write = None
@property
def server_params(self) -> McpServerParams:
return self._server_params
async def list_tools(self) -> List[ToolSchema]:
if not self._actor:
await self.start() # fallback to start the actor if not initialized instead of raising an error
# Why? Because when deserializing the workbench, the actor might not be initialized yet.
# raise RuntimeError("Actor is not initialized. Call start() first.")
if self._actor is None:
raise RuntimeError("Actor is not initialized. Please check the server connection.")
result_future = await self._actor.call("list_tools", None)
list_tool_result = await result_future
assert isinstance(
list_tool_result, ListToolsResult
), f"list_tools must return a CallToolResult, instead of : {str(type(list_tool_result))}"
schema: List[ToolSchema] = []
for tool in list_tool_result.tools:
original_name = tool.name
name = original_name
description = tool.description or ""
# Apply overrides if they exist for this tool
if original_name in self._tool_overrides:
override = self._tool_overrides[original_name]
if override.name is not None:
name = override.name
if override.description is not None:
description = override.description
parameters = ParametersSchema(
type="object",
properties=tool.inputSchema.get("properties", {}),
required=tool.inputSchema.get("required", []),
additionalProperties=tool.inputSchema.get("additionalProperties", False),
)
tool_schema = ToolSchema(
name=name,
description=description,
parameters=parameters,
)
schema.append(tool_schema)
return schema
async def call_tool(
self,
name: str,
arguments: Mapping[str, Any] | None = None,
cancellation_token: CancellationToken | None = None,
call_id: str | None = None,
) -> ToolResult:
if not self._actor:
await self.start() # fallback to start the actor if not initialized instead of raising an error
# Why? Because when deserializing the workbench, the actor might not be initialized yet.
# raise RuntimeError("Actor is not initialized. Call start() first.")
if self._actor is None:
raise RuntimeError("Actor is not initialized. Please check the server connection.")
if not cancellation_token:
cancellation_token = CancellationToken()
if not arguments:
arguments = {}
# Check if the name is an override name and map it back to the original
original_name = self._override_name_to_original.get(name, name)
with trace_tool_span(
tool_name=name, # Use the requested name for tracing
tool_call_id=call_id,
):
try:
result_future = await self._actor.call("call_tool", {"name": original_name, "kargs": arguments})
cancellation_token.link_future(result_future)
result = await result_future
assert isinstance(
result, CallToolResult
), f"call_tool must return a CallToolResult, instead of : {str(type(result))}"
result_parts: List[TextResultContent | ImageResultContent] = []
is_error = result.isError
for content in result.content:
if isinstance(content, TextContent):
result_parts.append(TextResultContent(content=content.text))
elif isinstance(content, ImageContent):
result_parts.append(ImageResultContent(content=Image.from_base64(content.data)))
elif isinstance(content, EmbeddedResource):
# TODO: how to handle embedded resources?
# For now we just use text representation.
result_parts.append(TextResultContent(content=content.model_dump_json()))
else:
raise ValueError(f"Unknown content type from server: {type(content)}")
except Exception as e:
error_message = self._format_errors(e)
is_error = True
result_parts = [TextResultContent(content=error_message)]
return ToolResult(name=name, result=result_parts, is_error=is_error) # Return the requested name
@property
def initialize_result(self) -> Any:
if self._actor:
return self._actor.initialize_result
return None
async def list_prompts(self) -> ListPromptsResult:
"""List available prompts from the MCP server."""
if not self._actor:
await self.start()
if self._actor is None:
raise RuntimeError("Actor is not initialized. Please check the server connection.")
result_future = await self._actor.call("list_prompts", None)
list_prompts_result = await result_future
assert isinstance(
list_prompts_result, ListPromptsResult
), f"list_prompts must return a ListPromptsResult, instead of: {str(type(list_prompts_result))}"
return list_prompts_result
async def list_resources(self) -> ListResourcesResult:
"""List available resources from the MCP server."""
if not self._actor:
await self.start()
if self._actor is None:
raise RuntimeError("Actor is not initialized. Please check the server connection.")
result_future = await self._actor.call("list_resources", None)
list_resources_result = await result_future
assert isinstance(
list_resources_result, ListResourcesResult
), f"list_resources must return a ListResourcesResult, instead of: {str(type(list_resources_result))}"
return list_resources_result
async def list_resource_templates(self) -> ListResourceTemplatesResult:
"""List available resource templates from the MCP server."""
if not self._actor:
await self.start()
if self._actor is None:
raise RuntimeError("Actor is not initialized. Please check the server connection.")
result_future = await self._actor.call("list_resource_templates", None)
list_templates_result = await result_future
assert isinstance(
list_templates_result, ListResourceTemplatesResult
), f"list_resource_templates must return a ListResourceTemplatesResult, instead of: {str(type(list_templates_result))}"
return list_templates_result
async def read_resource(self, uri: str) -> ReadResourceResult:
"""Read a resource from the MCP server."""
if not self._actor:
await self.start()
if self._actor is None:
raise RuntimeError("Actor is not initialized. Please check the server connection.")
result_future = await self._actor.call("read_resource", {"name": None, "kargs": {"uri": uri}})
read_resource_result = await result_future
assert isinstance(
read_resource_result, ReadResourceResult
), f"read_resource must return a ReadResourceResult, instead of: {str(type(read_resource_result))}"
return read_resource_result
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> GetPromptResult:
"""Get a prompt from the MCP server."""
if not self._actor:
await self.start()
if self._actor is None:
raise RuntimeError("Actor is not initialized. Please check the server connection.")
result_future = await self._actor.call("get_prompt", {"name": name, "kargs": {"arguments": arguments}})
get_prompt_result = await result_future
assert isinstance(
get_prompt_result, GetPromptResult
), f"get_prompt must return a GetPromptResult, instead of: {str(type(get_prompt_result))}"
return get_prompt_result
def _format_errors(self, error: Exception) -> str:
"""Recursively format errors into a string."""
error_message = ""
if hasattr(builtins, "ExceptionGroup") and isinstance(error, builtins.ExceptionGroup):
# ExceptionGroup is available in Python 3.11+.
# TODO: how to make this compatible with Python 3.10?
for sub_exception in error.exceptions: # type: ignore
error_message += self._format_errors(sub_exception) # type: ignore
else:
error_message += f"{str(error)}\n"
return error_message
async def start(self) -> None:
if self._actor:
warnings.warn(
"McpWorkbench is already started. No need to start again.",
UserWarning,
stacklevel=2,
)
return # Already initialized, no need to start again
if isinstance(self._server_params, (StdioServerParams, SseServerParams, StreamableHttpServerParams)):
self._actor = McpSessionActor(self._server_params, model_client=self._model_client)
await self._actor.initialize()
self._actor_loop = asyncio.get_event_loop()
else:
raise ValueError(f"Unsupported server params type: {type(self._server_params)}")
async def stop(self) -> None:
if self._actor:
# Close the actor
await self._actor.close()
self._actor = None
else:
raise RuntimeError("McpWorkbench is not started. Call start() first.")
async def reset(self) -> None:
pass
async def save_state(self) -> Mapping[str, Any]:
return McpWorkbenchState().model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
pass
def _to_config(self) -> McpWorkbenchConfig:
model_client_config = None
if self._model_client is not None:
model_client_config = self._model_client.dump_component()
return McpWorkbenchConfig(
server_params=self._server_params, tool_overrides=self._tool_overrides, model_client=model_client_config
)
@classmethod
def _from_config(cls, config: McpWorkbenchConfig) -> Self:
model_client = None
if config.model_client is not None:
model_client = ChatCompletionClient.load_component(config.model_client)
return cls(server_params=config.server_params, tool_overrides=config.tool_overrides, model_client=model_client)
def __del__(self) -> None:
# Ensure the actor is stopped when the workbench is deleted
# Use getattr to safely handle cases where attributes may not be set (e.g., if __init__ failed)
actor = getattr(self, "_actor", None)
actor_loop = getattr(self, "_actor_loop", None)
if actor and actor_loop:
if actor_loop.is_running() and not actor_loop.is_closed():
actor_loop.call_soon_threadsafe(lambda: asyncio.create_task(self.stop()))
else:
msg = "Cannot safely stop actor at [McpWorkbench.__del__]: loop is closed or not running"
warnings.warn(msg, RuntimeWarning, stacklevel=2)

View File

@@ -0,0 +1,6 @@
from ._kernel_function_from_tool import KernelFunctionFromTool, KernelFunctionFromToolSchema
__all__ = [
"KernelFunctionFromTool",
"KernelFunctionFromToolSchema",
]

View File

@@ -0,0 +1,94 @@
from typing import Any, TypeVar
from agentdhal_core import CancellationToken
from agentdhal_core.tools import BaseTool, ToolSchema
from pydantic import BaseModel
from semantic_kernel.functions import KernelFunctionFromMethod, KernelFunctionFromPrompt, kernel_function
from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata
from semantic_kernel.prompt_template.input_variable import InputVariable
from semantic_kernel.prompt_template.prompt_template_config import PromptTemplateConfig
InputT = TypeVar("InputT", bound=BaseModel)
OutputT = TypeVar("OutputT", bound=BaseModel)
class KernelFunctionFromTool(KernelFunctionFromMethod):
def __init__(self, tool: BaseTool[InputT, OutputT], plugin_name: str | None = None):
# Get the pydantic model types from the tool
args_type = tool.args_type()
return_type = tool.return_type()
# 1) Define an async function that calls the tool
@kernel_function(name=tool.name, description=tool.description)
async def tool_method(**kwargs: dict[str, Any]) -> Any:
return await tool.run_json(kwargs, cancellation_token=CancellationToken())
# Parse schema for parameters
parameters_meta: list[KernelParameterMetadata] = []
properties = tool.schema.get("parameters", {}).get("properties", {})
# Get the field types from the pydantic model
field_types = args_type.model_fields
for prop_name, prop_info in properties.items():
assert prop_name in field_types, f"Property {prop_name} not found in Tool {tool.name}"
assert isinstance(prop_info, dict), f"Property {prop_name} is not a dict in Tool {tool.name}"
# Get the actual type from the pydantic model field
field_type = field_types[prop_name]
parameters_meta.append(
KernelParameterMetadata(
name=prop_name,
description=field_type.description or "",
default_value=field_type.get_default(),
type=prop_info.get("type", "string"), # type: ignore
type_object=field_type.annotation,
is_required=field_type.is_required(),
)
)
# Create return parameter metadata
return_parameter = KernelParameterMetadata(
name="return",
description=f"Result from '{tool.name}' tool",
default_value=None,
type="object" if issubclass(return_type, BaseModel) else "string",
type_object=return_type,
is_required=True,
)
# Initialize the parent class
super().__init__(
method=tool_method,
plugin_name=plugin_name,
parameters=parameters_meta,
return_parameter=return_parameter,
additional_metadata=None,
)
self._tool = tool
class KernelFunctionFromToolSchema(KernelFunctionFromPrompt):
def __init__(self, tool_schema: ToolSchema, plugin_name: str | None = None):
properties = tool_schema.get("parameters", {}).get("properties", {})
required = properties.get("required", [])
prompt_template_config = PromptTemplateConfig(
name=tool_schema.get("name", ""),
description=tool_schema.get("description", ""),
input_variables=[
InputVariable(
name=prop_name, description=prop_info.get("description", ""), is_required=prop_name in required
)
for prop_name, prop_info in properties.items()
],
)
super().__init__(
function_name=tool_schema.get("name", ""),
plugin_name=plugin_name,
description=tool_schema.get("description", ""),
prompt_template_config=prompt_template_config,
)