first commit
This commit is contained in:
1
agent_dhal/agentdhal_extensions/memory/__init__.py
Normal file
1
agent_dhal/agentdhal_extensions/memory/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from ._text_canvas import TextCanvas
|
||||
from ._text_canvas_memory import TextCanvasMemory
|
||||
|
||||
__all__ = ["TextCanvas", "TextCanvasMemory"]
|
||||
50
agent_dhal/agentdhal_extensions/memory/canvas/_canvas.py
Normal file
50
agent_dhal/agentdhal_extensions/memory/canvas/_canvas.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
|
||||
class BaseCanvas(ABC):
|
||||
"""
|
||||
An abstract protocol for "canvas" objects that maintain
|
||||
revision history for file-like data. Concrete subclasses
|
||||
can handle text, images, structured data, etc.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental API and may change in the future.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self) -> Dict[str, int]:
|
||||
"""
|
||||
Returns a dict of filename -> latest revision number.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_latest_content(self, filename: str) -> Union[str, bytes, Any]:
|
||||
"""
|
||||
Returns the latest version of a file's content.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_or_update_file(self, filename: str, new_content: Union[str, bytes, Any]) -> None:
|
||||
"""
|
||||
Creates or updates the file content with a new revision.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_diff(self, filename: str, from_revision: int, to_revision: int) -> str:
|
||||
"""
|
||||
Returns a diff (in some format) between two revisions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_patch(self, filename: str, patch_data: Union[str, bytes, Any]) -> None:
|
||||
"""
|
||||
Applies a patch/diff to the latest revision and increments the revision.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,64 @@
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._text_canvas import TextCanvas
|
||||
|
||||
|
||||
class UpdateFileArgs(BaseModel):
|
||||
filename: str
|
||||
new_content: str
|
||||
|
||||
|
||||
class UpdateFileResult(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class UpdateFileTool(BaseTool[UpdateFileArgs, UpdateFileResult]):
|
||||
"""
|
||||
Overwrites or creates a file in the canvas.
|
||||
"""
|
||||
|
||||
def __init__(self, canvas: TextCanvas):
|
||||
super().__init__(
|
||||
args_type=UpdateFileArgs,
|
||||
return_type=UpdateFileResult,
|
||||
name="update_file",
|
||||
description="Create/update a file on the canvas with the provided content.",
|
||||
)
|
||||
self._canvas = canvas
|
||||
|
||||
async def run(self, args: UpdateFileArgs, cancellation_token: CancellationToken) -> UpdateFileResult:
|
||||
self._canvas.add_or_update_file(args.filename, args.new_content)
|
||||
return UpdateFileResult(status="OK")
|
||||
|
||||
|
||||
class ApplyPatchArgs(BaseModel):
|
||||
filename: str
|
||||
patch_text: str
|
||||
|
||||
|
||||
class ApplyPatchResult(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class ApplyPatchTool(BaseTool[ApplyPatchArgs, ApplyPatchResult]):
|
||||
"""
|
||||
Applies a unified diff patch to the given file on the canvas.
|
||||
"""
|
||||
|
||||
def __init__(self, canvas: TextCanvas):
|
||||
super().__init__(
|
||||
args_type=ApplyPatchArgs,
|
||||
return_type=ApplyPatchResult,
|
||||
name="apply_patch",
|
||||
description=(
|
||||
"Apply a unified diff patch to an existing file on the canvas. "
|
||||
"The patch must be in diff/patch format. The file must exist or be created first."
|
||||
),
|
||||
)
|
||||
self._canvas = canvas
|
||||
|
||||
async def run(self, args: ApplyPatchArgs, cancellation_token: CancellationToken) -> ApplyPatchResult:
|
||||
self._canvas.apply_patch(args.filename, args.patch_text)
|
||||
return ApplyPatchResult(status="PATCH APPLIED")
|
||||
192
agent_dhal/agentdhal_extensions/memory/canvas/_text_canvas.py
Normal file
192
agent_dhal/agentdhal_extensions/memory/canvas/_text_canvas.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import difflib
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
try: # pragma: no cover
|
||||
from unidiff import PatchSet
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
PatchSet = None # type: ignore
|
||||
|
||||
from ._canvas import BaseCanvas
|
||||
|
||||
|
||||
class FileRevision:
|
||||
"""Tracks the history of one file's content."""
|
||||
|
||||
__slots__ = ("content", "revision")
|
||||
|
||||
def __init__(self, content: str, revision: int) -> None:
|
||||
self.content: str = content
|
||||
self.revision: int = revision # e.g. an integer, a timestamp, or git hash
|
||||
|
||||
|
||||
class TextCanvas(BaseCanvas):
|
||||
"""An in‑memory canvas that stores *text* files with full revision history.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental API and may change in the future.
|
||||
|
||||
Besides the original CRUD‑like operations, this enhanced implementation adds:
|
||||
|
||||
* **apply_patch** – applies patches using the ``unidiff`` library for accurate
|
||||
hunk application and context line validation.
|
||||
* **get_revision_content** – random access to any historical revision.
|
||||
* **get_revision_diffs** – obtain the list of diffs applied between every
|
||||
consecutive pair of revisions so that a caller can replay or audit the
|
||||
full change history.
|
||||
"""
|
||||
|
||||
# ----------------------------------------------------------------------------------
|
||||
# Construction helpers
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
def __init__(self) -> None:
|
||||
# For each file we keep an *ordered* list of FileRevision where the last
|
||||
# element is the most recent. Using a list keeps the memory footprint
|
||||
# small and preserves order without any extra bookkeeping.
|
||||
self._files: Dict[str, List[FileRevision]] = {}
|
||||
|
||||
# ----------------------------------------------------------------------------------
|
||||
# Internal utilities
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
def _latest_idx(self, filename: str) -> int:
|
||||
"""Return the index (not revision number) of the newest revision."""
|
||||
return len(self._files.get(filename, [])) - 1
|
||||
|
||||
def _ensure_file(self, filename: str) -> None:
|
||||
if filename not in self._files:
|
||||
raise ValueError(f"File '{filename}' does not exist on the canvas; create it first.")
|
||||
|
||||
# ----------------------------------------------------------------------------------
|
||||
# Revision inspection helpers
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
def get_revision_content(self, filename: str, revision: int) -> str: # NEW 🚀
|
||||
"""Return the exact content stored in *revision*.
|
||||
|
||||
If the revision does not exist an empty string is returned so that
|
||||
downstream code can handle the "not found" case without exceptions.
|
||||
"""
|
||||
for rev in self._files.get(filename, []):
|
||||
if rev.revision == revision:
|
||||
return rev.content
|
||||
return ""
|
||||
|
||||
def get_revision_diffs(self, filename: str) -> List[str]: # NEW 🚀
|
||||
"""Return a *chronological* list of unified‑diffs for *filename*.
|
||||
|
||||
Each element in the returned list represents the diff that transformed
|
||||
revision *n* into revision *n+1* (starting at revision 1 → 2).
|
||||
"""
|
||||
revisions = self._files.get(filename, [])
|
||||
diffs: List[str] = []
|
||||
for i in range(1, len(revisions)):
|
||||
older, newer = revisions[i - 1], revisions[i]
|
||||
diff = difflib.unified_diff(
|
||||
older.content.splitlines(keepends=True),
|
||||
newer.content.splitlines(keepends=True),
|
||||
fromfile=f"{filename}@r{older.revision}",
|
||||
tofile=f"{filename}@r{newer.revision}",
|
||||
)
|
||||
diffs.append("".join(diff))
|
||||
return diffs
|
||||
|
||||
# ----------------------------------------------------------------------------------
|
||||
# BaseCanvas interface implementation
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
def list_files(self) -> Dict[str, int]:
|
||||
"""Return a mapping of *filename → latest revision number*."""
|
||||
return {fname: revs[-1].revision for fname, revs in self._files.items() if revs}
|
||||
|
||||
def get_latest_content(self, filename: str) -> str: # noqa: D401 – keep API identical
|
||||
"""Return the most recent content or an empty string if the file is new."""
|
||||
revs = self._files.get(filename, [])
|
||||
return revs[-1].content if revs else ""
|
||||
|
||||
def add_or_update_file(self, filename: str, new_content: Union[str, bytes, Any]) -> None:
|
||||
"""Create *filename* or append a new revision containing *new_content*."""
|
||||
if isinstance(new_content, bytes):
|
||||
new_content = new_content.decode("utf-8")
|
||||
if not isinstance(new_content, str):
|
||||
raise ValueError(f"Expected str or bytes, got {type(new_content)}")
|
||||
if filename not in self._files:
|
||||
self._files[filename] = [FileRevision(new_content, 1)]
|
||||
else:
|
||||
last_rev_num = self._files[filename][-1].revision
|
||||
self._files[filename].append(FileRevision(new_content, last_rev_num + 1))
|
||||
|
||||
def get_diff(self, filename: str, from_revision: int, to_revision: int) -> str:
|
||||
"""Return a unified diff between *from_revision* and *to_revision*."""
|
||||
revisions = self._files.get(filename, [])
|
||||
if not revisions:
|
||||
return ""
|
||||
# Fetch the contents for the requested revisions.
|
||||
from_content = self.get_revision_content(filename, from_revision)
|
||||
to_content = self.get_revision_content(filename, to_revision)
|
||||
if from_content == "" and to_content == "": # one (or both) revision ids not found
|
||||
return ""
|
||||
diff = difflib.unified_diff(
|
||||
from_content.splitlines(keepends=True),
|
||||
to_content.splitlines(keepends=True),
|
||||
fromfile=f"{filename}@r{from_revision}",
|
||||
tofile=f"{filename}@r{to_revision}",
|
||||
)
|
||||
return "".join(diff)
|
||||
|
||||
def apply_patch(self, filename: str, patch_data: Union[str, bytes, Any]) -> None:
|
||||
"""Apply *patch_text* (unified diff) to the latest revision and save a new revision.
|
||||
|
||||
Uses the *unidiff* library to accurately apply hunks and validate context lines.
|
||||
"""
|
||||
if isinstance(patch_data, bytes):
|
||||
patch_data = patch_data.decode("utf-8")
|
||||
if not isinstance(patch_data, str):
|
||||
raise ValueError(f"Expected str or bytes, got {type(patch_data)}")
|
||||
self._ensure_file(filename)
|
||||
original_content = self.get_latest_content(filename)
|
||||
|
||||
if PatchSet is None:
|
||||
raise ImportError(
|
||||
"The 'unidiff' package is required for patch application. Install with 'pip install unidiff'."
|
||||
)
|
||||
|
||||
patch = PatchSet(patch_data)
|
||||
# Our canvas stores exactly one file per patch operation so we
|
||||
# use the first (and only) patched_file object.
|
||||
if not patch:
|
||||
raise ValueError("Empty patch text provided.")
|
||||
patched_file = patch[0]
|
||||
working_lines = original_content.splitlines(keepends=True)
|
||||
line_offset = 0
|
||||
for hunk in patched_file:
|
||||
# Calculate the slice boundaries in the *current* working copy.
|
||||
start = hunk.source_start - 1 + line_offset
|
||||
end = start + hunk.source_length
|
||||
# Build the replacement block for this hunk.
|
||||
replacement: List[str] = []
|
||||
for line in hunk:
|
||||
if line.is_added or line.is_context:
|
||||
replacement.append(line.value)
|
||||
# removed lines (line.is_removed) are *not* added.
|
||||
# Replace the slice with the hunk‑result.
|
||||
working_lines[start:end] = replacement
|
||||
line_offset += len(replacement) - (end - start)
|
||||
new_content = "".join(working_lines)
|
||||
|
||||
# Finally commit the new revision.
|
||||
self.add_or_update_file(filename, new_content)
|
||||
|
||||
# ----------------------------------------------------------------------------------
|
||||
# Convenience helpers
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
def get_all_contents_for_context(self) -> str: # noqa: D401 – keep public API stable
|
||||
"""Return a summarised view of every file and its *latest* revision."""
|
||||
out: List[str] = ["=== CANVAS FILES ==="]
|
||||
for fname, revs in self._files.items():
|
||||
latest = revs[-1]
|
||||
out.append(f"File: {fname} (rev {latest.revision}):\n{latest.content}\n")
|
||||
out.append("=== END OF CANVAS ===")
|
||||
return "\n".join(out)
|
||||
@@ -0,0 +1,229 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.memory import (
|
||||
Memory,
|
||||
MemoryContent,
|
||||
MemoryMimeType,
|
||||
MemoryQueryResult,
|
||||
UpdateContextResult,
|
||||
)
|
||||
from agentdhal_core.model_context import ChatCompletionContext
|
||||
from agentdhal_core.models import SystemMessage
|
||||
|
||||
from ._canvas_writer import ApplyPatchTool, UpdateFileTool
|
||||
from ._text_canvas import TextCanvas
|
||||
|
||||
|
||||
class TextCanvasMemory(Memory):
|
||||
"""
|
||||
A memory implementation that uses a Canvas for storing file-like content.
|
||||
Inserts the current state of the canvas into the ChatCompletionContext on each turn.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is an experimental API and may change in the future.
|
||||
|
||||
The TextCanvasMemory provides a persistent, file-like storage mechanism that can be used
|
||||
by agents to read and write content. It automatically injects the current state of all files
|
||||
in the canvas into the model context before each inference.
|
||||
|
||||
This is particularly useful for:
|
||||
- Allowing agents to create and modify documents over multiple turns
|
||||
- Enabling collaborative document editing between multiple agents
|
||||
- Maintaining persistent state across conversation turns
|
||||
- Working with content too large to fit in a single message
|
||||
|
||||
The canvas provides tools for:
|
||||
- Creating or updating files with new content
|
||||
- Applying patches (unified diff format) to existing files
|
||||
|
||||
Examples:
|
||||
|
||||
**Example: Using TextCanvasMemory with an AssistantAgent**
|
||||
|
||||
The following example demonstrates how to create a TextCanvasMemory and use it with
|
||||
an AssistantAgent to write and update a story file.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.messages import TextMessage
|
||||
from agentdhal_extensions.memory.canvas import TextCanvasMemory
|
||||
|
||||
|
||||
async def main():
|
||||
# Create a model client
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o",
|
||||
# api_key = "your_openai_api_key"
|
||||
)
|
||||
|
||||
# Create the canvas memory
|
||||
text_canvas_memory = TextCanvasMemory()
|
||||
|
||||
# Get tools for working with the canvas
|
||||
update_file_tool = text_canvas_memory.get_update_file_tool()
|
||||
apply_patch_tool = text_canvas_memory.get_apply_patch_tool()
|
||||
|
||||
# Create an agent with the canvas memory and tools
|
||||
writer_agent = AssistantAgent(
|
||||
name="Writer",
|
||||
model_client=model_client,
|
||||
description="A writer agent that creates and updates stories.",
|
||||
system_message='''
|
||||
You are a Writer Agent. Your focus is to generate a story based on the user's request.
|
||||
|
||||
Instructions for using the canvas:
|
||||
|
||||
- The story should be stored on the canvas in a file named "story.md".
|
||||
- If "story.md" does not exist, create it by calling the 'update_file' tool.
|
||||
- If "story.md" already exists, generate a unified diff (patch) from the current
|
||||
content to the new version, and call the 'apply_patch' tool to apply the changes.
|
||||
|
||||
IMPORTANT: Do not include the full story text in your chat messages.
|
||||
Only write the story content to the canvas using the tools.
|
||||
''',
|
||||
tools=[update_file_tool, apply_patch_tool],
|
||||
memory=[text_canvas_memory],
|
||||
)
|
||||
|
||||
# Send a message to the agent
|
||||
await writer_agent.on_messages(
|
||||
[TextMessage(content="Write a short story about a bunny and a sunflower.", source="user")],
|
||||
CancellationToken(),
|
||||
)
|
||||
|
||||
# Retrieve the content from the canvas
|
||||
story_content = text_canvas_memory.canvas.get_latest_content("story.md")
|
||||
print("Story content from canvas:")
|
||||
print(story_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
**Example: Using TextCanvasMemory with multiple agents**
|
||||
|
||||
The following example shows how to use TextCanvasMemory with multiple agents
|
||||
collaborating on the same document.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_extensions.memory.canvas import TextCanvasMemory
|
||||
|
||||
|
||||
async def main():
|
||||
# Create a model client
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o",
|
||||
# api_key = "your_openai_api_key"
|
||||
)
|
||||
|
||||
# Create the shared canvas memory
|
||||
text_canvas_memory = TextCanvasMemory()
|
||||
update_file_tool = text_canvas_memory.get_update_file_tool()
|
||||
apply_patch_tool = text_canvas_memory.get_apply_patch_tool()
|
||||
|
||||
# Create a writer agent
|
||||
writer_agent = AssistantAgent(
|
||||
name="Writer",
|
||||
model_client=model_client,
|
||||
description="A writer agent that creates stories.",
|
||||
system_message="You write children's stories on the canvas in story.md.",
|
||||
tools=[update_file_tool, apply_patch_tool],
|
||||
memory=[text_canvas_memory],
|
||||
)
|
||||
|
||||
# Create a critique agent
|
||||
critique_agent = AssistantAgent(
|
||||
name="Critique",
|
||||
model_client=model_client,
|
||||
description="A critique agent that provides feedback on stories.",
|
||||
system_message="You review the story.md file and provide constructive feedback.",
|
||||
memory=[text_canvas_memory],
|
||||
)
|
||||
|
||||
# Create a team with both agents
|
||||
team = RoundRobinGroupChat(
|
||||
participants=[writer_agent, critique_agent],
|
||||
termination_condition=TextMentionTermination("TERMINATE"),
|
||||
max_turns=10,
|
||||
)
|
||||
|
||||
# Run the team on a task
|
||||
await team.run(task="Create a children's book about a bunny and a sunflower")
|
||||
|
||||
# Get the final story
|
||||
story = text_canvas_memory.canvas.get_latest_content("story.md")
|
||||
print(story)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
def __init__(self, canvas: Optional[TextCanvas] = None):
|
||||
super().__init__()
|
||||
self.canvas = canvas if canvas is not None else TextCanvas()
|
||||
|
||||
async def update_context(self, model_context: ChatCompletionContext) -> UpdateContextResult:
|
||||
"""
|
||||
Inject the entire canvas summary (or a selected subset) as reference data.
|
||||
Here, we just put it into a system message, but you could customize.
|
||||
"""
|
||||
snapshot = self.canvas.get_all_contents_for_context()
|
||||
if snapshot.strip():
|
||||
msg = SystemMessage(content=snapshot)
|
||||
await model_context.add_message(msg)
|
||||
|
||||
# Return it for debugging/logging
|
||||
memory_content = MemoryContent(content=snapshot, mime_type=MemoryMimeType.TEXT)
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=[memory_content]))
|
||||
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
|
||||
|
||||
async def query(
|
||||
self, query: str | MemoryContent, cancellation_token: Optional[CancellationToken] = None, **kwargs: Any
|
||||
) -> MemoryQueryResult:
|
||||
"""
|
||||
Potentially search for matching filenames or file content.
|
||||
This example returns empty.
|
||||
"""
|
||||
return MemoryQueryResult(results=[])
|
||||
|
||||
async def add(self, content: MemoryContent, cancellation_token: Optional[CancellationToken] = None) -> None:
|
||||
"""
|
||||
Example usage: Possibly interpret content as a patch or direct file update.
|
||||
Could also be done by a specialized "CanvasTool" instead.
|
||||
"""
|
||||
# NO-OP here, leaving actual changes to the CanvasTool
|
||||
pass
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the entire canvas by replacing it with a new empty instance."""
|
||||
# Create a new TextCanvas instance instead of calling __init__ directly
|
||||
self.canvas = TextCanvas()
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
def get_update_file_tool(self) -> UpdateFileTool:
|
||||
"""
|
||||
Returns an UpdateFileTool instance that works with this memory's canvas.
|
||||
"""
|
||||
return UpdateFileTool(self.canvas)
|
||||
|
||||
def get_apply_patch_tool(self) -> ApplyPatchTool:
|
||||
"""
|
||||
Returns an ApplyPatchTool instance that works with this memory's canvas.
|
||||
"""
|
||||
return ApplyPatchTool(self.canvas)
|
||||
21
agent_dhal/agentdhal_extensions/memory/chromadb/__init__.py
Normal file
21
agent_dhal/agentdhal_extensions/memory/chromadb/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from ._chroma_configs import (
|
||||
ChromaDBVectorMemoryConfig,
|
||||
CustomEmbeddingFunctionConfig,
|
||||
DefaultEmbeddingFunctionConfig,
|
||||
HttpChromaDBVectorMemoryConfig,
|
||||
OpenAIEmbeddingFunctionConfig,
|
||||
PersistentChromaDBVectorMemoryConfig,
|
||||
SentenceTransformerEmbeddingFunctionConfig,
|
||||
)
|
||||
from ._chromadb import ChromaDBVectorMemory
|
||||
|
||||
__all__ = [
|
||||
"ChromaDBVectorMemory",
|
||||
"ChromaDBVectorMemoryConfig",
|
||||
"PersistentChromaDBVectorMemoryConfig",
|
||||
"HttpChromaDBVectorMemoryConfig",
|
||||
"DefaultEmbeddingFunctionConfig",
|
||||
"SentenceTransformerEmbeddingFunctionConfig",
|
||||
"OpenAIEmbeddingFunctionConfig",
|
||||
"CustomEmbeddingFunctionConfig",
|
||||
]
|
||||
@@ -0,0 +1,137 @@
|
||||
"""Configuration classes for ChromaDB vector memory."""
|
||||
|
||||
from typing import Any, Callable, Dict, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class DefaultEmbeddingFunctionConfig(BaseModel):
|
||||
"""Configuration for the default ChromaDB embedding function.
|
||||
|
||||
Uses ChromaDB's default embedding function (Sentence Transformers all-MiniLM-L6-v2).
|
||||
|
||||
.. versionadded:: v0.4.1
|
||||
Support for custom embedding functions in ChromaDB memory.
|
||||
"""
|
||||
|
||||
function_type: Literal["default"] = "default"
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingFunctionConfig(BaseModel):
|
||||
"""Configuration for SentenceTransformer embedding functions.
|
||||
|
||||
Allows specifying a custom SentenceTransformer model for embeddings.
|
||||
|
||||
.. versionadded:: v0.4.1
|
||||
Support for custom embedding functions in ChromaDB memory.
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the SentenceTransformer model to use.
|
||||
Defaults to "all-MiniLM-L6-v2".
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_extensions.memory.chromadb import SentenceTransformerEmbeddingFunctionConfig
|
||||
|
||||
_ = SentenceTransformerEmbeddingFunctionConfig(model_name="paraphrase-multilingual-mpnet-base-v2")
|
||||
"""
|
||||
|
||||
function_type: Literal["sentence_transformer"] = "sentence_transformer"
|
||||
model_name: str = Field(default="all-MiniLM-L6-v2", description="SentenceTransformer model name to use")
|
||||
|
||||
|
||||
class OpenAIEmbeddingFunctionConfig(BaseModel):
|
||||
"""Configuration for OpenAI embedding functions.
|
||||
|
||||
Uses OpenAI's embedding API for generating embeddings.
|
||||
|
||||
.. versionadded:: v0.4.1
|
||||
Support for custom embedding functions in ChromaDB memory.
|
||||
|
||||
Args:
|
||||
api_key (str): OpenAI API key. If empty, will attempt to use environment variable.
|
||||
model_name (str): OpenAI embedding model name. Defaults to "text-embedding-ada-002".
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from agentdhal_extensions.memory.chromadb import OpenAIEmbeddingFunctionConfig
|
||||
|
||||
_ = OpenAIEmbeddingFunctionConfig(api_key="sk-...", model_name="text-embedding-3-small")
|
||||
"""
|
||||
|
||||
function_type: Literal["openai"] = "openai"
|
||||
api_key: str = Field(default="", description="OpenAI API key")
|
||||
model_name: str = Field(default="text-embedding-ada-002", description="OpenAI embedding model name")
|
||||
|
||||
|
||||
class CustomEmbeddingFunctionConfig(BaseModel):
|
||||
"""Configuration for custom embedding functions.
|
||||
|
||||
Allows using a custom function that returns a ChromaDB-compatible embedding function.
|
||||
|
||||
.. versionadded:: v0.4.1
|
||||
Support for custom embedding functions in ChromaDB memory.
|
||||
|
||||
.. warning::
|
||||
Configurations containing custom functions are not serializable.
|
||||
|
||||
Args:
|
||||
function (Callable): Function that returns a ChromaDB-compatible embedding function.
|
||||
params (Dict[str, Any]): Parameters to pass to the function.
|
||||
"""
|
||||
|
||||
function_type: Literal["custom"] = "custom"
|
||||
function: Callable[..., Any] = Field(description="Function that returns an embedding function")
|
||||
params: Dict[str, Any] = Field(default_factory=dict, description="Parameters to pass to the function")
|
||||
|
||||
|
||||
# Tagged union type for embedding function configurations
|
||||
EmbeddingFunctionConfig = Annotated[
|
||||
Union[
|
||||
DefaultEmbeddingFunctionConfig,
|
||||
SentenceTransformerEmbeddingFunctionConfig,
|
||||
OpenAIEmbeddingFunctionConfig,
|
||||
CustomEmbeddingFunctionConfig,
|
||||
],
|
||||
Field(discriminator="function_type"),
|
||||
]
|
||||
|
||||
|
||||
class ChromaDBVectorMemoryConfig(BaseModel):
|
||||
"""Base configuration for ChromaDB-based memory implementation.
|
||||
|
||||
.. versionchanged:: v0.4.1
|
||||
Added support for custom embedding functions via embedding_function_config.
|
||||
"""
|
||||
|
||||
client_type: Literal["persistent", "http"]
|
||||
collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection")
|
||||
distance_metric: str = Field(default="cosine", description="Distance metric for similarity search")
|
||||
k: int = Field(default=3, description="Number of results to return in queries")
|
||||
score_threshold: float | None = Field(default=None, description="Minimum similarity score threshold")
|
||||
allow_reset: bool = Field(default=False, description="Whether to allow resetting the ChromaDB client")
|
||||
tenant: str = Field(default="default_tenant", description="Tenant to use")
|
||||
database: str = Field(default="default_database", description="Database to use")
|
||||
embedding_function_config: EmbeddingFunctionConfig = Field(
|
||||
default_factory=DefaultEmbeddingFunctionConfig, description="Configuration for the embedding function"
|
||||
)
|
||||
|
||||
|
||||
class PersistentChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
|
||||
"""Configuration for persistent ChromaDB memory."""
|
||||
|
||||
client_type: Literal["persistent", "http"] = "persistent"
|
||||
persistence_path: str = Field(default="./chroma_db", description="Path for persistent storage")
|
||||
|
||||
|
||||
class HttpChromaDBVectorMemoryConfig(ChromaDBVectorMemoryConfig):
|
||||
"""Configuration for HTTP ChromaDB memory."""
|
||||
|
||||
client_type: Literal["persistent", "http"] = "http"
|
||||
host: str = Field(default="localhost", description="Host of the remote server")
|
||||
port: int = Field(default=8000, description="Port of the remote server")
|
||||
ssl: bool = Field(default=False, description="Whether to use HTTPS")
|
||||
headers: Dict[str, str] | None = Field(default=None, description="Headers to send to the server")
|
||||
459
agent_dhal/agentdhal_extensions/memory/chromadb/_chromadb.py
Normal file
459
agent_dhal/agentdhal_extensions/memory/chromadb/_chromadb.py
Normal file
@@ -0,0 +1,459 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, List
|
||||
|
||||
from agentdhal_core import CancellationToken, Component, Image
|
||||
from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
|
||||
from agentdhal_core.model_context import ChatCompletionContext
|
||||
from agentdhal_core.models import SystemMessage
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.types import Document, Metadata
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._chroma_configs import (
|
||||
ChromaDBVectorMemoryConfig,
|
||||
CustomEmbeddingFunctionConfig,
|
||||
DefaultEmbeddingFunctionConfig,
|
||||
HttpChromaDBVectorMemoryConfig,
|
||||
OpenAIEmbeddingFunctionConfig,
|
||||
PersistentChromaDBVectorMemoryConfig,
|
||||
SentenceTransformerEmbeddingFunctionConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from chromadb.api import ClientAPI
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"To use the ChromaDBVectorMemory the chromadb extra must be installed. Run `pip install autogen-ext[chromadb]`"
|
||||
) from e
|
||||
|
||||
|
||||
class ChromaDBVectorMemory(Memory, Component[ChromaDBVectorMemoryConfig]):
|
||||
"""
|
||||
Store and retrieve memory using vector similarity search powered by ChromaDB.
|
||||
|
||||
`ChromaDBVectorMemory` provides a vector-based memory implementation that uses ChromaDB for
|
||||
storing and retrieving content based on semantic similarity. It enhances agents with the ability
|
||||
to recall contextually relevant information during conversations by leveraging vector embeddings
|
||||
to find similar content.
|
||||
|
||||
This implementation serves as a reference for more complex memory systems using vector embeddings.
|
||||
For advanced use cases requiring specialized formatting of retrieved content, users should extend
|
||||
this class and override the `update_context()` method.
|
||||
|
||||
This implementation requires the ChromaDB extra to be installed. Install with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[chromadb]"
|
||||
|
||||
Args:
|
||||
config (ChromaDBVectorMemoryConfig | None): Configuration for the ChromaDB memory.
|
||||
If None, defaults to a PersistentChromaDBVectorMemoryConfig with default values.
|
||||
Two config types are supported:
|
||||
* PersistentChromaDBVectorMemoryConfig: For local storage
|
||||
* HttpChromaDBVectorMemoryConfig: For connecting to a remote ChromaDB server
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_core.memory import MemoryContent, MemoryMimeType
|
||||
from agentdhal_extensions.memory.chromadb import (
|
||||
ChromaDBVectorMemory,
|
||||
PersistentChromaDBVectorMemoryConfig,
|
||||
SentenceTransformerEmbeddingFunctionConfig,
|
||||
OpenAIEmbeddingFunctionConfig,
|
||||
)
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
def get_weather(city: str) -> str:
|
||||
return f"The weather in {city} is sunny with a high of 90°F and a low of 70°F."
|
||||
|
||||
|
||||
def fahrenheit_to_celsius(fahrenheit: float) -> float:
|
||||
return (fahrenheit - 32) * 5.0 / 9.0
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Use default embedding function
|
||||
default_memory = ChromaDBVectorMemory(
|
||||
config=PersistentChromaDBVectorMemoryConfig(
|
||||
collection_name="user_preferences",
|
||||
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
|
||||
k=3, # Return top 3 results
|
||||
score_threshold=0.5, # Minimum similarity score
|
||||
)
|
||||
)
|
||||
|
||||
# Using a custom SentenceTransformer model
|
||||
custom_memory = ChromaDBVectorMemory(
|
||||
config=PersistentChromaDBVectorMemoryConfig(
|
||||
collection_name="multilingual_memory",
|
||||
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
|
||||
embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(
|
||||
model_name="paraphrase-multilingual-mpnet-base-v2"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Using OpenAI embeddings
|
||||
openai_memory = ChromaDBVectorMemory(
|
||||
config=PersistentChromaDBVectorMemoryConfig(
|
||||
collection_name="openai_memory",
|
||||
persistence_path=os.path.join(str(Path.home()), ".chromadb_autogen"),
|
||||
embedding_function_config=OpenAIEmbeddingFunctionConfig(
|
||||
api_key=os.environ["OPENAI_API_KEY"], model_name="text-embedding-3-small"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Add user preferences to memory
|
||||
await openai_memory.add(
|
||||
MemoryContent(
|
||||
content="The user prefers weather temperatures in Celsius",
|
||||
mime_type=MemoryMimeType.TEXT,
|
||||
metadata={"category": "preferences", "type": "units"},
|
||||
)
|
||||
)
|
||||
|
||||
# Create assistant agent with ChromaDB memory
|
||||
assistant = AssistantAgent(
|
||||
name="assistant",
|
||||
model_client=OpenAIChatCompletionClient(
|
||||
model="gpt-4.1",
|
||||
),
|
||||
tools=[
|
||||
get_weather,
|
||||
fahrenheit_to_celsius,
|
||||
],
|
||||
max_tool_iterations=10,
|
||||
memory=[openai_memory],
|
||||
)
|
||||
|
||||
# The memory will automatically retrieve relevant content during conversations
|
||||
await Console(assistant.run_stream(task="What's the temperature in New York?"))
|
||||
|
||||
# Remember to close the memory when finished
|
||||
await default_memory.close()
|
||||
await custom_memory.close()
|
||||
await openai_memory.close()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Output:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
---------- TextMessage (user) ----------
|
||||
What's the temperature in New York?
|
||||
---------- MemoryQueryEvent (assistant) ----------
|
||||
[MemoryContent(content='The user prefers weather temperatures in Celsius', mime_type='MemoryMimeType.TEXT', metadata={'type': 'units', 'category': 'preferences', 'mime_type': 'MemoryMimeType.TEXT', 'score': 0.3133561611175537, 'id': 'fb00506c-acf4-4174-93d7-2a942593f3f7'}), MemoryContent(content='The user prefers weather temperatures in Celsius', mime_type='MemoryMimeType.TEXT', metadata={'mime_type': 'MemoryMimeType.TEXT', 'category': 'preferences', 'type': 'units', 'score': 0.3133561611175537, 'id': '34311689-b419-4e1a-8bc4-09143f356c66'})]
|
||||
---------- ToolCallRequestEvent (assistant) ----------
|
||||
[FunctionCall(id='call_7TjsFd430J1aKwU5T2w8bvdh', arguments='{"city":"New York"}', name='get_weather')]
|
||||
---------- ToolCallExecutionEvent (assistant) ----------
|
||||
[FunctionExecutionResult(content='The weather in New York is sunny with a high of 90°F and a low of 70°F.', name='get_weather', call_id='call_7TjsFd430J1aKwU5T2w8bvdh', is_error=False)]
|
||||
---------- ToolCallRequestEvent (assistant) ----------
|
||||
[FunctionCall(id='call_RTjMHEZwDXtjurEYTjDlvq9c', arguments='{"fahrenheit": 90}', name='fahrenheit_to_celsius'), FunctionCall(id='call_3mMuCK1aqtzZPTqIHPoHKxtP', arguments='{"fahrenheit": 70}', name='fahrenheit_to_celsius')]
|
||||
---------- ToolCallExecutionEvent (assistant) ----------
|
||||
[FunctionExecutionResult(content='32.22222222222222', name='fahrenheit_to_celsius', call_id='call_RTjMHEZwDXtjurEYTjDlvq9c', is_error=False), FunctionExecutionResult(content='21.11111111111111', name='fahrenheit_to_celsius', call_id='call_3mMuCK1aqtzZPTqIHPoHKxtP', is_error=False)]
|
||||
---------- TextMessage (assistant) ----------
|
||||
The temperature in New York today is sunny with a high of about 32°C and a low of about 21°C.
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = ChromaDBVectorMemoryConfig
|
||||
component_provider_override = "agentdhal_extensions.memory.chromadb.ChromaDBVectorMemory"
|
||||
|
||||
def __init__(self, config: ChromaDBVectorMemoryConfig | None = None) -> None:
|
||||
self._config = config or PersistentChromaDBVectorMemoryConfig()
|
||||
self._client: ClientAPI | None = None
|
||||
self._collection: Collection | None = None
|
||||
|
||||
@property
|
||||
def collection_name(self) -> str:
|
||||
"""Get the name of the ChromaDB collection."""
|
||||
return self._config.collection_name
|
||||
|
||||
def _create_embedding_function(self) -> Any:
|
||||
"""Create an embedding function based on the configuration.
|
||||
|
||||
Returns:
|
||||
A ChromaDB-compatible embedding function.
|
||||
|
||||
Raises:
|
||||
ValueError: If the embedding function type is unsupported.
|
||||
ImportError: If required dependencies are not installed.
|
||||
"""
|
||||
try:
|
||||
from chromadb.utils import embedding_functions
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ChromaDB embedding functions not available. Ensure chromadb is properly installed."
|
||||
) from e
|
||||
|
||||
config = self._config.embedding_function_config
|
||||
|
||||
if isinstance(config, DefaultEmbeddingFunctionConfig):
|
||||
return embedding_functions.DefaultEmbeddingFunction()
|
||||
|
||||
elif isinstance(config, SentenceTransformerEmbeddingFunctionConfig):
|
||||
try:
|
||||
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=config.model_name)
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Failed to create SentenceTransformer embedding function with model '{config.model_name}'. "
|
||||
f"Ensure sentence-transformers is installed and the model is available. Error: {e}"
|
||||
) from e
|
||||
|
||||
elif isinstance(config, OpenAIEmbeddingFunctionConfig):
|
||||
try:
|
||||
return embedding_functions.OpenAIEmbeddingFunction(api_key=config.api_key, model_name=config.model_name)
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Failed to create OpenAI embedding function with model '{config.model_name}'. "
|
||||
f"Ensure openai is installed and API key is valid. Error: {e}"
|
||||
) from e
|
||||
|
||||
elif isinstance(config, CustomEmbeddingFunctionConfig):
|
||||
try:
|
||||
return config.function(**config.params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create custom embedding function. Error: {e}") from e
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding function config type: {type(config)}")
|
||||
|
||||
def _ensure_initialized(self) -> None:
|
||||
"""Ensure ChromaDB client and collection are initialized."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from chromadb.config import Settings
|
||||
|
||||
settings = Settings(allow_reset=self._config.allow_reset)
|
||||
|
||||
if isinstance(self._config, PersistentChromaDBVectorMemoryConfig):
|
||||
self._client = PersistentClient(
|
||||
path=self._config.persistence_path,
|
||||
settings=settings,
|
||||
tenant=self._config.tenant,
|
||||
database=self._config.database,
|
||||
)
|
||||
elif isinstance(self._config, HttpChromaDBVectorMemoryConfig):
|
||||
self._client = HttpClient(
|
||||
host=self._config.host,
|
||||
port=self._config.port,
|
||||
ssl=self._config.ssl,
|
||||
headers=self._config.headers,
|
||||
settings=settings,
|
||||
tenant=self._config.tenant,
|
||||
database=self._config.database,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported config type: {type(self._config)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ChromaDB client: {e}")
|
||||
raise
|
||||
|
||||
if self._collection is None:
|
||||
try:
|
||||
# Create embedding function
|
||||
embedding_function = self._create_embedding_function()
|
||||
|
||||
# Create or get collection with embedding function
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self._config.collection_name,
|
||||
metadata={"distance_metric": self._config.distance_metric},
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get/create collection: {e}")
|
||||
raise
|
||||
|
||||
def _extract_text(self, content_item: str | MemoryContent) -> str:
|
||||
"""Extract searchable text from content."""
|
||||
if isinstance(content_item, str):
|
||||
return content_item
|
||||
|
||||
content = content_item.content
|
||||
mime_type = content_item.mime_type
|
||||
|
||||
if mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]:
|
||||
return str(content)
|
||||
elif mime_type == MemoryMimeType.JSON:
|
||||
if isinstance(content, dict):
|
||||
# Store original JSON string representation
|
||||
return str(content).lower()
|
||||
raise ValueError("JSON content must be a dict")
|
||||
elif isinstance(content, Image):
|
||||
raise ValueError("Image content cannot be converted to text")
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {mime_type}")
|
||||
|
||||
def _calculate_score(self, distance: float) -> float:
|
||||
"""Convert ChromaDB distance to a similarity score."""
|
||||
if self._config.distance_metric == "cosine":
|
||||
return 1.0 - (distance / 2.0)
|
||||
return 1.0 / (1.0 + distance)
|
||||
|
||||
async def update_context(
|
||||
self,
|
||||
model_context: ChatCompletionContext,
|
||||
) -> UpdateContextResult:
|
||||
messages = await model_context.get_messages()
|
||||
if not messages:
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
|
||||
|
||||
# Extract query from last message
|
||||
last_message = messages[-1]
|
||||
query_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
|
||||
|
||||
# Query memory and get results
|
||||
query_results = await self.query(query_text)
|
||||
|
||||
if query_results.results:
|
||||
# Format results for context
|
||||
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)]
|
||||
memory_context = "\nRelevant memory content:\n" + "\n".join(memory_strings)
|
||||
|
||||
# Add to context
|
||||
await model_context.add_message(SystemMessage(content=memory_context))
|
||||
|
||||
return UpdateContextResult(memories=query_results)
|
||||
|
||||
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
|
||||
self._ensure_initialized()
|
||||
if self._collection is None:
|
||||
raise RuntimeError("Failed to initialize ChromaDB")
|
||||
|
||||
try:
|
||||
# Extract text from content
|
||||
text = self._extract_text(content)
|
||||
|
||||
# Use metadata directly from content
|
||||
metadata_dict = content.metadata or {}
|
||||
metadata_dict["mime_type"] = str(content.mime_type)
|
||||
|
||||
# Add to ChromaDB
|
||||
self._collection.add(documents=[text], metadatas=[metadata_dict], ids=[str(uuid.uuid4())])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add content to ChromaDB: {e}")
|
||||
raise
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str | MemoryContent,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MemoryQueryResult:
|
||||
self._ensure_initialized()
|
||||
if self._collection is None:
|
||||
raise RuntimeError("Failed to initialize ChromaDB")
|
||||
|
||||
try:
|
||||
# Extract text for query
|
||||
query_text = self._extract_text(query)
|
||||
|
||||
# Query ChromaDB
|
||||
results = self._collection.query(
|
||||
query_texts=[query_text],
|
||||
n_results=self._config.k,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Convert results to MemoryContent list
|
||||
memory_results: List[MemoryContent] = []
|
||||
|
||||
if (
|
||||
not results
|
||||
or not results.get("documents")
|
||||
or not results.get("metadatas")
|
||||
or not results.get("distances")
|
||||
):
|
||||
return MemoryQueryResult(results=memory_results)
|
||||
|
||||
documents: List[Document] = results["documents"][0] if results["documents"] else []
|
||||
metadatas: List[Metadata] = results["metadatas"][0] if results["metadatas"] else []
|
||||
distances: List[float] = results["distances"][0] if results["distances"] else []
|
||||
ids: List[str] = results["ids"][0] if results["ids"] else []
|
||||
|
||||
for doc, metadata_dict, distance, doc_id in zip(documents, metadatas, distances, ids, strict=False):
|
||||
# Calculate score
|
||||
score = self._calculate_score(distance)
|
||||
metadata = dict(metadata_dict)
|
||||
metadata["score"] = score
|
||||
metadata["id"] = doc_id
|
||||
if self._config.score_threshold is not None and score < self._config.score_threshold:
|
||||
continue
|
||||
|
||||
# Extract mime_type from metadata
|
||||
mime_type = str(metadata_dict.get("mime_type", MemoryMimeType.TEXT.value))
|
||||
|
||||
# Create MemoryContent
|
||||
content = MemoryContent(
|
||||
content=doc,
|
||||
mime_type=mime_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
memory_results.append(content)
|
||||
|
||||
return MemoryQueryResult(results=memory_results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query ChromaDB: {e}")
|
||||
raise
|
||||
|
||||
async def clear(self) -> None:
|
||||
self._ensure_initialized()
|
||||
if self._collection is None:
|
||||
raise RuntimeError("Failed to initialize ChromaDB")
|
||||
|
||||
try:
|
||||
results = self._collection.get()
|
||||
if results and results["ids"]:
|
||||
self._collection.delete(ids=results["ids"])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear ChromaDB collection: {e}")
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up ChromaDB client and resources."""
|
||||
self._collection = None
|
||||
self._client = None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._ensure_initialized()
|
||||
if not self._config.allow_reset:
|
||||
raise RuntimeError("Reset not allowed. Set allow_reset=True in config to enable.")
|
||||
|
||||
if self._client is not None:
|
||||
try:
|
||||
self._client.reset()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ChromaDB reset: {e}")
|
||||
finally:
|
||||
self._collection = None
|
||||
|
||||
def _to_config(self) -> ChromaDBVectorMemoryConfig:
|
||||
"""Serialize the memory configuration."""
|
||||
|
||||
return self._config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: ChromaDBVectorMemoryConfig) -> Self:
|
||||
"""Deserialize the memory configuration."""
|
||||
|
||||
return cls(config=config)
|
||||
6
agent_dhal/agentdhal_extensions/memory/mem0/__init__.py
Normal file
6
agent_dhal/agentdhal_extensions/memory/mem0/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from ._mem0 import Mem0Memory, Mem0MemoryConfig
|
||||
|
||||
__all__ = [
|
||||
"Mem0Memory",
|
||||
"Mem0MemoryConfig",
|
||||
]
|
||||
449
agent_dhal/agentdhal_extensions/memory/mem0/_mem0.py
Normal file
449
agent_dhal/agentdhal_extensions/memory/mem0/_mem0.py
Normal file
@@ -0,0 +1,449 @@
|
||||
import io
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, TypedDict, cast
|
||||
|
||||
from agentdhal_core import CancellationToken, Component, ComponentBase
|
||||
from agentdhal_core.memory import Memory, MemoryContent, MemoryQueryResult, UpdateContextResult
|
||||
from agentdhal_core.model_context import ChatCompletionContext
|
||||
from agentdhal_core.models import SystemMessage
|
||||
from mem0 import Memory as Memory0
|
||||
from mem0 import MemoryClient
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger("chromadb").setLevel(logging.ERROR)
|
||||
|
||||
|
||||
class Mem0MemoryConfig(BaseModel):
|
||||
"""Configuration for Mem0Memory component."""
|
||||
|
||||
user_id: Optional[str] = Field(
|
||||
default=None, description="User ID for memory operations. If not provided, a UUID will be generated."
|
||||
)
|
||||
limit: int = Field(default=10, description="Maximum number of results to return in memory queries.")
|
||||
is_cloud: bool = Field(default=True, description="Whether to use cloud Mem0 client (True) or local client (False).")
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, description="API key for cloud Mem0 client. Required if is_cloud=True."
|
||||
)
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Configuration dictionary for local Mem0 client. Required if is_cloud=False."
|
||||
)
|
||||
|
||||
|
||||
class MemoryResult(TypedDict, total=False):
|
||||
memory: str
|
||||
score: float
|
||||
metadata: Dict[str, Any]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
categories: List[str]
|
||||
|
||||
|
||||
# pyright: reportGeneralTypeIssues=false
|
||||
class Mem0Memory(Memory, Component[Mem0MemoryConfig], ComponentBase[Mem0MemoryConfig]):
|
||||
"""Mem0 memory implementation for AutoGen.
|
||||
|
||||
This component integrates with Mem0.ai's memory system, providing an implementation
|
||||
of AutoGen's Memory interface. It supports both cloud and local backends through the
|
||||
mem0ai Python package.
|
||||
|
||||
To use this component, you need to have the `mem0` (for cloud-only) or `mem0-local` (for local)
|
||||
extra installed for the `autogen-ext` package:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-ext[mem0]" # For cloud-based Mem0
|
||||
pip install -U "agentdhal-ext[mem0-local]" # For local Mem0
|
||||
|
||||
The memory component can store and retrieve information that agents need to remember
|
||||
across conversations. It also provides context updating for language models with
|
||||
relevant memories.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.memory.mem0 import Mem0Memory
|
||||
from agentdhal_core.memory import MemoryContent
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create a local Mem0Memory (no API key required)
|
||||
memory = Mem0Memory(
|
||||
is_cloud=False,
|
||||
config={"path": ":memory:"}, # Use in-memory storage for testing
|
||||
)
|
||||
print("Memory initialized successfully!")
|
||||
|
||||
# Add something to memory
|
||||
test_content = "User likes the color blue."
|
||||
await memory.add(MemoryContent(content=test_content, mime_type="text/plain"))
|
||||
print(f"Added content: {test_content}")
|
||||
|
||||
# Retrieve memories with a search query
|
||||
results = await memory.query("What color does the user like?")
|
||||
print(f"Query results: {len(results.results)} found")
|
||||
|
||||
for i, result in enumerate(results.results):
|
||||
print(f"Result {i+1}: {result}")
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Output:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
Memory initialized successfully!
|
||||
Added content: User likes the color blue.
|
||||
Query results: 1 found
|
||||
Result 1: content='User likes the color blue' mime_type='text/plain' metadata={'score': 0.6977155806281953, 'created_at': datetime.datetime(2025, 7, 6, 17, 25, 18, 754725, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200)))}
|
||||
|
||||
Using it with an :class:`~agentdhal_agentchat.agents.AssistantAgent`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_core.memory import MemoryContent
|
||||
from agentdhal_extensions.memory.mem0 import Mem0Memory
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create a model client
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4.1")
|
||||
|
||||
# Create a Mem0 memory instance
|
||||
memory = Mem0Memory(
|
||||
user_id="user123",
|
||||
is_cloud=False,
|
||||
config={"path": ":memory:"}, # Use in-memory storage for testing
|
||||
)
|
||||
|
||||
# Add something to memory
|
||||
test_content = "User likes the color blue."
|
||||
await memory.add(MemoryContent(content=test_content, mime_type="text/plain"))
|
||||
|
||||
# Create an assistant agent with Mem0 memory
|
||||
agent = AssistantAgent(
|
||||
name="assistant",
|
||||
model_client=model_client,
|
||||
memory=[memory],
|
||||
system_message="You are a helpful assistant that remembers user preferences.",
|
||||
)
|
||||
|
||||
# Run a sample task
|
||||
result = await agent.run(task="What color does the user like?")
|
||||
print(result.messages[-1].content) # type: ignore
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Output:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
User likes the color blue.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for memory operations. If not provided, a UUID will be generated.
|
||||
limit: Maximum number of results to return in memory queries.
|
||||
is_cloud: Whether to use cloud Mem0 client (True) or local client (False).
|
||||
api_key: API key for cloud Mem0 client. It will read from the environment MEM0_API_KEY if not provided.
|
||||
config: Configuration dictionary for local Mem0 client. Required if is_cloud=False.
|
||||
"""
|
||||
|
||||
component_type = "memory"
|
||||
component_provider_override = "agentdhal_extensions.memory.mem0.Mem0Memory"
|
||||
component_config_schema = Mem0MemoryConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
is_cloud: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
# Validate parameters
|
||||
if not is_cloud and config is None:
|
||||
raise ValueError("config is required when using local Mem0 client (is_cloud=False)")
|
||||
|
||||
# Initialize instance variables
|
||||
self._user_id = user_id or str(uuid.uuid4())
|
||||
self._limit = limit
|
||||
self._is_cloud = is_cloud
|
||||
self._api_key = api_key
|
||||
self._config = config
|
||||
|
||||
# Initialize client
|
||||
if self._is_cloud:
|
||||
self._client = MemoryClient(api_key=self._api_key)
|
||||
else:
|
||||
assert self._config is not None
|
||||
config_dict = self._config
|
||||
self._client = Memory0.from_config(config_dict=config_dict) # type: ignore
|
||||
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
"""Get the user ID for memory operations."""
|
||||
return self._user_id
|
||||
|
||||
@property
|
||||
def limit(self) -> int:
|
||||
"""Get the maximum number of results to return in memory queries."""
|
||||
return self._limit
|
||||
|
||||
@property
|
||||
def is_cloud(self) -> bool:
|
||||
"""Check if the Mem0 client is cloud-based."""
|
||||
return self._is_cloud
|
||||
|
||||
@property
|
||||
def config(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get the configuration for the Mem0 client."""
|
||||
return self._config
|
||||
|
||||
async def add(
|
||||
self,
|
||||
content: MemoryContent,
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> None:
|
||||
"""Add content to memory.
|
||||
|
||||
Args:
|
||||
content: The memory content to add.
|
||||
cancellation_token: Optional token to cancel operation.
|
||||
|
||||
Raises:
|
||||
Exception: If there's an error adding content to mem0 memory.
|
||||
"""
|
||||
# Extract content based on mime type
|
||||
if hasattr(content, "content") and hasattr(content, "mime_type"):
|
||||
if content.mime_type in ["text/plain", "text/markdown"]:
|
||||
message = str(content.content)
|
||||
elif content.mime_type == "application/json":
|
||||
# Convert JSON content to string representation
|
||||
if isinstance(content.content, str):
|
||||
message = content.content
|
||||
else:
|
||||
# Convert dict or other JSON serializable objects to string
|
||||
import json
|
||||
|
||||
message = json.dumps(content.content)
|
||||
else:
|
||||
message = str(content.content)
|
||||
|
||||
# Extract metadata
|
||||
metadata = content.metadata or {}
|
||||
else:
|
||||
# Handle case where content is directly provided as string
|
||||
message = str(content)
|
||||
metadata = {}
|
||||
|
||||
# Check if operation is cancelled
|
||||
if cancellation_token is not None and cancellation_token.cancelled: # type: ignore
|
||||
return
|
||||
|
||||
# Add to mem0 client
|
||||
try:
|
||||
user_id = metadata.pop("user_id", self._user_id)
|
||||
# Suppress warning messages from mem0 MemoryClient
|
||||
kwargs = {} if self._client.__class__.__name__ == "Memory" else {"output_format": "v1.1"}
|
||||
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
|
||||
self._client.add([{"role": "user", "content": message}], user_id=user_id, metadata=metadata, **kwargs) # type: ignore
|
||||
except Exception as e:
|
||||
# Log the error but don't crash
|
||||
logger.error(f"Error adding to mem0 memory: {str(e)}")
|
||||
raise
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str | MemoryContent = "",
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
**kwargs: Any,
|
||||
) -> MemoryQueryResult:
|
||||
"""Query memory for relevant content.
|
||||
|
||||
Args:
|
||||
query: The query to search for, either as string or MemoryContent.
|
||||
cancellation_token: Optional token to cancel operation.
|
||||
**kwargs: Additional query parameters to pass to mem0.
|
||||
|
||||
Returns:
|
||||
MemoryQueryResult containing search results.
|
||||
"""
|
||||
# Extract query text
|
||||
if isinstance(query, str):
|
||||
query_text = query
|
||||
elif hasattr(query, "content"):
|
||||
query_text = str(query.content)
|
||||
else:
|
||||
query_text = str(query)
|
||||
|
||||
# Check if operation is cancelled
|
||||
if (
|
||||
cancellation_token
|
||||
and hasattr(cancellation_token, "cancelled")
|
||||
and getattr(cancellation_token, "cancelled", False)
|
||||
):
|
||||
return MemoryQueryResult(results=[])
|
||||
|
||||
try:
|
||||
limit = kwargs.pop("limit", self._limit)
|
||||
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
|
||||
# Query mem0 client
|
||||
results = self._client.search( # type: ignore
|
||||
query_text,
|
||||
user_id=self._user_id,
|
||||
limit=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Type-safe handling of results
|
||||
if isinstance(results, dict) and "results" in results:
|
||||
result_list = cast(List[MemoryResult], results["results"])
|
||||
else:
|
||||
result_list = cast(List[MemoryResult], results)
|
||||
|
||||
# Convert results to MemoryContent objects
|
||||
memory_contents: List[MemoryContent] = []
|
||||
for result in result_list:
|
||||
content_text = result.get("memory", "")
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
if "metadata" in result and result["metadata"]:
|
||||
metadata = result["metadata"]
|
||||
|
||||
# Add relevant fields to metadata
|
||||
if "score" in result:
|
||||
metadata["score"] = result["score"]
|
||||
|
||||
# For created_at
|
||||
if "created_at" in result and result.get("created_at"):
|
||||
try:
|
||||
metadata["created_at"] = datetime.fromisoformat(result["created_at"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# For updated_at
|
||||
if "updated_at" in result and result.get("updated_at"):
|
||||
try:
|
||||
metadata["updated_at"] = datetime.fromisoformat(result["updated_at"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# For categories
|
||||
if "categories" in result and result.get("categories"):
|
||||
metadata["categories"] = result["categories"]
|
||||
|
||||
# Create MemoryContent object
|
||||
memory_content = MemoryContent(
|
||||
content=content_text,
|
||||
mime_type="text/plain", # Default to text/plain
|
||||
metadata=metadata,
|
||||
)
|
||||
memory_contents.append(memory_content)
|
||||
|
||||
return MemoryQueryResult(results=memory_contents)
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but return empty results
|
||||
logger.error(f"Error querying mem0 memory: {str(e)}")
|
||||
return MemoryQueryResult(results=[])
|
||||
|
||||
async def update_context(
|
||||
self,
|
||||
model_context: ChatCompletionContext,
|
||||
) -> UpdateContextResult:
|
||||
"""Update the model context with relevant memories.
|
||||
|
||||
This method retrieves the conversation history from the model context,
|
||||
uses the last message as a query to find relevant memories, and then
|
||||
adds those memories to the context as a system message.
|
||||
|
||||
Args:
|
||||
model_context: The model context to update.
|
||||
|
||||
Returns:
|
||||
UpdateContextResult containing memories added to the context.
|
||||
"""
|
||||
# Get messages from context
|
||||
messages = await model_context.get_messages()
|
||||
if not messages:
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
|
||||
|
||||
# Use the last message as query
|
||||
last_message = messages[-1]
|
||||
query_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
|
||||
|
||||
# Query memory
|
||||
query_results = await self.query(query_text, limit=self._limit)
|
||||
|
||||
# If we have results, add them to the context
|
||||
if query_results.results:
|
||||
# Format memories as numbered list
|
||||
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)]
|
||||
memory_context = "\nRelevant memories:\n" + "\n".join(memory_strings)
|
||||
|
||||
# Add as system message
|
||||
await model_context.add_message(SystemMessage(content=memory_context))
|
||||
|
||||
return UpdateContextResult(memories=query_results)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all content from memory for the current user.
|
||||
|
||||
Raises:
|
||||
Exception: If there's an error clearing mem0 memory.
|
||||
"""
|
||||
try:
|
||||
self._client.delete_all(user_id=self._user_id) # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing mem0 memory: {str(e)}")
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources if needed.
|
||||
|
||||
This is a no-op for Mem0 clients as they don't require explicit cleanup.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: Mem0MemoryConfig) -> Self:
|
||||
"""Create instance from configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration for Mem0Memory component.
|
||||
|
||||
Returns:
|
||||
A new Mem0Memory instance.
|
||||
"""
|
||||
return cls(
|
||||
user_id=config.user_id,
|
||||
limit=config.limit,
|
||||
is_cloud=config.is_cloud,
|
||||
api_key=config.api_key,
|
||||
config=config.config,
|
||||
)
|
||||
|
||||
def _to_config(self) -> Mem0MemoryConfig:
|
||||
"""Convert instance to configuration.
|
||||
|
||||
Returns:
|
||||
Configuration representing this Mem0Memory instance.
|
||||
"""
|
||||
return Mem0MemoryConfig(
|
||||
user_id=self._user_id,
|
||||
limit=self._limit,
|
||||
is_cloud=self._is_cloud,
|
||||
api_key=self._api_key,
|
||||
config=self._config,
|
||||
)
|
||||
9
agent_dhal/agentdhal_extensions/memory/redis/__init__.py
Normal file
9
agent_dhal/agentdhal_extensions/memory/redis/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from ._redis_memory import (
|
||||
RedisMemory,
|
||||
RedisMemoryConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RedisMemoryConfig",
|
||||
"RedisMemory",
|
||||
]
|
||||
325
agent_dhal/agentdhal_extensions/memory/redis/_redis_memory.py
Normal file
325
agent_dhal/agentdhal_extensions/memory/redis/_redis_memory.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import logging
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from agentdhal_core import CancellationToken, Component
|
||||
from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
|
||||
from agentdhal_core.model_context import ChatCompletionContext
|
||||
from agentdhal_core.models import SystemMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from redis import Redis
|
||||
from redisvl.extensions.message_history import SemanticMessageHistory
|
||||
from redisvl.utils.utils import deserialize, serialize
|
||||
except ImportError as e:
|
||||
raise ImportError("To use Redis Memory RedisVL must be installed. Run `pip install autogen-ext[redisvl]`") from e
|
||||
|
||||
|
||||
class RedisMemoryConfig(BaseModel):
|
||||
"""
|
||||
Configuration for Redis-based vector memory.
|
||||
|
||||
This class defines the configuration options for using Redis as a vector memory store,
|
||||
supporting semantic memory. It allows customization of the Redis connection, index settings,
|
||||
similarity search parameters, and embedding model.
|
||||
"""
|
||||
|
||||
redis_url: str = Field(default="redis://localhost:6379", description="url of the Redis instance")
|
||||
index_name: str = Field(default="chat_history", description="Name of the Redis collection")
|
||||
prefix: str = Field(default="memory", description="prefix of the Redis collection")
|
||||
distance_metric: Literal["cosine", "ip", "l2"] = "cosine"
|
||||
algorithm: Literal["flat", "hnsw"] = "flat"
|
||||
top_k: int = Field(default=10, description="Number of results to return in queries")
|
||||
datatype: Literal["uint8", "int8", "float16", "float32", "float64", "bfloat16"] = "float32"
|
||||
distance_threshold: float = Field(default=0.7, description="Minimum similarity score threshold")
|
||||
model_name: str | None = Field(
|
||||
default="sentence-transformers/all-mpnet-base-v2", description="Embedding model name"
|
||||
)
|
||||
|
||||
|
||||
class RedisMemory(Memory, Component[RedisMemoryConfig]):
|
||||
"""
|
||||
Store and retrieve memory using vector similarity search powered by RedisVL.
|
||||
|
||||
`RedisMemory` provides a vector-based memory implementation that uses RedisVL for storing and
|
||||
retrieving content based on semantic similarity. It enhances agents with the ability to recall
|
||||
contextually relevant information during conversations by leveraging vector embeddings to find
|
||||
similar content.
|
||||
|
||||
This implementation requires the RedisVL extra to be installed. Install with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[redisvl]"
|
||||
|
||||
Additionally, you will need access to a Redis instance.
|
||||
To run a local instance of redis in docker:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker run -d --name redis -p 6379:6379 redis:8
|
||||
|
||||
To download and run Redis locally:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
|
||||
echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list
|
||||
sudo apt-get update > /dev/null 2>&1
|
||||
sudo apt-get install redis-server > /dev/null 2>&1
|
||||
redis-server --daemonize yes
|
||||
|
||||
Args:
|
||||
config (RedisMemoryConfig | None): Configuration for the Redis memory.
|
||||
If None, defaults to a RedisMemoryConfig with recommended settings.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from logging import WARNING, getLogger
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_core.memory import MemoryContent, MemoryMimeType
|
||||
from agentdhal_extensions.memory.redis import RedisMemory, RedisMemoryConfig
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
|
||||
logger = getLogger()
|
||||
logger.setLevel(WARNING)
|
||||
|
||||
|
||||
# Define tool to use
|
||||
async def get_weather(city: str, units: str = "imperial") -> str:
|
||||
if units == "imperial":
|
||||
return f"The weather in {city} is 73 °F and Sunny."
|
||||
elif units == "metric":
|
||||
return f"The weather in {city} is 23 °C and Sunny."
|
||||
else:
|
||||
return f"Sorry, I don't know the weather in {city}."
|
||||
|
||||
|
||||
async def main():
|
||||
# Initailize Redis memory
|
||||
redis_memory = RedisMemory(
|
||||
config=RedisMemoryConfig(
|
||||
redis_url="redis://localhost:6379",
|
||||
index_name="chat_history",
|
||||
prefix="memory",
|
||||
)
|
||||
)
|
||||
|
||||
# Add user preferences to memory
|
||||
await redis_memory.add(
|
||||
MemoryContent(
|
||||
content="The weather should be in metric units",
|
||||
mime_type=MemoryMimeType.TEXT,
|
||||
metadata={"category": "preferences", "type": "units"},
|
||||
)
|
||||
)
|
||||
|
||||
await redis_memory.add(
|
||||
MemoryContent(
|
||||
content="Meal recipe must be vegan",
|
||||
mime_type=MemoryMimeType.TEXT,
|
||||
metadata={"category": "preferences", "type": "dietary"},
|
||||
)
|
||||
)
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
# Create assistant agent with ChromaDB memory
|
||||
assistant_agent = AssistantAgent(
|
||||
name="assistant_agent",
|
||||
model_client=model_client,
|
||||
tools=[get_weather],
|
||||
memory=[redis_memory],
|
||||
)
|
||||
|
||||
stream = assistant_agent.run_stream(task="What is the weather in New York?")
|
||||
await Console(stream)
|
||||
|
||||
await model_client.close()
|
||||
await redis_memory.close()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Output:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
---------- TextMessage (user) ----------
|
||||
What is the weather in New York?
|
||||
---------- MemoryQueryEvent (assistant_agent) ----------
|
||||
[MemoryContent(content='The weather should be in metric units', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata={'category': 'preferences', 'type': 'units'})]
|
||||
---------- ToolCallRequestEvent (assistant_agent) ----------
|
||||
[FunctionCall(id='call_tyCPvPPAV4SHWhtfpM6UMemr', arguments='{"city":"New York","units":"metric"}', name='get_weather')]
|
||||
---------- ToolCallExecutionEvent (assistant_agent) ----------
|
||||
[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', name='get_weather', call_id='call_tyCPvPPAV4SHWhtfpM6UMemr', is_error=False)]
|
||||
---------- ToolCallSummaryMessage (assistant_agent) ----------
|
||||
The weather in New York is 23 °C and Sunny.
|
||||
|
||||
"""
|
||||
|
||||
component_config_schema = RedisMemoryConfig
|
||||
component_provider_override = "agentdhal_extensions.memory.redis_memory.RedisMemory"
|
||||
|
||||
def __init__(self, config: RedisMemoryConfig | None = None) -> None:
|
||||
"""Initialize RedisMemory."""
|
||||
self.config = config or RedisMemoryConfig()
|
||||
client = Redis.from_url(url=self.config.redis_url) # type: ignore[reportUknownMemberType]
|
||||
|
||||
self.message_history = SemanticMessageHistory(name=self.config.index_name, redis_client=client)
|
||||
|
||||
async def update_context(
|
||||
self,
|
||||
model_context: ChatCompletionContext,
|
||||
) -> UpdateContextResult:
|
||||
"""
|
||||
Update the model context with relevant memory content.
|
||||
|
||||
This method retrieves memory content relevant to the last message in the context
|
||||
and adds it as a system message. This implementation uses the last message in the context
|
||||
as a query to find semantically similar memories and adds them all to the context as a
|
||||
single system message.
|
||||
|
||||
Args:
|
||||
model_context (ChatCompletionContext): The model context to update with relevant
|
||||
memories.
|
||||
|
||||
Returns:
|
||||
UpdateContextResult: Object containing the memories that were used to update the
|
||||
context.
|
||||
"""
|
||||
messages = await model_context.get_messages()
|
||||
if messages:
|
||||
last_message = str(messages[-1].content)
|
||||
else:
|
||||
last_message = ""
|
||||
|
||||
query_results = await self.query(last_message)
|
||||
|
||||
stringified_messages = "\n\n".join([str(m.content) for m in query_results.results])
|
||||
|
||||
await model_context.add_message(SystemMessage(content=stringified_messages))
|
||||
|
||||
return UpdateContextResult(memories=query_results)
|
||||
|
||||
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
|
||||
"""Add a memory content object to Redis.
|
||||
|
||||
.. note::
|
||||
|
||||
To perform semantic search over stored memories RedisMemory creates a vector embedding
|
||||
from the content field of a MemoryContent object. This content is assumed to be text,
|
||||
JSON, or Markdown, and is passed to the vector embedding model specified in
|
||||
RedisMemoryConfig.
|
||||
|
||||
Args:
|
||||
content (MemoryContent): The memory content to store within Redis.
|
||||
cancellation_token (CancellationToken): Token passed to cease operation. Not used.
|
||||
"""
|
||||
if content.mime_type == MemoryMimeType.TEXT:
|
||||
memory_content = content.content
|
||||
mime_type = "text/plain"
|
||||
elif content.mime_type == MemoryMimeType.JSON:
|
||||
memory_content = serialize(content.content)
|
||||
mime_type = "application/json"
|
||||
elif content.mime_type == MemoryMimeType.MARKDOWN:
|
||||
memory_content = content.content
|
||||
mime_type = "text/markdown"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Error: {content.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, and MemoryMimeType.MARKDOWN are currently supported."
|
||||
)
|
||||
metadata = {"mime_type": mime_type}
|
||||
metadata.update(content.metadata if content.metadata else {})
|
||||
self.message_history.add_message(
|
||||
{"role": "user", "content": memory_content, "tool_call_id": serialize(metadata)} # type: ignore[reportArgumentType]
|
||||
)
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str | MemoryContent,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MemoryQueryResult:
|
||||
"""Query memory content based on semantic vector similarity.
|
||||
|
||||
.. note::
|
||||
|
||||
RedisMemory.query() supports additional keyword arguments to improve query performance.
|
||||
top_k (int): The maximum number of relevant memories to include. Defaults to 10.
|
||||
distance_threshold (float): The maximum distance in vector space to consider a memory
|
||||
semantically similar when performining cosine similarity search. Defaults to 0.7.
|
||||
|
||||
Args:
|
||||
query (str | MemoryContent): query to perform vector similarity search with. If a
|
||||
string is passed, a vector embedding is created from it with the model specified
|
||||
in the RedisMemoryConfig. If a MemoryContent object is passed, the content field
|
||||
of this object is extracted and a vector embedding is created from it with the
|
||||
model specified in the RedisMemoryConfig.
|
||||
cancellation_token (CancellationToken): Token passed to cease operation. Not used.
|
||||
|
||||
Returns:
|
||||
memoryQueryResult: Object containing memories relevant to the provided query.
|
||||
"""
|
||||
# get the query string, or raise an error for unsupported MemoryContent types
|
||||
if isinstance(query, str):
|
||||
prompt = query
|
||||
elif isinstance(query, MemoryContent):
|
||||
if query.mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
|
||||
prompt = str(query.content)
|
||||
elif query.mime_type == MemoryMimeType.JSON:
|
||||
prompt = serialize(query.content)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Error: {query.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, MemoryMimeType.MARKDOWN are currently supported."
|
||||
)
|
||||
else:
|
||||
raise TypeError("'query' must be either a string or MemoryContent")
|
||||
|
||||
top_k = kwargs.pop("top_k", self.config.top_k)
|
||||
distance_threshold = kwargs.pop("distance_threshold", self.config.distance_threshold)
|
||||
|
||||
results = self.message_history.get_relevant(
|
||||
prompt=prompt, # type: ignore[reportArgumentType]
|
||||
top_k=top_k,
|
||||
distance_threshold=distance_threshold,
|
||||
raw=False,
|
||||
)
|
||||
|
||||
memories: List[MemoryContent] = []
|
||||
for result in results:
|
||||
metadata = deserialize(result["tool_call_id"]) # type: ignore[reportArgumentType]
|
||||
mime_type = MemoryMimeType(metadata.pop("mime_type"))
|
||||
if mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
|
||||
memory_content = result["content"] # type: ignore[reportArgumentType]
|
||||
elif mime_type == MemoryMimeType.JSON:
|
||||
memory_content = deserialize(result["content"]) # type: ignore[reportArgumentType]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Error: {mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, and MemoryMimeType.MARKDOWN are currently supported."
|
||||
)
|
||||
memory = MemoryContent(
|
||||
content=memory_content, # type: ignore[reportArgumentType]
|
||||
mime_type=mime_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
memories.append(memory) # type: ignore[reportUknownMemberType]
|
||||
|
||||
return MemoryQueryResult(results=memories) # type: ignore[reportUknownMemberType]
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all entries from memory, preserving the RedisMemory resources."""
|
||||
self.message_history.clear()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clears all entries from memory, and cleans up Redis client, index and resources."""
|
||||
self.message_history.delete()
|
||||
Reference in New Issue
Block a user