first commit

This commit is contained in:
DigiJ
2026-03-13 12:56:43 -07:00
commit 159cf9fcfe
309 changed files with 64584 additions and 0 deletions

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,4 @@
from ._text_canvas import TextCanvas
from ._text_canvas_memory import TextCanvasMemory
__all__ = ["TextCanvas", "TextCanvasMemory"]

View File

@@ -0,0 +1,50 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Union
class BaseCanvas(ABC):
"""
An abstract protocol for "canvas" objects that maintain
revision history for file-like data. Concrete subclasses
can handle text, images, structured data, etc.
.. warning::
This is an experimental API and may change in the future.
"""
@abstractmethod
def list_files(self) -> Dict[str, int]:
"""
Returns a dict of filename -> latest revision number.
"""
raise NotImplementedError
@abstractmethod
def get_latest_content(self, filename: str) -> Union[str, bytes, Any]:
"""
Returns the latest version of a file's content.
"""
raise NotImplementedError
@abstractmethod
def add_or_update_file(self, filename: str, new_content: Union[str, bytes, Any]) -> None:
"""
Creates or updates the file content with a new revision.
"""
raise NotImplementedError
@abstractmethod
def get_diff(self, filename: str, from_revision: int, to_revision: int) -> str:
"""
Returns a diff (in some format) between two revisions.
"""
raise NotImplementedError
@abstractmethod
def apply_patch(self, filename: str, patch_data: Union[str, bytes, Any]) -> None:
"""
Applies a patch/diff to the latest revision and increments the revision.
"""
raise NotImplementedError

View File

@@ -0,0 +1,64 @@
from agentdhal_core import CancellationToken
from agentdhal_core.tools import BaseTool
from pydantic import BaseModel
from ._text_canvas import TextCanvas
class UpdateFileArgs(BaseModel):
filename: str
new_content: str
class UpdateFileResult(BaseModel):
status: str
class UpdateFileTool(BaseTool[UpdateFileArgs, UpdateFileResult]):
"""
Overwrites or creates a file in the canvas.
"""
def __init__(self, canvas: TextCanvas):
super().__init__(
args_type=UpdateFileArgs,
return_type=UpdateFileResult,
name="update_file",
description="Create/update a file on the canvas with the provided content.",
)
self._canvas = canvas
async def run(self, args: UpdateFileArgs, cancellation_token: CancellationToken) -> UpdateFileResult:
self._canvas.add_or_update_file(args.filename, args.new_content)
return UpdateFileResult(status="OK")
class ApplyPatchArgs(BaseModel):
filename: str
patch_text: str
class ApplyPatchResult(BaseModel):
status: str
class ApplyPatchTool(BaseTool[ApplyPatchArgs, ApplyPatchResult]):
"""
Applies a unified diff patch to the given file on the canvas.
"""
def __init__(self, canvas: TextCanvas):
super().__init__(
args_type=ApplyPatchArgs,
return_type=ApplyPatchResult,
name="apply_patch",
description=(
"Apply a unified diff patch to an existing file on the canvas. "
"The patch must be in diff/patch format. The file must exist or be created first."
),
)
self._canvas = canvas
async def run(self, args: ApplyPatchArgs, cancellation_token: CancellationToken) -> ApplyPatchResult:
self._canvas.apply_patch(args.filename, args.patch_text)
return ApplyPatchResult(status="PATCH APPLIED")

View File

@@ -0,0 +1,192 @@
import difflib
from typing import Any, Dict, List, Union
try: # pragma: no cover
from unidiff import PatchSet
except ModuleNotFoundError: # pragma: no cover
PatchSet = None # type: ignore
from ._canvas import BaseCanvas
class FileRevision:
"""Tracks the history of one file's content."""
__slots__ = ("content", "revision")
def __init__(self, content: str, revision: int) -> None:
self.content: str = content
self.revision: int = revision # e.g. an integer, a timestamp, or git hash
class TextCanvas(BaseCanvas):
"""An inmemory canvas that stores *text* files with full revision history.
.. warning::
This is an experimental API and may change in the future.
Besides the original CRUDlike operations, this enhanced implementation adds:
* **apply_patch** applies patches using the ``unidiff`` library for accurate
hunk application and context line validation.
* **get_revision_content** random access to any historical revision.
* **get_revision_diffs** obtain the list of diffs applied between every
consecutive pair of revisions so that a caller can replay or audit the
full change history.
"""
# ----------------------------------------------------------------------------------
# Construction helpers
# ----------------------------------------------------------------------------------
def __init__(self) -> None:
# For each file we keep an *ordered* list of FileRevision where the last
# element is the most recent. Using a list keeps the memory footprint
# small and preserves order without any extra bookkeeping.
self._files: Dict[str, List[FileRevision]] = {}
# ----------------------------------------------------------------------------------
# Internal utilities
# ----------------------------------------------------------------------------------
def _latest_idx(self, filename: str) -> int:
"""Return the index (not revision number) of the newest revision."""
return len(self._files.get(filename, [])) - 1
def _ensure_file(self, filename: str) -> None:
if filename not in self._files:
raise ValueError(f"File '{filename}' does not exist on the canvas; create it first.")
# ----------------------------------------------------------------------------------
# Revision inspection helpers
# ----------------------------------------------------------------------------------
def get_revision_content(self, filename: str, revision: int) -> str: # NEW 🚀
"""Return the exact content stored in *revision*.
If the revision does not exist an empty string is returned so that
downstream code can handle the "not found" case without exceptions.
"""
for rev in self._files.get(filename, []):
if rev.revision == revision:
return rev.content
return ""
def get_revision_diffs(self, filename: str) -> List[str]: # NEW 🚀
"""Return a *chronological* list of unifieddiffs for *filename*.
Each element in the returned list represents the diff that transformed
revision *n* into revision *n+1* (starting at revision 1 → 2).
"""
revisions = self._files.get(filename, [])
diffs: List[str] = []
for i in range(1, len(revisions)):
older, newer = revisions[i - 1], revisions[i]
diff = difflib.unified_diff(
older.content.splitlines(keepends=True),
newer.content.splitlines(keepends=True),
fromfile=f"{filename}@r{older.revision}",
tofile=f"{filename}@r{newer.revision}",
)
diffs.append("".join(diff))
return diffs
# ----------------------------------------------------------------------------------
# BaseCanvas interface implementation
# ----------------------------------------------------------------------------------
def list_files(self) -> Dict[str, int]:
"""Return a mapping of *filename → latest revision number*."""
return {fname: revs[-1].revision for fname, revs in self._files.items() if revs}
def get_latest_content(self, filename: str) -> str: # noqa: D401 keep API identical
"""Return the most recent content or an empty string if the file is new."""
revs = self._files.get(filename, [])
return revs[-1].content if revs else ""
def add_or_update_file(self, filename: str, new_content: Union[str, bytes, Any]) -> None:
"""Create *filename* or append a new revision containing *new_content*."""
if isinstance(new_content, bytes):
new_content = new_content.decode("utf-8")
if not isinstance(new_content, str):
raise ValueError(f"Expected str or bytes, got {type(new_content)}")
if filename not in self._files:
self._files[filename] = [FileRevision(new_content, 1)]
else:
last_rev_num = self._files[filename][-1].revision
self._files[filename].append(FileRevision(new_content, last_rev_num + 1))
def get_diff(self, filename: str, from_revision: int, to_revision: int) -> str:
"""Return a unified diff between *from_revision* and *to_revision*."""
revisions = self._files.get(filename, [])
if not revisions:
return ""
# Fetch the contents for the requested revisions.
from_content = self.get_revision_content(filename, from_revision)
to_content = self.get_revision_content(filename, to_revision)
if from_content == "" and to_content == "": # one (or both) revision ids not found
return ""
diff = difflib.unified_diff(
from_content.splitlines(keepends=True),
to_content.splitlines(keepends=True),
fromfile=f"{filename}@r{from_revision}",
tofile=f"{filename}@r{to_revision}",
)
return "".join(diff)
def apply_patch(self, filename: str, patch_data: Union[str, bytes, Any]) -> None:
"""Apply *patch_text* (unified diff) to the latest revision and save a new revision.
Uses the *unidiff* library to accurately apply hunks and validate context lines.
"""
if isinstance(patch_data, bytes):
patch_data = patch_data.decode("utf-8")
if not isinstance(patch_data, str):
raise ValueError(f"Expected str or bytes, got {type(patch_data)}")
self._ensure_file(filename)
original_content = self.get_latest_content(filename)
if PatchSet is None:
raise ImportError(
"The 'unidiff' package is required for patch application. Install with 'pip install unidiff'."
)
patch = PatchSet(patch_data)
# Our canvas stores exactly one file per patch operation so we
# use the first (and only) patched_file object.
if not patch:
raise ValueError("Empty patch text provided.")
patched_file = patch[0]
working_lines = original_content.splitlines(keepends=True)
line_offset = 0
for hunk in patched_file:
# Calculate the slice boundaries in the *current* working copy.
start = hunk.source_start - 1 + line_offset
end = start + hunk.source_length
# Build the replacement block for this hunk.
replacement: List[str] = []
for line in hunk:
if line.is_added or line.is_context:
replacement.append(line.value)
# removed lines (line.is_removed) are *not* added.
# Replace the slice with the hunkresult.
working_lines[start:end] = replacement
line_offset += len(replacement) - (end - start)
new_content = "".join(working_lines)
# Finally commit the new revision.
self.add_or_update_file(filename, new_content)
# ----------------------------------------------------------------------------------
# Convenience helpers
# ----------------------------------------------------------------------------------
def get_all_contents_for_context(self) -> str: # noqa: D401 keep public API stable
"""Return a summarised view of every file and its *latest* revision."""
out: List[str] = ["=== CANVAS FILES ==="]
for fname, revs in self._files.items():
latest = revs[-1]
out.append(f"File: {fname} (rev {latest.revision}):\n{latest.content}\n")
out.append("=== END OF CANVAS ===")
return "\n".join(out)

View File

@@ -0,0 +1,229 @@
from typing import Any, Optional
from agentdhal_core import CancellationToken
from agentdhal_core.memory import (
Memory,
MemoryContent,
MemoryMimeType,
MemoryQueryResult,
UpdateContextResult,
)
from agentdhal_core.model_context import ChatCompletionContext
from agentdhal_core.models import SystemMessage
from ._canvas_writer import ApplyPatchTool, UpdateFileTool
from ._text_canvas import TextCanvas
class TextCanvasMemory(Memory):
"""
A memory implementation that uses a Canvas for storing file-like content.
Inserts the current state of the canvas into the ChatCompletionContext on each turn.
.. warning::
This is an experimental API and may change in the future.
The TextCanvasMemory provides a persistent, file-like storage mechanism that can be used
by agents to read and write content. It automatically injects the current state of all files
in the canvas into the model context before each inference.
This is particularly useful for:
- Allowing agents to create and modify documents over multiple turns
- Enabling collaborative document editing between multiple agents
- Maintaining persistent state across conversation turns
- Working with content too large to fit in a single message
The canvas provides tools for:
- Creating or updating files with new content
- Applying patches (unified diff format) to existing files
Examples:
**Example: Using TextCanvasMemory with an AssistantAgent**
The following example demonstrates how to create a TextCanvasMemory and use it with
an AssistantAgent to write and update a story file.
.. code-block:: python
import asyncio
from agentdhal_core import CancellationToken
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.messages import TextMessage
from agentdhal_extensions.memory.canvas import TextCanvasMemory
async def main():
# Create a model client
model_client = OpenAIChatCompletionClient(
model="gpt-4o",
# api_key = "your_openai_api_key"
)
# Create the canvas memory
text_canvas_memory = TextCanvasMemory()
# Get tools for working with the canvas
update_file_tool = text_canvas_memory.get_update_file_tool()
apply_patch_tool = text_canvas_memory.get_apply_patch_tool()
# Create an agent with the canvas memory and tools
writer_agent = AssistantAgent(
name="Writer",
model_client=model_client,
description="A writer agent that creates and updates stories.",
system_message='''
You are a Writer Agent. Your focus is to generate a story based on the user's request.
Instructions for using the canvas:
- The story should be stored on the canvas in a file named "story.md".
- If "story.md" does not exist, create it by calling the 'update_file' tool.
- If "story.md" already exists, generate a unified diff (patch) from the current
content to the new version, and call the 'apply_patch' tool to apply the changes.
IMPORTANT: Do not include the full story text in your chat messages.
Only write the story content to the canvas using the tools.
''',
tools=[update_file_tool, apply_patch_tool],
memory=[text_canvas_memory],
)
# Send a message to the agent
await writer_agent.on_messages(
[TextMessage(content="Write a short story about a bunny and a sunflower.", source="user")],
CancellationToken(),
)
# Retrieve the content from the canvas
story_content = text_canvas_memory.canvas.get_latest_content("story.md")
print("Story content from canvas:")
print(story_content)
if __name__ == "__main__":
asyncio.run(main())
**Example: Using TextCanvasMemory with multiple agents**
The following example shows how to use TextCanvasMemory with multiple agents
collaborating on the same document.
.. code-block:: python
import asyncio
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.teams import RoundRobinGroupChat
from agentdhal_agentchat.conditions import TextMentionTermination
from agentdhal_extensions.memory.canvas import TextCanvasMemory
async def main():
# Create a model client
model_client = OpenAIChatCompletionClient(
model="gpt-4o",
# api_key = "your_openai_api_key"
)
# Create the shared canvas memory
text_canvas_memory = TextCanvasMemory()
update_file_tool = text_canvas_memory.get_update_file_tool()
apply_patch_tool = text_canvas_memory.get_apply_patch_tool()
# Create a writer agent
writer_agent = AssistantAgent(
name="Writer",
model_client=model_client,
description="A writer agent that creates stories.",
system_message="You write children's stories on the canvas in story.md.",
tools=[update_file_tool, apply_patch_tool],
memory=[text_canvas_memory],
)
# Create a critique agent
critique_agent = AssistantAgent(
name="Critique",
model_client=model_client,
description="A critique agent that provides feedback on stories.",
system_message="You review the story.md file and provide constructive feedback.",
memory=[text_canvas_memory],
)
# Create a team with both agents
team = RoundRobinGroupChat(
participants=[writer_agent, critique_agent],
termination_condition=TextMentionTermination("TERMINATE"),
max_turns=10,
)
# Run the team on a task
await team.run(task="Create a children's book about a bunny and a sunflower")
# Get the final story
story = text_canvas_memory.canvas.get_latest_content("story.md")
print(story)
if __name__ == "__main__":
asyncio.run(main())
"""
def __init__(self, canvas: Optional[TextCanvas] = None):
super().__init__()
self.canvas = canvas if canvas is not None else TextCanvas()
async def update_context(self, model_context: ChatCompletionContext) -> UpdateContextResult:
"""
Inject the entire canvas summary (or a selected subset) as reference data.
Here, we just put it into a system message, but you could customize.
"""
snapshot = self.canvas.get_all_contents_for_context()
if snapshot.strip():
msg = SystemMessage(content=snapshot)
await model_context.add_message(msg)
# Return it for debugging/logging
memory_content = MemoryContent(content=snapshot, mime_type=MemoryMimeType.TEXT)
return UpdateContextResult(memories=MemoryQueryResult(results=[memory_content]))
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
async def query(
self, query: str | MemoryContent, cancellation_token: Optional[CancellationToken] = None, **kwargs: Any
) -> MemoryQueryResult:
"""
Potentially search for matching filenames or file content.
This example returns empty.
"""
return MemoryQueryResult(results=[])
async def add(self, content: MemoryContent, cancellation_token: Optional[CancellationToken] = None) -> None:
"""
Example usage: Possibly interpret content as a patch or direct file update.
Could also be done by a specialized "CanvasTool" instead.
"""
# NO-OP here, leaving actual changes to the CanvasTool
pass
async def clear(self) -> None:
"""Clear the entire canvas by replacing it with a new empty instance."""
# Create a new TextCanvas instance instead of calling __init__ directly
self.canvas = TextCanvas()
async def close(self) -> None:
pass
def get_update_file_tool(self) -> UpdateFileTool:
"""
Returns an UpdateFileTool instance that works with this memory's canvas.
"""
return UpdateFileTool(self.canvas)
def get_apply_patch_tool(self) -> ApplyPatchTool:
"""
Returns an ApplyPatchTool instance that works with this memory's canvas.
"""
return ApplyPatchTool(self.canvas)

View File

@@ -0,0 +1,21 @@
from ._chroma_configs import (
ChromaDBVectorMemoryConfig,
CustomEmbeddingFunctionConfig,
DefaultEmbeddingFunctionConfig,
HttpChromaDBVectorMemoryConfig,
OpenAIEmbeddingFunctionConfig,
PersistentChromaDBVectorMemoryConfig,
SentenceTransformerEmbeddingFunctionConfig,
)
from ._chromadb import ChromaDBVectorMemory
__all__ = [
"ChromaDBVectorMemory",
"ChromaDBVectorMemoryConfig",
"PersistentChromaDBVectorMemoryConfig",
"HttpChromaDBVectorMemoryConfig",
"DefaultEmbeddingFunctionConfig",
"SentenceTransformerEmbeddingFunctionConfig",
"OpenAIEmbeddingFunctionConfig",
"CustomEmbeddingFunctionConfig",
]

View File

@@ -0,0 +1,137 @@
"""Configuration classes for ChromaDB vector memory."""
from typing import Any, Callable, Dict, Literal, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
class DefaultEmbeddingFunctionConfig(BaseModel):
"""Configuration for the default ChromaDB embedding function.
Uses ChromaDB's default embedding function (Sentence Transformers all-MiniLM-L6-v2).
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
"""
function_type: Literal["default"] = "default"
class SentenceTransformerEmbeddingFunctionConfig(BaseModel):
"""Configuration for SentenceTransformer embedding functions.
Allows specifying a custom SentenceTransformer model for embeddings.
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
Args:
model_name (str): Name of the SentenceTransformer model to use.
Defaults to "all-MiniLM-L6-v2".
Example:
.. code-block:: python
from agentdhal_extensions.memory.chromadb import SentenceTransformerEmbeddingFunctionConfig
_ = SentenceTransformerEmbeddingFunctionConfig(model_name="paraphrase-multilingual-mpnet-base-v2")
"""
function_type: Literal["sentence_transformer"] = "sentence_transformer"
model_name: str = Field(default="all-MiniLM-L6-v2", description="SentenceTransformer model name to use")
class OpenAIEmbeddingFunctionConfig(BaseModel):
"""Configuration for OpenAI embedding functions.
Uses OpenAI's embedding API for generating embeddings.
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
Args:
api_key (str): OpenAI API key. If empty, will attempt to use environment variable.
model_name (str): OpenAI embedding model name. Defaults to "text-embedding-ada-002".
Example:
.. code-block:: python
from agentdhal_extensions.memory.chromadb import OpenAIEmbeddingFunctionConfig
_ = OpenAIEmbeddingFunctionConfig(api_key="sk-...", model_name="text-embedding-3-small")
"""
function_type: Literal["openai"] = "openai"
api_key: str = Field(default="", description="OpenAI API key")
model_name: str = Field(default="text-embedding-ada-002", description="OpenAI embedding model name")
class CustomEmbeddingFunctionConfig(BaseModel):
"""Configuration for custom embedding functions.
Allows using a custom function that returns a ChromaDB-compatible embedding function.
.. versionadded:: v0.4.1
Support for custom embedding functions in ChromaDB memory.
.. warning::
Configurations containing custom functions are not serializable.
Args:
function (Callable): Function that returns a ChromaDB-compatible embedding function.
params (Dict[str, Any]): Parameters to pass to the function.
"""
function_type: Literal["custom"] = "custom"
function: Callable[..., Any] = Field(description="Function that returns an embedding function")
params: Dict[str, Any] = Field(default_factory=dict, description="Parameters to pass to the function")
# Tagged union type for embedding function configurations
EmbeddingFunctionConfig = Annotated[
Union[
DefaultEmbeddingFunctionConfig,
SentenceTransformerEmbeddingFunctionConfig,
OpenAIEmbeddingFunctionConfig,
CustomEmbeddingFunctionConfig,
],
Field(discriminator="function_type"),
]
class ChromaDBVectorMemoryConfig(BaseModel):
"""Base configuration for ChromaDB-based memory implementation.
.. versionchanged:: v0.4.1
Added support for custom embedding functions via embedding_function_config.
"""
client_type: Literal["persistent", "http"]
collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection")
distance_metric: str = Field(default="cosine", description="Distance metric for similarity search")
k: int = Field(default=3, description="Number of results to return in queries")
score_threshold: float | None = Field(default=None, description="Minimum similarity score threshold")
allow_reset: bool = Field(default=False, description="Whether to allow resetting the ChromaDB client")
tenant: str = Field(default="default_tenant", description="Tenant to use")
database: str = Field(default="default_database", description="Database to use")
embedding_function_config: EmbeddingFunctionConfig = Field(
default_factory=DefaultEmbeddingFunctionConfig, description="Configuration for the embedding function"
)
class PersistentChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
"""Configuration for persistent ChromaDB memory."""
client_type: Literal["persistent", "http"] = "persistent"
persistence_path: str = Field(default="./chroma_db", description="Path for persistent storage")
class HttpChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
"""Configuration for HTTP ChromaDB memory."""
client_type: Literal["persistent", "http"] = "http"
host: str = Field(default="localhost", description="Host of the remote server")
port: int = Field(default=8000, description="Port of the remote server")
ssl: bool = Field(default=False, description="Whether to use HTTPS")
headers: Dict[str, str] | None = Field(default=None, description="Headers to send to the server")

View File

@@ -0,0 +1,459 @@
import logging
import uuid
from typing import Any, List
from agentdhal_core import CancellationToken, Component, Image
from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
from agentdhal_core.model_context import ChatCompletionContext
from agentdhal_core.models import SystemMessage
from chromadb import HttpClient, PersistentClient
from chromadb.api.models.Collection import Collection
from chromadb.api.types import Document, Metadata
from typing_extensions import Self
from ._chroma_configs import (
ChromaDBVectorMemoryConfig,
CustomEmbeddingFunctionConfig,
DefaultEmbeddingFunctionConfig,
HttpChromaDBVectorMemoryConfig,
OpenAIEmbeddingFunctionConfig,
PersistentChromaDBVectorMemoryConfig,
SentenceTransformerEmbeddingFunctionConfig,
)
logger = logging.getLogger(__name__)
try:
from chromadb.api import ClientAPI
except ImportError as e:
raise ImportError(
"To use the ChromaDBVectorMemory the chromadb extra must be installed. Run `pip install autogen-ext[chromadb]`"
) from e
class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
"""
Store and retrieve memory using vector similarity search powered by ChromaDB.
`ChromaDBVectorMemory` provides a vector-based memory implementation that uses ChromaDB for
storing and retrieving content based on semantic similarity. It enhances agents with the ability
to recall contextually relevant information during conversations by leveraging vector embeddings
to find similar content.
This implementation serves as a reference for more complex memory systems using vector embeddings.
For advanced use cases requiring specialized formatting of retrieved content, users should extend
this class and override the `update_context()` method.
This implementation requires the ChromaDB extra to be installed. Install with:
.. code-block:: bash
pip install "agentdhal-ext[chromadb]"
Args:
config (ChromaDBVectorMemoryConfig | None): Configuration for the ChromaDB memory.
If None, defaults to a PersistentChromaDBVectorMemoryConfig with default values.
Two config types are supported:
* PersistentChromaDBVectorMemoryConfig: For local storage
* HttpChromaDBVectorMemoryConfig: For connecting to a remote ChromaDB server
Example:
.. code-block:: python
import os
import asyncio
from pathlib import Path
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.ui import Console
from agentdhal_core.memory import MemoryContent, MemoryMimeType
from agentdhal_extensions.memory.chromadb import (
ChromaDBVectorMemory,
PersistentChromaDBVectorMemoryConfig,
SentenceTransformerEmbeddingFunctionConfig,
OpenAIEmbeddingFunctionConfig,
)
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
def get_weather(city: str) -> str:
return f"The weather in {city} is sunny with a high of 90°F and a low of 70°F."
def fahrenheit_to_celsius(fahrenheit: float) -> float:
return (fahrenheit - 32) * 5.0 / 9.0
async def main() -> None:
# Use default embedding function
default_memory = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
collection_name="user_preferences",
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
k=3, # Return top 3 results
score_threshold=0.5, # Minimum similarity score
)
)
# Using a custom SentenceTransformer model
custom_memory = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
collection_name="multilingual_memory",
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(
model_name="paraphrase-multilingual-mpnet-base-v2"
),
)
)
# Using OpenAI embeddings
openai_memory = ChromaDBVectorMemory(
config=PersistentChromaDBVectorMemoryConfig(
collection_name="openai_memory",
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
embedding_function_config=OpenAIEmbeddingFunctionConfig(
api_key=os.environ["OPENAI_API_KEY"], model_name="text-embedding-3-small"
),
)
)
# Add user preferences to memory
await openai_memory.add(
MemoryContent(
content="The user prefers weather temperatures in Celsius",
mime_type=MemoryMimeType.TEXT,
metadata={"category": "preferences", "type": "units"},
)
)
# Create assistant agent with ChromaDB memory
assistant = AssistantAgent(
name="assistant",
model_client=OpenAIChatCompletionClient(
model="gpt-4.1",
),
tools=[
get_weather,
fahrenheit_to_celsius,
],
max_tool_iterations=10,
memory=[openai_memory],
)
# The memory will automatically retrieve relevant content during conversations
await Console(assistant.run_stream(task="What's the temperature in New York?"))
# Remember to close the memory when finished
await default_memory.close()
await custom_memory.close()
await openai_memory.close()
asyncio.run(main())
Output:
.. code-block:: text
---------- TextMessage (user) ----------
What's the temperature in New York?
---------- MemoryQueryEvent (assistant) ----------
[MemoryContent(content='The user prefers weather temperatures in Celsius', mime_type='MemoryMimeType.TEXT', metadata={'type': 'units', 'category': 'preferences', 'mime_type': 'MemoryMimeType.TEXT', 'score': 0.3133561611175537, 'id': 'fb00506c-acf4-4174-93d7-2a942593f3f7'}), MemoryContent(content='The user prefers weather temperatures in Celsius', mime_type='MemoryMimeType.TEXT', metadata={'mime_type': 'MemoryMimeType.TEXT', 'category': 'preferences', 'type': 'units', 'score': 0.3133561611175537, 'id': '34311689-b419-4e1a-8bc4-09143f356c66'})]
---------- ToolCallRequestEvent (assistant) ----------
[FunctionCall(id='call_7TjsFd430J1aKwU5T2w8bvdh', arguments='{"city":"New York"}', name='get_weather')]
---------- ToolCallExecutionEvent (assistant) ----------
[FunctionExecutionResult(content='The weather in New York is sunny with a high of 90°F and a low of 70°F.', name='get_weather', call_id='call_7TjsFd430J1aKwU5T2w8bvdh', is_error=False)]
---------- ToolCallRequestEvent (assistant) ----------
[FunctionCall(id='call_RTjMHEZwDXtjurEYTjDlvq9c', arguments='{"fahrenheit": 90}', name='fahrenheit_to_celsius'), FunctionCall(id='call_3mMuCK1aqtzZPTqIHPoHKxtP', arguments='{"fahrenheit": 70}', name='fahrenheit_to_celsius')]
---------- ToolCallExecutionEvent (assistant) ----------
[FunctionExecutionResult(content='32.22222222222222', name='fahrenheit_to_celsius', call_id='call_RTjMHEZwDXtjurEYTjDlvq9c', is_error=False), FunctionExecutionResult(content='21.11111111111111', name='fahrenheit_to_celsius', call_id='call_3mMuCK1aqtzZPTqIHPoHKxtP', is_error=False)]
---------- TextMessage (assistant) ----------
The temperature in New York today is sunny with a high of about 32°C and a low of about 21°C.
"""
component_config_schema = ChromaDBVectorMemoryConfig
component_provider_override = "agentdhal_extensions.memory.chromadb.ChromaDBVectorMemory"
def __init__(self, config: ChromaDBVectorMemoryConfig | None = None) -> None:
self._config = config or PersistentChromaDBVectorMemoryConfig()
self._client: ClientAPI | None = None
self._collection: Collection | None = None
@property
def collection_name(self) -> str:
"""Get the name of the ChromaDB collection."""
return self._config.collection_name
def _create_embedding_function(self) -> Any:
"""Create an embedding function based on the configuration.
Returns:
A ChromaDB-compatible embedding function.
Raises:
ValueError: If the embedding function type is unsupported.
ImportError: If required dependencies are not installed.
"""
try:
from chromadb.utils import embedding_functions
except ImportError as e:
raise ImportError(
"ChromaDB embedding functions not available. Ensure chromadb is properly installed."
) from e
config = self._config.embedding_function_config
if isinstance(config, DefaultEmbeddingFunctionConfig):
return embedding_functions.DefaultEmbeddingFunction()
elif isinstance(config, SentenceTransformerEmbeddingFunctionConfig):
try:
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=config.model_name)
except Exception as e:
raise ImportError(
f"Failed to create SentenceTransformer embedding function with model '{config.model_name}'. "
f"Ensure sentence-transformers is installed and the model is available. Error: {e}"
) from e
elif isinstance(config, OpenAIEmbeddingFunctionConfig):
try:
return embedding_functions.OpenAIEmbeddingFunction(api_key=config.api_key, model_name=config.model_name)
except Exception as e:
raise ImportError(
f"Failed to create OpenAI embedding function with model '{config.model_name}'. "
f"Ensure openai is installed and API key is valid. Error: {e}"
) from e
elif isinstance(config, CustomEmbeddingFunctionConfig):
try:
return config.function(**config.params)
except Exception as e:
raise ValueError(f"Failed to create custom embedding function. Error: {e}") from e
else:
raise ValueError(f"Unsupported embedding function config type: {type(config)}")
def _ensure_initialized(self) -> None:
"""Ensure ChromaDB client and collection are initialized."""
if self._client is None:
try:
from chromadb.config import Settings
settings = Settings(allow_reset=self._config.allow_reset)
if isinstance(self._config, PersistentChromaDBVectorMemoryConfig):
self._client = PersistentClient(
path=self._config.persistence_path,
settings=settings,
tenant=self._config.tenant,
database=self._config.database,
)
elif isinstance(self._config, HttpChromaDBVectorMemoryConfig):
self._client = HttpClient(
host=self._config.host,
port=self._config.port,
ssl=self._config.ssl,
headers=self._config.headers,
settings=settings,
tenant=self._config.tenant,
database=self._config.database,
)
else:
raise ValueError(f"Unsupported config type: {type(self._config)}")
except Exception as e:
logger.error(f"Failed to initialize ChromaDB client: {e}")
raise
if self._collection is None:
try:
# Create embedding function
embedding_function = self._create_embedding_function()
# Create or get collection with embedding function
self._collection = self._client.get_or_create_collection(
name=self._config.collection_name,
metadata={"distance_metric": self._config.distance_metric},
embedding_function=embedding_function,
)
except Exception as e:
logger.error(f"Failed to get/create collection: {e}")
raise
def _extract_text(self, content_item: str | MemoryContent) -> str:
"""Extract searchable text from content."""
if isinstance(content_item, str):
return content_item
content = content_item.content
mime_type = content_item.mime_type
if mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]:
return str(content)
elif mime_type == MemoryMimeType.JSON:
if isinstance(content, dict):
# Store original JSON string representation
return str(content).lower()
raise ValueError("JSON content must be a dict")
elif isinstance(content, Image):
raise ValueError("Image content cannot be converted to text")
else:
raise ValueError(f"Unsupported content type: {mime_type}")
def _calculate_score(self, distance: float) -> float:
"""Convert ChromaDB distance to a similarity score."""
if self._config.distance_metric == "cosine":
return 1.0 - (distance / 2.0)
return 1.0 / (1.0 + distance)
async def update_context(
self,
model_context: ChatCompletionContext,
) -> UpdateContextResult:
messages = await model_context.get_messages()
if not messages:
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
# Extract query from last message
last_message = messages[-1]
query_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
# Query memory and get results
query_results = await self.query(query_text)
if query_results.results:
# Format results for context
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)]
memory_context = "\nRelevant memory content:\n" + "\n".join(memory_strings)
# Add to context
await model_context.add_message(SystemMessage(content=memory_context))
return UpdateContextResult(memories=query_results)
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
self._ensure_initialized()
if self._collection is None:
raise RuntimeError("Failed to initialize ChromaDB")
try:
# Extract text from content
text = self._extract_text(content)
# Use metadata directly from content
metadata_dict = content.metadata or {}
metadata_dict["mime_type"] = str(content.mime_type)
# Add to ChromaDB
self._collection.add(documents=[text], metadatas=[metadata_dict], ids=[str(uuid.uuid4())])
except Exception as e:
logger.error(f"Failed to add content to ChromaDB: {e}")
raise
async def query(
self,
query: str | MemoryContent,
cancellation_token: CancellationToken | None = None,
**kwargs: Any,
) -> MemoryQueryResult:
self._ensure_initialized()
if self._collection is None:
raise RuntimeError("Failed to initialize ChromaDB")
try:
# Extract text for query
query_text = self._extract_text(query)
# Query ChromaDB
results = self._collection.query(
query_texts=[query_text],
n_results=self._config.k,
include=["documents", "metadatas", "distances"],
**kwargs,
)
# Convert results to MemoryContent list
memory_results: List[MemoryContent] = []
if (
not results
or not results.get("documents")
or not results.get("metadatas")
or not results.get("distances")
):
return MemoryQueryResult(results=memory_results)
documents: List[Document] = results["documents"][0] if results["documents"] else []
metadatas: List[Metadata] = results["metadatas"][0] if results["metadatas"] else []
distances: List[float] = results["distances"][0] if results["distances"] else []
ids: List[str] = results["ids"][0] if results["ids"] else []
for doc, metadata_dict, distance, doc_id in zip(documents, metadatas, distances, ids, strict=False):
# Calculate score
score = self._calculate_score(distance)
metadata = dict(metadata_dict)
metadata["score"] = score
metadata["id"] = doc_id
if self._config.score_threshold is not None and score < self._config.score_threshold:
continue
# Extract mime_type from metadata
mime_type = str(metadata_dict.get("mime_type", MemoryMimeType.TEXT.value))
# Create MemoryContent
content = MemoryContent(
content=doc,
mime_type=mime_type,
metadata=metadata,
)
memory_results.append(content)
return MemoryQueryResult(results=memory_results)
except Exception as e:
logger.error(f"Failed to query ChromaDB: {e}")
raise
async def clear(self) -> None:
self._ensure_initialized()
if self._collection is None:
raise RuntimeError("Failed to initialize ChromaDB")
try:
results = self._collection.get()
if results and results["ids"]:
self._collection.delete(ids=results["ids"])
except Exception as e:
logger.error(f"Failed to clear ChromaDB collection: {e}")
raise
async def close(self) -> None:
"""Clean up ChromaDB client and resources."""
self._collection = None
self._client = None
async def reset(self) -> None:
self._ensure_initialized()
if not self._config.allow_reset:
raise RuntimeError("Reset not allowed. Set allow_reset=True in config to enable.")
if self._client is not None:
try:
self._client.reset()
except Exception as e:
logger.error(f"Error during ChromaDB reset: {e}")
finally:
self._collection = None
def _to_config(self) -> ChromaDBVectorMemoryConfig:
"""Serialize the memory configuration."""
return self._config
@classmethod
def _from_config(cls, config: ChromaDBVectorMemoryConfig) -> Self:
"""Deserialize the memory configuration."""
return cls(config=config)

View File

@@ -0,0 +1,6 @@
from ._mem0 import Mem0Memory, Mem0MemoryConfig
__all__ = [
"Mem0Memory",
"Mem0MemoryConfig",
]

View File

@@ -0,0 +1,449 @@
import io
import logging
import uuid
from contextlib import redirect_stderr, redirect_stdout
from datetime import datetime
from typing import Any, Dict, List, Optional, TypedDict, cast
from agentdhal_core import CancellationToken, Component, ComponentBase
from agentdhal_core.memory import Memory, MemoryContent, MemoryQueryResult, UpdateContextResult
from agentdhal_core.model_context import ChatCompletionContext
from agentdhal_core.models import SystemMessage
from mem0 import Memory as Memory0
from mem0 import MemoryClient
from pydantic import BaseModel, Field
from typing_extensions import Self
logger = logging.getLogger(__name__)
logging.getLogger("chromadb").setLevel(logging.ERROR)
class Mem0MemoryConfig(BaseModel):
"""Configuration for Mem0Memory component."""
user_id: Optional[str] = Field(
default=None, description="User ID for memory operations. If not provided, a UUID will be generated."
)
limit: int = Field(default=10, description="Maximum number of results to return in memory queries.")
is_cloud: bool = Field(default=True, description="Whether to use cloud Mem0 client (True) or local client (False).")
api_key: Optional[str] = Field(
default=None, description="API key for cloud Mem0 client. Required if is_cloud=True."
)
config: Optional[Dict[str, Any]] = Field(
default=None, description="Configuration dictionary for local Mem0 client. Required if is_cloud=False."
)
class MemoryResult(TypedDict, total=False):
memory: str
score: float
metadata: Dict[str, Any]
created_at: str
updated_at: str
categories: List[str]
# pyright: reportGeneralTypeIssues=false
class Mem0Memory(Memory, Component[Mem0MemoryConfig], ComponentBase[Mem0MemoryConfig]):
"""Mem0 memory implementation for AutoGen.
This component integrates with Mem0.ai's memory system, providing an implementation
of AutoGen's Memory interface. It supports both cloud and local backends through the
mem0ai Python package.
To use this component, you need to have the `mem0` (for cloud-only) or `mem0-local` (for local)
extra installed for the `autogen-ext` package:
.. code-block:: bash
pip install -U "agentdhal-ext[mem0]" # For cloud-based Mem0
pip install -U "agentdhal-ext[mem0-local]" # For local Mem0
The memory component can store and retrieve information that agents need to remember
across conversations. It also provides context updating for language models with
relevant memories.
Examples:
.. code-block:: python
import asyncio
from agentdhal_extensions.memory.mem0 import Mem0Memory
from agentdhal_core.memory import MemoryContent
async def main() -> None:
# Create a local Mem0Memory (no API key required)
memory = Mem0Memory(
is_cloud=False,
config={"path": ":memory:"}, # Use in-memory storage for testing
)
print("Memory initialized successfully!")
# Add something to memory
test_content = "User likes the color blue."
await memory.add(MemoryContent(content=test_content, mime_type="text/plain"))
print(f"Added content: {test_content}")
# Retrieve memories with a search query
results = await memory.query("What color does the user like?")
print(f"Query results: {len(results.results)} found")
for i, result in enumerate(results.results):
print(f"Result {i+1}: {result}")
asyncio.run(main())
Output:
.. code-block:: text
Memory initialized successfully!
Added content: User likes the color blue.
Query results: 1 found
Result 1: content='User likes the color blue' mime_type='text/plain' metadata={'score': 0.6977155806281953, 'created_at': datetime.datetime(2025, 7, 6, 17, 25, 18, 754725, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200)))}
Using it with an :class:`~agentdhal_agentchat.agents.AssistantAgent`:
.. code-block:: python
import asyncio
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_core.memory import MemoryContent
from agentdhal_extensions.memory.mem0 import Mem0Memory
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
async def main() -> None:
# Create a model client
model_client = OpenAIChatCompletionClient(model="gpt-4.1")
# Create a Mem0 memory instance
memory = Mem0Memory(
user_id="user123",
is_cloud=False,
config={"path": ":memory:"}, # Use in-memory storage for testing
)
# Add something to memory
test_content = "User likes the color blue."
await memory.add(MemoryContent(content=test_content, mime_type="text/plain"))
# Create an assistant agent with Mem0 memory
agent = AssistantAgent(
name="assistant",
model_client=model_client,
memory=[memory],
system_message="You are a helpful assistant that remembers user preferences.",
)
# Run a sample task
result = await agent.run(task="What color does the user like?")
print(result.messages[-1].content) # type: ignore
asyncio.run(main())
Output:
.. code-block:: text
User likes the color blue.
Args:
user_id: Optional user ID for memory operations. If not provided, a UUID will be generated.
limit: Maximum number of results to return in memory queries.
is_cloud: Whether to use cloud Mem0 client (True) or local client (False).
api_key: API key for cloud Mem0 client. It will read from the environment MEM0_API_KEY if not provided.
config: Configuration dictionary for local Mem0 client. Required if is_cloud=False.
"""
component_type = "memory"
component_provider_override = "agentdhal_extensions.memory.mem0.Mem0Memory"
component_config_schema = Mem0MemoryConfig
def __init__(
self,
user_id: Optional[str] = None,
limit: int = 10,
is_cloud: bool = True,
api_key: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
) -> None:
# Validate parameters
if not is_cloud and config is None:
raise ValueError("config is required when using local Mem0 client (is_cloud=False)")
# Initialize instance variables
self._user_id = user_id or str(uuid.uuid4())
self._limit = limit
self._is_cloud = is_cloud
self._api_key = api_key
self._config = config
# Initialize client
if self._is_cloud:
self._client = MemoryClient(api_key=self._api_key)
else:
assert self._config is not None
config_dict = self._config
self._client = Memory0.from_config(config_dict=config_dict) # type: ignore
@property
def user_id(self) -> str:
"""Get the user ID for memory operations."""
return self._user_id
@property
def limit(self) -> int:
"""Get the maximum number of results to return in memory queries."""
return self._limit
@property
def is_cloud(self) -> bool:
"""Check if the Mem0 client is cloud-based."""
return self._is_cloud
@property
def config(self) -> Optional[Dict[str, Any]]:
"""Get the configuration for the Mem0 client."""
return self._config
async def add(
self,
content: MemoryContent,
cancellation_token: Optional[CancellationToken] = None,
) -> None:
"""Add content to memory.
Args:
content: The memory content to add.
cancellation_token: Optional token to cancel operation.
Raises:
Exception: If there's an error adding content to mem0 memory.
"""
# Extract content based on mime type
if hasattr(content, "content") and hasattr(content, "mime_type"):
if content.mime_type in ["text/plain", "text/markdown"]:
message = str(content.content)
elif content.mime_type == "application/json":
# Convert JSON content to string representation
if isinstance(content.content, str):
message = content.content
else:
# Convert dict or other JSON serializable objects to string
import json
message = json.dumps(content.content)
else:
message = str(content.content)
# Extract metadata
metadata = content.metadata or {}
else:
# Handle case where content is directly provided as string
message = str(content)
metadata = {}
# Check if operation is cancelled
if cancellation_token is not None and cancellation_token.cancelled: # type: ignore
return
# Add to mem0 client
try:
user_id = metadata.pop("user_id", self._user_id)
# Suppress warning messages from mem0 MemoryClient
kwargs = {} if self._client.__class__.__name__ == "Memory" else {"output_format": "v1.1"}
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
self._client.add([{"role": "user", "content": message}], user_id=user_id, metadata=metadata, **kwargs) # type: ignore
except Exception as e:
# Log the error but don't crash
logger.error(f"Error adding to mem0 memory: {str(e)}")
raise
async def query(
self,
query: str | MemoryContent = "",
cancellation_token: Optional[CancellationToken] = None,
**kwargs: Any,
) -> MemoryQueryResult:
"""Query memory for relevant content.
Args:
query: The query to search for, either as string or MemoryContent.
cancellation_token: Optional token to cancel operation.
**kwargs: Additional query parameters to pass to mem0.
Returns:
MemoryQueryResult containing search results.
"""
# Extract query text
if isinstance(query, str):
query_text = query
elif hasattr(query, "content"):
query_text = str(query.content)
else:
query_text = str(query)
# Check if operation is cancelled
if (
cancellation_token
and hasattr(cancellation_token, "cancelled")
and getattr(cancellation_token, "cancelled", False)
):
return MemoryQueryResult(results=[])
try:
limit = kwargs.pop("limit", self._limit)
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
# Query mem0 client
results = self._client.search( # type: ignore
query_text,
user_id=self._user_id,
limit=limit,
**kwargs,
)
# Type-safe handling of results
if isinstance(results, dict) and "results" in results:
result_list = cast(List[MemoryResult], results["results"])
else:
result_list = cast(List[MemoryResult], results)
# Convert results to MemoryContent objects
memory_contents: List[MemoryContent] = []
for result in result_list:
content_text = result.get("memory", "")
metadata: Dict[str, Any] = {}
if "metadata" in result and result["metadata"]:
metadata = result["metadata"]
# Add relevant fields to metadata
if "score" in result:
metadata["score"] = result["score"]
# For created_at
if "created_at" in result and result.get("created_at"):
try:
metadata["created_at"] = datetime.fromisoformat(result["created_at"])
except (ValueError, TypeError):
pass
# For updated_at
if "updated_at" in result and result.get("updated_at"):
try:
metadata["updated_at"] = datetime.fromisoformat(result["updated_at"])
except (ValueError, TypeError):
pass
# For categories
if "categories" in result and result.get("categories"):
metadata["categories"] = result["categories"]
# Create MemoryContent object
memory_content = MemoryContent(
content=content_text,
mime_type="text/plain", # Default to text/plain
metadata=metadata,
)
memory_contents.append(memory_content)
return MemoryQueryResult(results=memory_contents)
except Exception as e:
# Log the error but return empty results
logger.error(f"Error querying mem0 memory: {str(e)}")
return MemoryQueryResult(results=[])
async def update_context(
self,
model_context: ChatCompletionContext,
) -> UpdateContextResult:
"""Update the model context with relevant memories.
This method retrieves the conversation history from the model context,
uses the last message as a query to find relevant memories, and then
adds those memories to the context as a system message.
Args:
model_context: The model context to update.
Returns:
UpdateContextResult containing memories added to the context.
"""
# Get messages from context
messages = await model_context.get_messages()
if not messages:
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
# Use the last message as query
last_message = messages[-1]
query_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
# Query memory
query_results = await self.query(query_text, limit=self._limit)
# If we have results, add them to the context
if query_results.results:
# Format memories as numbered list
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)]
memory_context = "\nRelevant memories:\n" + "\n".join(memory_strings)
# Add as system message
await model_context.add_message(SystemMessage(content=memory_context))
return UpdateContextResult(memories=query_results)
async def clear(self) -> None:
"""Clear all content from memory for the current user.
Raises:
Exception: If there's an error clearing mem0 memory.
"""
try:
self._client.delete_all(user_id=self._user_id) # type: ignore
except Exception as e:
logger.error(f"Error clearing mem0 memory: {str(e)}")
raise
async def close(self) -> None:
"""Clean up resources if needed.
This is a no-op for Mem0 clients as they don't require explicit cleanup.
"""
pass
@classmethod
def _from_config(cls, config: Mem0MemoryConfig) -> Self:
"""Create instance from configuration.
Args:
config: Configuration for Mem0Memory component.
Returns:
A new Mem0Memory instance.
"""
return cls(
user_id=config.user_id,
limit=config.limit,
is_cloud=config.is_cloud,
api_key=config.api_key,
config=config.config,
)
def _to_config(self) -> Mem0MemoryConfig:
"""Convert instance to configuration.
Returns:
Configuration representing this Mem0Memory instance.
"""
return Mem0MemoryConfig(
user_id=self._user_id,
limit=self._limit,
is_cloud=self._is_cloud,
api_key=self._api_key,
config=self._config,
)

View File

@@ -0,0 +1,9 @@
from ._redis_memory import (
RedisMemory,
RedisMemoryConfig,
)
__all__ = [
"RedisMemoryConfig",
"RedisMemory",
]

View File

@@ -0,0 +1,325 @@
import logging
from typing import Any, List, Literal
from agentdhal_core import CancellationToken, Component
from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
from agentdhal_core.model_context import ChatCompletionContext
from agentdhal_core.models import SystemMessage
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
try:
from redis import Redis
from redisvl.extensions.message_history import SemanticMessageHistory
from redisvl.utils.utils import deserialize, serialize
except ImportError as e:
raise ImportError("To use Redis Memory RedisVL must be installed. Run `pip install autogen-ext[redisvl]`") from e
class RedisMemoryConfig(BaseModel):
"""
Configuration for Redis-based vector memory.
This class defines the configuration options for using Redis as a vector memory store,
supporting semantic memory. It allows customization of the Redis connection, index settings,
similarity search parameters, and embedding model.
"""
redis_url: str = Field(default="redis://localhost:6379", description="url of the Redis instance")
index_name: str = Field(default="chat_history", description="Name of the Redis collection")
prefix: str = Field(default="memory", description="prefix of the Redis collection")
distance_metric: Literal["cosine", "ip", "l2"] = "cosine"
algorithm: Literal["flat", "hnsw"] = "flat"
top_k: int = Field(default=10, description="Number of results to return in queries")
datatype: Literal["uint8", "int8", "float16", "float32", "float64", "bfloat16"] = "float32"
distance_threshold: float = Field(default=0.7, description="Minimum similarity score threshold")
model_name: str | None = Field(
default="sentence-transformers/all-mpnet-base-v2", description="Embedding model name"
)
class RedisMemory(Memory, Component[RedisMemoryConfig]):
"""
Store and retrieve memory using vector similarity search powered by RedisVL.
`RedisMemory` provides a vector-based memory implementation that uses RedisVL for storing and
retrieving content based on semantic similarity. It enhances agents with the ability to recall
contextually relevant information during conversations by leveraging vector embeddings to find
similar content.
This implementation requires the RedisVL extra to be installed. Install with:
.. code-block:: bash
pip install "agentdhal-ext[redisvl]"
Additionally, you will need access to a Redis instance.
To run a local instance of redis in docker:
.. code-block:: bash
docker run -d --name redis -p 6379:6379 redis:8
To download and run Redis locally:
.. code-block:: bash
curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list
sudo apt-get update > /dev/null 2>&1
sudo apt-get install redis-server > /dev/null 2>&1
redis-server --daemonize yes
Args:
config (RedisMemoryConfig | None): Configuration for the Redis memory.
If None, defaults to a RedisMemoryConfig with recommended settings.
Example:
.. code-block:: python
from logging import WARNING, getLogger
import asyncio
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.ui import Console
from agentdhal_core.memory import MemoryContent, MemoryMimeType
from agentdhal_extensions.memory.redis import RedisMemory, RedisMemoryConfig
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
logger = getLogger()
logger.setLevel(WARNING)
# Define tool to use
async def get_weather(city: str, units: str = "imperial") -> str:
if units == "imperial":
return f"The weather in {city} is 73 °F and Sunny."
elif units == "metric":
return f"The weather in {city} is 23 °C and Sunny."
else:
return f"Sorry, I don't know the weather in {city}."
async def main():
# Initailize Redis memory
redis_memory = RedisMemory(
config=RedisMemoryConfig(
redis_url="redis://localhost:6379",
index_name="chat_history",
prefix="memory",
)
)
# Add user preferences to memory
await redis_memory.add(
MemoryContent(
content="The weather should be in metric units",
mime_type=MemoryMimeType.TEXT,
metadata={"category": "preferences", "type": "units"},
)
)
await redis_memory.add(
MemoryContent(
content="Meal recipe must be vegan",
mime_type=MemoryMimeType.TEXT,
metadata={"category": "preferences", "type": "dietary"},
)
)
model_client = OpenAIChatCompletionClient(
model="gpt-4o",
)
# Create assistant agent with ChromaDB memory
assistant_agent = AssistantAgent(
name="assistant_agent",
model_client=model_client,
tools=[get_weather],
memory=[redis_memory],
)
stream = assistant_agent.run_stream(task="What is the weather in New York?")
await Console(stream)
await model_client.close()
await redis_memory.close()
asyncio.run(main())
Output:
.. code-block:: text
---------- TextMessage (user) ----------
What is the weather in New York?
---------- MemoryQueryEvent (assistant_agent) ----------
[MemoryContent(content='The weather should be in metric units', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata={'category': 'preferences', 'type': 'units'})]
---------- ToolCallRequestEvent (assistant_agent) ----------
[FunctionCall(id='call_tyCPvPPAV4SHWhtfpM6UMemr', arguments='{"city":"New York","units":"metric"}', name='get_weather')]
---------- ToolCallExecutionEvent (assistant_agent) ----------
[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_tyCPvPPAV4SHWhtfpM6UMemr', is_error=False)]
---------- ToolCallSummaryMessage (assistant_agent) ----------
The weather in New York is 23 °C and Sunny.
"""
component_config_schema = RedisMemoryConfig
component_provider_override = "agentdhal_extensions.memory.redis_memory.RedisMemory"
def __init__(self, config: RedisMemoryConfig | None = None) -> None:
"""Initialize RedisMemory."""
self.config = config or RedisMemoryConfig()
client = Redis.from_url(url=self.config.redis_url) # type: ignore[reportUknownMemberType]
self.message_history = SemanticMessageHistory(name=self.config.index_name, redis_client=client)
async def update_context(
self,
model_context: ChatCompletionContext,
) -> UpdateContextResult:
"""
Update the model context with relevant memory content.
This method retrieves memory content relevant to the last message in the context
and adds it as a system message. This implementation uses the last message in the context
as a query to find semantically similar memories and adds them all to the context as a
single system message.
Args:
model_context (ChatCompletionContext): The model context to update with relevant
memories.
Returns:
UpdateContextResult: Object containing the memories that were used to update the
context.
"""
messages = await model_context.get_messages()
if messages:
last_message = str(messages[-1].content)
else:
last_message = ""
query_results = await self.query(last_message)
stringified_messages = "\n\n".join([str(m.content) for m in query_results.results])
await model_context.add_message(SystemMessage(content=stringified_messages))
return UpdateContextResult(memories=query_results)
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
"""Add a memory content object to Redis.
.. note::
To perform semantic search over stored memories RedisMemory creates a vector embedding
from the content field of a MemoryContent object. This content is assumed to be text,
JSON, or Markdown, and is passed to the vector embedding model specified in
RedisMemoryConfig.
Args:
content (MemoryContent): The memory content to store within Redis.
cancellation_token (CancellationToken): Token passed to cease operation. Not used.
"""
if content.mime_type == MemoryMimeType.TEXT:
memory_content = content.content
mime_type = "text/plain"
elif content.mime_type == MemoryMimeType.JSON:
memory_content = serialize(content.content)
mime_type = "application/json"
elif content.mime_type == MemoryMimeType.MARKDOWN:
memory_content = content.content
mime_type = "text/markdown"
else:
raise NotImplementedError(
f"Error: {content.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, and MemoryMimeType.MARKDOWN are currently supported."
)
metadata = {"mime_type": mime_type}
metadata.update(content.metadata if content.metadata else {})
self.message_history.add_message(
{"role": "user", "content": memory_content, "tool_call_id": serialize(metadata)} # type: ignore[reportArgumentType]
)
async def query(
self,
query: str | MemoryContent,
cancellation_token: CancellationToken | None = None,
**kwargs: Any,
) -> MemoryQueryResult:
"""Query memory content based on semantic vector similarity.
.. note::
RedisMemory.query() supports additional keyword arguments to improve query performance.
top_k (int): The maximum number of relevant memories to include. Defaults to 10.
distance_threshold (float): The maximum distance in vector space to consider a memory
semantically similar when performining cosine similarity search. Defaults to 0.7.
Args:
query (str | MemoryContent): query to perform vector similarity search with. If a
string is passed, a vector embedding is created from it with the model specified
in the RedisMemoryConfig. If a MemoryContent object is passed, the content field
of this object is extracted and a vector embedding is created from it with the
model specified in the RedisMemoryConfig.
cancellation_token (CancellationToken): Token passed to cease operation. Not used.
Returns:
memoryQueryResult: Object containing memories relevant to the provided query.
"""
# get the query string, or raise an error for unsupported MemoryContent types
if isinstance(query, str):
prompt = query
elif isinstance(query, MemoryContent):
if query.mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
prompt = str(query.content)
elif query.mime_type == MemoryMimeType.JSON:
prompt = serialize(query.content)
else:
raise NotImplementedError(
f"Error: {query.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, MemoryMimeType.MARKDOWN are currently supported."
)
else:
raise TypeError("'query' must be either a string or MemoryContent")
top_k = kwargs.pop("top_k", self.config.top_k)
distance_threshold = kwargs.pop("distance_threshold", self.config.distance_threshold)
results = self.message_history.get_relevant(
prompt=prompt, # type: ignore[reportArgumentType]
top_k=top_k,
distance_threshold=distance_threshold,
raw=False,
)
memories: List[MemoryContent] = []
for result in results:
metadata = deserialize(result["tool_call_id"]) # type: ignore[reportArgumentType]
mime_type = MemoryMimeType(metadata.pop("mime_type"))
if mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
memory_content = result["content"] # type: ignore[reportArgumentType]
elif mime_type == MemoryMimeType.JSON:
memory_content = deserialize(result["content"]) # type: ignore[reportArgumentType]
else:
raise NotImplementedError(
f"Error: {mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, and MemoryMimeType.MARKDOWN are currently supported."
)
memory = MemoryContent(
content=memory_content, # type: ignore[reportArgumentType]
mime_type=mime_type,
metadata=metadata,
)
memories.append(memory) # type: ignore[reportUknownMemberType]
return MemoryQueryResult(results=memories) # type: ignore[reportUknownMemberType]
async def clear(self) -> None:
"""Clear all entries from memory, preserving the RedisMemory resources."""
self.message_history.clear()
async def close(self) -> None:
"""Clears all entries from memory, and cleans up Redis client, index and resources."""
self.message_history.delete()