first commit

This commit is contained in:
DigiJ
2026-03-13 12:56:43 -07:00
commit 159cf9fcfe
309 changed files with 64584 additions and 0 deletions

View File

@@ -0,0 +1,21 @@
from ._chroma_configs import (
ChromaDBVectorMemoryConfig,
CustomEmbeddingFunctionConfig,
DefaultEmbeddingFunctionConfig,
HttpChromaDBVectorMemoryConfig,
OpenAIEmbeddingFunctionConfig,
PersistentChromaDBVectorMemoryConfig,
SentenceTransformerEmbeddingFunctionConfig,
)
from ._chromadb import ChromaDBVectorMemory
__all__ = [
"ChromaDBVectorMemory",
"ChromaDBVectorMemoryConfig",
"PersistentChromaDBVectorMemoryConfig",
"HttpChromaDBVectorMemoryConfig",
"DefaultEmbeddingFunctionConfig",
"SentenceTransformerEmbeddingFunctionConfig",
"OpenAIEmbeddingFunctionConfig",
"CustomEmbeddingFunctionConfig",
]

View File

@@ -0,0 +1,137 @@
"""Configuration classes for ChromaDB vector memory."""
from typing import Any, Callable, Dict, Literal, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
class DefaultEmbeddingFunctionConfig(BaseModel):
"""Configuration for the default ChromaDB embedding function.
Uses ChromaDB's default embedding function (Sentence Transformers all-MiniLM-L6-v2).
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
"""
function_type: Literal["default"] = "default"
class SentenceTransformerEmbeddingFunctionConfig(BaseModel):
"""Configuration for SentenceTransformer embedding functions.
Allows specifying a custom SentenceTransformer model for embeddings.
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
Args:
model_name (str): Name of the SentenceTransformer model to use.
Defaults to "all-MiniLM-L6-v2".
Example:
.. code-block:: python
from agentdhal_extensions.memory.chromadb import SentenceTransformerEmbeddingFunctionConfig
_ = SentenceTransformerEmbeddingFunctionConfig(model_name="paraphrase-multilingual-mpnet-base-v2")
"""
function_type: Literal["sentence_transformer"] = "sentence_transformer"
model_name: str = Field(default="all-MiniLM-L6-v2", description="SentenceTransformer model name to use")
class OpenAIEmbeddingFunctionConfig(BaseModel):
"""Configuration for OpenAI embedding functions.
Uses OpenAI's embedding API for generating embeddings.
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
Args:
api_key (str): OpenAI API key. If empty, will attempt to use environment variable.
model_name (str): OpenAI embedding model name. Defaults to "text-embedding-ada-002".
Example:
.. code-block:: python
from agentdhal_extensions.memory.chromadb import OpenAIEmbeddingFunctionConfig
_ = OpenAIEmbeddingFunctionConfig(api_key="sk-...", model_name="text-embedding-3-small")
"""
function_type: Literal["openai"] = "openai"
api_key: str = Field(default="", description="OpenAI API key")
model_name: str = Field(default="text-embedding-ada-002", description="OpenAI embedding model name")
class CustomEmbeddingFunctionConfig(BaseModel):
"""Configuration for custom embedding functions.
Allows using a custom function that returns a ChromaDB-compatible embedding function.
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
.. warning::
Configurations containing custom functions are not serializable.
Args:
function (Callable): Function that returns a ChromaDB-compatible embedding function.
params (Dict[str, Any]): Parameters to pass to the function.
"""
function_type: Literal["custom"] = "custom"
function: Callable[..., Any] = Field(description="Function that returns an embedding function")
params: Dict[str, Any] = Field(default_factory=dict, description="Parameters to pass to the function")
# Tagged union type for embedding function configurations
EmbeddingFunctionConfig = Annotated[
Union[
DefaultEmbeddingFunctionConfig,
SentenceTransformerEmbeddingFunctionConfig,
OpenAIEmbeddingFunctionConfig,
CustomEmbeddingFunctionConfig,
],
Field(discriminator="function_type"),
]
class ChromaDBVectorMemoryConfig(BaseModel):
"""Base configuration for ChromaDB-based memory implementation.
.. versionchanged:: v0.4.1
Added support for custom embedding functions via embedding_function_config.
"""
client_type: Literal["persistent", "http"]
collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection")
distance_metric: str = Field(default="cosine", description="Distance metric for similarity search")
k: int = Field(default=3, description="Number of results to return in queries")
score_threshold: float | None = Field(default=None, description="Minimum similarity score threshold")
allow_reset: bool = Field(default=False, description="Whether to allow resetting the ChromaDB client")
tenant: str = Field(default="default_tenant", description="Tenant to use")
database: str = Field(default="default_database", description="Database to use")
embedding_function_config: EmbeddingFunctionConfig = Field(
default_factory=DefaultEmbeddingFunctionConfig, description="Configuration for the embedding function"
)
class PersistentChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
"""Configuration for persistent ChromaDB memory."""
client_type: Literal["persistent", "http"] = "persistent"
persistence_path: str = Field(default="./chroma_db", description="Path for persistent storage")
class HttpChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
"""Configuration for HTTP ChromaDB memory."""
client_type: Literal["persistent", "http"] = "http"
host: str = Field(default="localhost", description="Host of the remote server")
port: int = Field(default=8000, description="Port of the remote server")
ssl: bool = Field(default=False, description="Whether to use HTTPS")
headers: Dict[str, str] | None = Field(default=None, description="Headers to send to the server")

View File

@@ -0,0 +1,459 @@
import logging
import uuid
from typing import Any, List
from agentdhal_core import CancellationToken, Component, Image
from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
from agentdhal_core.model_context import ChatCompletionContext
from agentdhal_core.models import SystemMessage
from chromadb import HttpClient, PersistentClient
from chromadb.api.models.Collection import Collection
from chromadb.api.types import Document, Metadata
from typing_extensions import Self
from ._chroma_configs import (
ChromaDBVectorMemoryConfig,
CustomEmbeddingFunctionConfig,
DefaultEmbeddingFunctionConfig,
HttpChromaDBVectorMemoryConfig,
OpenAIEmbeddingFunctionConfig,
PersistentChromaDBVectorMemoryConfig,
SentenceTransformerEmbeddingFunctionConfig,
)
logger = logging.getLogger(__name__)
try:
from chromadb.api import ClientAPI
except ImportError as e:
raise ImportError(
"To use the ChromaDBVectorMemory the chromadb extra must be installed. Run `pip install autogen-ext[chromadb]`"
) from e
class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
"""
Store and retrieve memory using vector similarity search powered by ChromaDB.
`ChromaDBVectorMemory` provides a vector-based memory implementation that uses ChromaDB for
storing and retrieving content based on semantic similarity. It enhances agents with the ability
to recall contextually relevant information during conversations by leveraging vector embeddings
to find similar content.
This implementation serves as a reference for more complex memory systems using vector embeddings.
For advanced use cases requiring specialized formatting of retrieved content, users should extend
this class and override the `update_context()` method.
This implementation requires the ChromaDB extra to be installed. Install with:
.. code-block:: bash
pip install "agentdhal-ext[chromadb]"
Args:
config (ChromaDBVectorMemoryConfig | None): Configuration for the ChromaDB memory.
If None, defaults to a PersistentChromaDBVectorMemoryConfig with default values.
Two config types are supported:
* PersistentChromaDBVectorMemoryConfig: For local storage
* HttpChromaDBVectorMemoryConfig: For connecting to a remote ChromaDB server
Example:
.. code-block:: python
import os
import asyncio
from pathlib import Path
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.ui import Console
from agentdhal_core.memory import MemoryContent, MemoryMimeType
from agentdhal_extensions.memory.chromadb import (
ChromaDBVectorMemory,
PersistentChromaDBVectorMemoryConfig,
SentenceTransformerEmbeddingFunctionConfig,
OpenAIEmbeddingFunctionConfig,
)
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
def get_weather(city: str) -> str:
return f"The weather in {city} is sunny with a high of 90°F and a low of 70°F."
def fahrenheit_to_celsius(fahrenheit: float) -> float:
return (fahrenheit - 32) * 5.0 / 9.0
async def main() -> None:
# Use default embedding function
default_memory = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
collection_name="user_preferences",
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
k=3, # Return top 3 results
score_threshold=0.5, # Minimum similarity score
)
)
# Using a custom SentenceTransformer model
custom_memory = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
collection_name="multilingual_memory",
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(
model_name="paraphrase-multilingual-mpnet-base-v2"
),
)
)
# Using OpenAI embeddings
openai_memory = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
collection_name="openai_memory",
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
embedding_function_config=OpenAIEmbeddingFunctionConfig(
api_key=os.environ["OPENAI_API_KEY"], model_name="text-embedding-3-small"
),
)
)
# Add user preferences to memory
await openai_memory.add(
MemoryContent(
content="The user prefers weather temperatures in Celsius",
mime_type=MemoryMimeType.TEXT,
metadata={"category": "preferences", "type": "units"},
)
)
# Create assistant agent with ChromaDB memory
assistant = AssistantAgent(
name="assistant",
model_client=OpenAIChatCompletionClient(
model="gpt-4.1",
),
tools=[
get_weather,
fahrenheit_to_celsius,
],
max_tool_iterations=10,
memory=[openai_memory],
)
# The memory will automatically retrieve relevant content during conversations
await Console(assistant.run_stream(task="What's the temperature in New York?"))
# Remember to close the memory when finished
await default_memory.close()
await custom_memory.close()
await openai_memory.close()
asyncio.run(main())
Output:
.. code-block:: text
---------- TextMessage (user) ----------
What's the temperature in New York?
---------- MemoryQueryEvent (assistant) ----------
[MemoryContent(content='The user prefers weather temperatures in Celsius', mime_type='MemoryMimeType.TEXT', metadata={'type': 'units', 'category': 'preferences', 'mime_type': 'MemoryMimeType.TEXT', 'score': 0.3133561611175537, 'id': 'fb00506c-acf4-4174-93d7-2a942593f3f7'}), MemoryContent(content='The user prefers weather temperatures in Celsius', mime_type='MemoryMimeType.TEXT', metadata={'mime_type': 'MemoryMimeType.TEXT', 'category': 'preferences', 'type': 'units', 'score': 0.3133561611175537, 'id': '34311689-b419-4e1a-8bc4-09143f356c66'})]
---------- ToolCallRequestEvent (assistant) ----------
[FunctionCall(id='call_7TjsFd430J1aKwU5T2w8bvdh', arguments='{"city":"New York"}', name='get_weather')]
---------- ToolCallExecutionEvent (assistant) ----------
[FunctionExecutionResult(content='The weather in New York is sunny with a high of 90°F and a low of 70°F.', name='get_weather', call_id='call_7TjsFd430J1aKwU5T2w8bvdh', is_error=False)]
---------- ToolCallRequestEvent (assistant) ----------
[FunctionCall(id='call_RTjMHEZwDXtjurEYTjDlvq9c', arguments='{"fahrenheit": 90}', name='fahrenheit_to_celsius'), FunctionCall(id='call_3mMuCK1aqtzZPTqIHPoHKxtP', arguments='{"fahrenheit": 70}', name='fahrenheit_to_celsius')]
---------- ToolCallExecutionEvent (assistant) ----------
[FunctionExecutionResult(content='32.22222222222222', name='fahrenheit_to_celsius', call_id='call_RTjMHEZwDXtjurEYTjDlvq9c', is_error=False), FunctionExecutionResult(content='21.11111111111111', name='fahrenheit_to_celsius', call_id='call_3mMuCK1aqtzZPTqIHPoHKxtP', is_error=False)]
---------- TextMessage (assistant) ----------
The temperature in New York today is sunny with a high of about 32°C and a low of about 21°C.
"""
component_config_schema = ChromaDBVectorMemoryConfig
component_provider_override = "agentdhal_extensions.memory.chromadb.ChromaDBVectorMemory"
def __init__(self, config: ChromaDBVectorMemoryConfig | None = None) -> None:
self._config = config or PersistentChromaDBVectorMemoryConfig()
self._client: ClientAPI | None = None
self._collection: Collection | None = None
@property
def collection_name(self) -> str:
"""Get the name of the ChromaDB collection."""
return self._config.collection_name
def _create_embedding_function(self) -> Any:
"""Create an embedding function based on the configuration.
Returns:
A ChromaDB-compatible embedding function.
Raises:
ValueError: If the embedding function type is unsupported.
ImportError: If required dependencies are not installed.
"""
try:
from chromadb.utils import embedding_functions
except ImportError as e:
raise ImportError(
"ChromaDB embedding functions not available. Ensure chromadb is properly installed."
) from e
config = self._config.embedding_function_config
if isinstance(config, DefaultEmbeddingFunctionConfig):
return embedding_functions.DefaultEmbeddingFunction()
elif isinstance(config, SentenceTransformerEmbeddingFunctionConfig):
try:
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=config.model_name)
except Exception as e:
raise ImportError(
f"Failed to create SentenceTransformer embedding function with model '{config.model_name}'. "
f"Ensure sentence-transformers is installed and the model is available. Error: {e}"
) from e
elif isinstance(config, OpenAIEmbeddingFunctionConfig):
try:
return embedding_functions.OpenAIEmbeddingFunction(api_key=config.api_key, model_name=config.model_name)
except Exception as e:
raise ImportError(
f"Failed to create OpenAI embedding function with model '{config.model_name}'. "
f"Ensure openai is installed and API key is valid. Error: {e}"
) from e
elif isinstance(config, CustomEmbeddingFunctionConfig):
try:
return config.function(**config.params)
except Exception as e:
raise ValueError(f"Failed to create custom embedding function. Error: {e}") from e
else:
raise ValueError(f"Unsupported embedding function config type: {type(config)}")
def _ensure_initialized(self) -> None:
"""Ensure ChromaDB client and collection are initialized."""
if self._client is None:
try:
from chromadb.config import Settings
settings = Settings(allow_reset=self._config.allow_reset)
if isinstance(self._config, PersistentChromaDBVectorMemoryConfig):
self._client = PersistentClient(
path=self._config.persistence_path,
settings=settings,
tenant=self._config.tenant,
database=self._config.database,
)
elif isinstance(self._config, HttpChromaDBVectorMemoryConfig):
self._client = HttpClient(
host=self._config.host,
port=self._config.port,
ssl=self._config.ssl,
headers=self._config.headers,
settings=settings,
tenant=self._config.tenant,
database=self._config.database,
)
else:
raise ValueError(f"Unsupported config type: {type(self._config)}")
except Exception as e:
logger.error(f"Failed to initialize ChromaDB client: {e}")
raise
if self._collection is None:
try:
# Create embedding function
embedding_function = self._create_embedding_function()
# Create or get collection with embedding function
self._collection = self._client.get_or_create_collection(
name=self._config.collection_name,
metadata={"distance_metric": self._config.distance_metric},
embedding_function=embedding_function,
)
except Exception as e:
logger.error(f"Failed to get/create collection: {e}")
raise
def _extract_text(self, content_item: str | MemoryContent) -> str:
"""Extract searchable text from content."""
if isinstance(content_item, str):
return content_item
content = content_item.content
mime_type = content_item.mime_type
if mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]:
return str(content)
elif mime_type == MemoryMimeType.JSON:
if isinstance(content, dict):
# Store original JSON string representation
return str(content).lower()
raise ValueError("JSON content must be a dict")
elif isinstance(content, Image):
raise ValueError("Image content cannot be converted to text")
else:
raise ValueError(f"Unsupported content type: {mime_type}")
def _calculate_score(self, distance: float) -> float:
"""Convert ChromaDB distance to a similarity score."""
if self._config.distance_metric == "cosine":
return 1.0 - (distance / 2.0)
return 1.0 / (1.0 + distance)
async def update_context(
self,
model_context: ChatCompletionContext,
) -> UpdateContextResult:
messages = await model_context.get_messages()
if not messages:
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
# Extract query from last message
last_message = messages[-1]
query_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
# Query memory and get results
query_results = await self.query(query_text)
if query_results.results:
# Format results for context
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)]
memory_context = "\nRelevant memory content:\n" + "\n".join(memory_strings)
# Add to context
await model_context.add_message(SystemMessage(content=memory_context))
return UpdateContextResult(memories=query_results)
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
self._ensure_initialized()
if self._collection is None:
raise RuntimeError("Failed to initialize ChromaDB")
try:
# Extract text from content
text = self._extract_text(content)
# Use metadata directly from content
metadata_dict = content.metadata or {}
metadata_dict["mime_type"] = str(content.mime_type)
# Add to ChromaDB
self._collection.add(documents=[text], metadatas=[metadata_dict], ids=[str(uuid.uuid4())])
except Exception as e:
logger.error(f"Failed to add content to ChromaDB: {e}")
raise
async def query(
self,
query: str | MemoryContent,
cancellation_token: CancellationToken | None = None,
**kwargs: Any,
) -> MemoryQueryResult:
self._ensure_initialized()
if self._collection is None:
raise RuntimeError("Failed to initialize ChromaDB")
try:
# Extract text for query
query_text = self._extract_text(query)
# Query ChromaDB
results = self._collection.query(
query_texts=[query_text],
n_results=self._config.k,
include=["documents", "metadatas", "distances"],
**kwargs,
)
# Convert results to MemoryContent list
memory_results: List[MemoryContent] = []
if (
not results
or not results.get("documents")
or not results.get("metadatas")
or not results.get("distances")
):
return MemoryQueryResult(results=memory_results)
documents: List[Document] = results["documents"][0] if results["documents"] else []
metadatas: List[Metadata] = results["metadatas"][0] if results["metadatas"] else []
distances: List[float] = results["distances"][0] if results["distances"] else []
ids: List[str] = results["ids"][0] if results["ids"] else []
for doc, metadata_dict, distance, doc_id in zip(documents, metadatas, distances, ids, strict=False):
# Calculate score
score = self._calculate_score(distance)
metadata = dict(metadata_dict)
metadata["score"] = score
metadata["id"] = doc_id
if self._config.score_threshold is not None and score < self._config.score_threshold:
continue
# Extract mime_type from metadata
mime_type = str(metadata_dict.get("mime_type", MemoryMimeType.TEXT.value))
# Create MemoryContent
content = MemoryContent(
content=doc,
mime_type=mime_type,
metadata=metadata,
)
memory_results.append(content)
return MemoryQueryResult(results=memory_results)
except Exception as e:
logger.error(f"Failed to query ChromaDB: {e}")
raise
async def clear(self) -> None:
self._ensure_initialized()
if self._collection is None:
raise RuntimeError("Failed to initialize ChromaDB")
try:
results = self._collection.get()
if results and results["ids"]:
self._collection.delete(ids=results["ids"])
except Exception as e:
logger.error(f"Failed to clear ChromaDB collection: {e}")
raise
async def close(self) -> None:
"""Clean up ChromaDB client and resources."""
self._collection = None
self._client = None
async def reset(self) -> None:
self._ensure_initialized()
if not self._config.allow_reset:
raise RuntimeError("Reset not allowed. Set allow_reset=True in config to enable.")
if self._client is not None:
try:
self._client.reset()
except Exception as e:
logger.error(f"Error during ChromaDB reset: {e}")
finally:
self._collection = None
def _to_config(self) -> ChromaDBVectorMemoryConfig:
"""Serialize the memory configuration."""
return self._config
@classmethod
def _from_config(cls, config: ChromaDBVectorMemoryConfig) -> Self:
"""Deserialize the memory configuration."""
return cls(config=config)