first commit

This commit is contained in:
DigiJ
2026-03-13 12:56:43 -07:00
commit 159cf9fcfe
309 changed files with 64584 additions and 0 deletions

View File

@@ -0,0 +1,210 @@
# Task-Centric Memory
_(EXPERIMENTAL, RESEARCH IN PROGRESS)_
**Task-Centric Memory** is an active research project aimed at giving AI agents the ability to:
* Accomplish general tasks more effectively by learning quickly and continually beyond context-window limitations.
* Remember guidance, corrections, plans, and demonstrations provided by users.
* Learn through the agent's own experience and adapt quickly to changing circumstances.
* Avoid repeating mistakes on tasks that are similar to those previously encountered.
## Installation
Install AutoGen and its extension package as follows:
```bash
pip install -U "autogen-agentchat" "autogen-ext[openai]" "autogen-ext[task-centric-memory]"
```
## Quickstart
<p align="right">
<img src="../../../../imgs/task_centric_memory_2.png" alt="Description" width="150" align="right" style="margin-left: 10px;">
</p>
This first code snippet runs a basic test to verify that the installation was successful,
as illustrated by the diagram to the right.
```python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.experimental.task_centric_memory import MemoryController
from autogen_ext.experimental.task_centric_memory.utils import PageLogger
async def main() -> None:
client = OpenAIChatCompletionClient(model="gpt-4o")
logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart"}) # Optional, but very useful.
memory_controller = MemoryController(reset=True, client=client, logger=logger)
# Add a few task-insight pairs as memories, where an insight can be any string that may help solve the task.
await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color")
await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan")
await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite")
# Retrieve memories for a new task that's related to only two of the stored memories.
memos = await memory_controller.retrieve_relevant_memos(task="What colors do I like most?")
print("{} memories retrieved".format(len(memos)))
for memo in memos:
print("- " + memo.insight)
asyncio.run(main())
```
<p align="right">
<img src="../../../../imgs/task_centric_memory_3.png" alt="Description" width="150" align="right" style="margin-left: 10px;">
</p>
This second code example shows one way to incorporate task-centric memory directly into an AutoGen agent,
in this case a subclass of RoutedAgent.
To keep the code short, only the simplest form of memory retrieval is exercised by this agent.
```python
import asyncio
from dataclasses import dataclass
from typing import List
from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.experimental.task_centric_memory import MemoryController
from autogen_ext.experimental.task_centric_memory.utils import PageLogger
@dataclass
class Message:
content: str
class MemoryEnabledAgent(RoutedAgent):
def __init__(
self, description: str, model_client: ChatCompletionClient, memory_controller: MemoryController
) -> None:
super().__init__(description)
self._model_client = model_client
self._memory_controller = memory_controller
@message_handler
async def handle_message(self, message: Message, context: MessageContext) -> Message:
# Retrieve relevant memories for the task.
memos = await self._memory_controller.retrieve_relevant_memos(task=message.content)
# Format the memories for the model.
formatted_memos = "Info that may be useful:\n" + "\n".join(["- " + memo.insight for memo in memos])
print(f"{'-' * 23}Text appended to the user message{'-' * 24}\n{formatted_memos}\n{'-' * 80}")
# Create the messages for the model with the retrieved memories.
messages: List[LLMMessage] = [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content=message.content, source="user"),
UserMessage(content=formatted_memos, source="user"),
]
# Call the model with the messages.
model_result = await self._model_client.create(messages=messages)
assert isinstance(model_result.content, str)
# Send the model's response to the user.
return Message(content=model_result.content)
async def main() -> None:
client = OpenAIChatCompletionClient(model="gpt-4o")
logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart2"}) # Optional, but very useful.
memory_controller = MemoryController(reset=True, client=client, logger=logger)
# Prepopulate memory to mimic learning from a prior session.
await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color")
await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan")
await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite")
# Create and start an agent runtime.
runtime = SingleThreadedAgentRuntime()
runtime.start()
# Register the agent type.
await MemoryEnabledAgent.register(
runtime,
"memory_enabled_agent",
lambda: MemoryEnabledAgent(
"A agent with memory", model_client=client, memory_controller=memory_controller
),
)
# Send a direct message to the agent.
request = "What colors do I like most?"
print("User request: " + request)
response = await runtime.send_message(
Message(content=request), AgentId("memory_enabled_agent", "default")
)
print("Agent response: " + response.content)
# Stop the agent runtime.
await runtime.stop()
asyncio.run(main())
```
## Sample Code
The example above modifies the agent's code.
But it's also possible to add task-centric memory to an agent or multi-agent team _without_ modifying any agent code.
See the [sample code](../../../../../../samples/task_centric_memory) for that and other forms of fast, memory-based learning.
## Architecture
<p align="right">
<img src="../../../../imgs/task_centric_memory.png" alt="Description" width="300" align="right" style="margin-left: 10px;">
</p>
The block diagram to the right outlines the key components of the architecture in the most general form.
The memory components are shown in blue, and the green blocks represent external components.
The **Memory Controller** implements the fast-learning methods described below,
and manages communication with a **Memory Bank** containing a vector DB and associated structures.
The **Agent or Team** is the AI agent or team of agents to which memory is being added.
The sample code shows how to add task-centric memory to a simple AssistantAgent or a MagenticOneGroupChat team.
The **Apprentice, app, or service** represents the code that instantiates the agent and memory controller,
and routes information between them, effectively wrapping agent and memory into a combined component.
The term _Apprentice_ connotes that this combination uses memory to learn quickly on the job.
The Apprentice class is a minimal reference implementation provided as utility code for illustration and testing,
but most applications will use their own code instead of the Apprentice.
## Memory Creation and Storage
Each stored memory (called a _memo_) contains a text insight and (optionally) a task description.
The insight is intended to help the agent accomplish future tasks that are similar to a prior task.
The memory controller provides methods for different types of learning.
If the user provides advice for solving a given task, the advice is extracted by the model client and stored as an insight.
If the user demonstrates how to perform a task,
the task and demonstration are stored together as an insight used to solve similar but different tasks.
If the agent is given a task (free of side-effects) and some means of determining success or failure,
the memory controller repeats the following learning loop in the background some number of times:
1. Test the agent on the task a few times to check for a failure.
2. If a failure is found, analyze the agent's response in order to:
1. Diagnose the failure of reasoning or missing information,
2. Phrase a general piece of advice, such as what a teacher might give to a student,
3. Temporarily append this advice to the task description,
4. Return to step 1.
5. If some piece of advice succeeds in helping the agent solve the task a number of times, add the advice as an insight to memory.
3. For each insight to be stored in memory, an LLM is prompted to generate a set of free-form, multi-word topics related to the insight. Each topic is embedded to a fixed-length vector and stored in a vector DB mapping it to the topics related insight.
## Memory Retrieval and Usage
The memory controller provides methods for different types of memory retrieval.
When the agent is given a task, the following steps are performed by the controller:
1. The task is rephrased into a generalized form.
2. A set of free-form, multi-word query topics are generated from the generalized task.
3. A potentially large number of previously stored topics, those most similar to each query topic, are retrieved from the vector DB along with the insights they map to.
4. These candidate memos are filtered by the aggregate similarity of their stored topics to the query topics.
5. In the final filtering stage, an LLM is prompted to validate only those insights that seem potentially useful in solving the task at hand.
Retrieved insights that pass the filtering steps are listed under a heading like
"Important insights that may help solve tasks like this", then appended to the task description before it is passed to the agent as usual.

View File

@@ -0,0 +1,4 @@
from ._memory_bank import MemoryBankConfig
from .memory_controller import MemoryController, MemoryControllerConfig
__all__ = ["MemoryController", "MemoryControllerConfig", "MemoryBankConfig"]

View File

@@ -0,0 +1,201 @@
import os
import pickle
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, TypedDict
from ._string_similarity_map import StringSimilarityMap
from .utils.page_logger import PageLogger
@dataclass
class Memo:
"""
Represents an atomic unit of memory that can be stored in a memory bank and later retrieved.
"""
task: str | None # The task description, if any.
insight: str # A hint, solution, plan, or any other text that may help solve a similar task.
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
# the settings that change frequently, as when loading many settings from a single YAML file.
class MemoryBankConfig(TypedDict, total=False):
path: str
relevance_conversion_threshold: float
n_results: int
distance_threshold: int
class MemoryBank:
"""
Stores task-completion insights as memories in a vector DB for later retrieval.
Args:
reset: True to clear the DB before starting.
config: An optional dict that can be used to override the following values:
- path: The path to the directory where the memory bank files are stored.
- relevance_conversion_threshold: The threshold used to normalize relevance.
- n_results: The maximum number of most relevant results to return for any given topic.
- distance_threshold: The maximum string-pair distance for a memo to be retrieved.
logger: An optional logger. If None, no logging will be performed.
"""
def __init__(
self,
reset: bool,
config: MemoryBankConfig | None = None,
logger: PageLogger | None = None,
) -> None:
if logger is None:
logger = PageLogger() # Nothing will be logged by this object.
self.logger = logger
self.logger.enter_function()
# Apply default settings and any config overrides.
memory_dir_path = "./memory_bank/default"
self.relevance_conversion_threshold = 1.7
self.n_results = 25
self.distance_threshold = 100
if config is not None:
memory_dir_path = config.get("path", memory_dir_path)
self.relevance_conversion_threshold = config.get(
"relevance_conversion_threshold", self.relevance_conversion_threshold
)
self.n_results = config.get("n_results", self.n_results)
self.distance_threshold = config.get("distance_threshold", self.distance_threshold)
memory_dir_path = os.path.expanduser(memory_dir_path)
self.logger.info("\nMEMORY BANK DIRECTORY {}".format(memory_dir_path))
path_to_db_dir = os.path.join(memory_dir_path, "string_map")
self.path_to_dict = os.path.join(memory_dir_path, "uid_memo_dict.pkl")
self.string_map = StringSimilarityMap(reset=reset, path_to_db_dir=path_to_db_dir, logger=self.logger)
# Load or create the associated memo dict on disk.
self.uid_memo_dict: Dict[str, Memo] = {}
self.last_memo_id = 0
if (not reset) and os.path.exists(self.path_to_dict):
self.logger.info("\nLOADING MEMOS FROM DISK at {}".format(self.path_to_dict))
with open(self.path_to_dict, "rb") as f:
self.uid_memo_dict = pickle.load(f)
self.last_memo_id = len(self.uid_memo_dict)
self.logger.info("\n{} MEMOS LOADED".format(len(self.uid_memo_dict)))
# Clear the DB if requested.
if reset:
self._reset_memos()
self.logger.leave_function()
def reset(self) -> None:
"""
Forces immediate deletion of all contents, in memory and on disk.
"""
self.string_map.reset_db()
self._reset_memos()
def _reset_memos(self) -> None:
"""
Forces immediate deletion of the memos, in memory and on disk.
"""
self.logger.info("\nCLEARING MEMOS")
self.uid_memo_dict = {}
self.save_memos()
def save_memos(self) -> None:
"""
Saves the current memo structures (possibly empty) to disk.
"""
self.string_map.save_string_pairs()
with open(self.path_to_dict, "wb") as file:
self.logger.info("\nSAVING MEMOS TO DISK at {}".format(self.path_to_dict))
pickle.dump(self.uid_memo_dict, file)
def contains_memos(self) -> bool:
"""
Returns True if the memory bank contains any memo.
"""
return len(self.uid_memo_dict) > 0
def _map_topics_to_memo(self, topics: List[str], memo_id: str, memo: Memo) -> None:
"""
Adds a mapping in the vec DB from each topic to the memo.
"""
self.logger.enter_function()
self.logger.info("\nINSIGHT\n{}".format(memo.insight))
for topic in topics:
self.logger.info("\n TOPIC = {}".format(topic))
self.string_map.add_input_output_pair(topic, memo_id)
self.uid_memo_dict[memo_id] = memo
self.save_memos()
self.logger.leave_function()
def add_memo(self, insight_str: str, topics: List[str], task_str: Optional[str] = None) -> None:
"""
Adds an insight to the memory bank, given topics related to the insight, and optionally the task.
"""
self.logger.enter_function()
self.last_memo_id += 1
id_str = str(self.last_memo_id)
insight = Memo(insight=insight_str, task=task_str)
self._map_topics_to_memo(topics, id_str, insight)
self.logger.leave_function()
def add_task_with_solution(self, task: str, solution: str, topics: List[str]) -> None:
"""
Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight.
This is useful when the insight is a demonstration of how to solve a given type of task.
"""
self.logger.enter_function()
self.last_memo_id += 1
id_str = str(self.last_memo_id)
# Prepend the insight to the task description for context.
insight_str = "Example task:\n\n{}\n\nExample solution:\n\n{}".format(task, solution)
memo = Memo(insight=insight_str, task=task)
self._map_topics_to_memo(topics, id_str, memo)
self.logger.leave_function()
def get_relevant_memos(self, topics: List[str]) -> List[Memo]:
"""
Returns any memos from the memory bank that appear sufficiently relevant to the input topics.
"""
self.logger.enter_function()
# Retrieve all topic matches, and gather them into a single list.
matches: List[Tuple[str, str, float]] = [] # Each match is a tuple: (topic, memo_id, distance)
for topic in topics:
matches.extend(self.string_map.get_related_string_pairs(topic, self.n_results, self.distance_threshold))
# Build a dict of memo-relevance pairs from the matches.
memo_relevance_dict: Dict[str, float] = {}
for match in matches:
relevance = self.relevance_conversion_threshold - match[2]
memo_id = match[1]
if memo_id in memo_relevance_dict:
memo_relevance_dict[memo_id] += relevance
else:
memo_relevance_dict[memo_id] = relevance
# Log the details of all the retrieved memos.
self.logger.info("\n{} POTENTIALLY RELEVANT MEMOS".format(len(memo_relevance_dict)))
for memo_id, relevance in memo_relevance_dict.items():
memo = self.uid_memo_dict[memo_id]
details = ""
if memo.task is not None:
details += "\n TASK: {}\n".format(memo.task)
details += "\n INSIGHT: {}\n\n RELEVANCE: {:.3f}\n".format(memo.insight, relevance)
self.logger.info(details)
# Sort the memo-relevance pairs by relevance, in descending order.
memo_relevance_dict = dict(sorted(memo_relevance_dict.items(), key=lambda item: item[1], reverse=True))
# Compose the list of sufficiently relevant memos to return.
memo_list: List[Memo] = []
for memo_id in memo_relevance_dict:
if memo_relevance_dict[memo_id] >= 0:
memo_list.append(self.uid_memo_dict[memo_id])
self.logger.leave_function()
return memo_list

View File

@@ -0,0 +1,289 @@
import time
from typing import List, Union
from agentdhal_core import Image
from agentdhal_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
LLMMessage,
SystemMessage,
UserMessage,
)
from .utils._functions import UserContent
from .utils.page_logger import PageLogger
class Prompter:
"""
Centralizes most of the Apprentice prompts sent to the model client.
Args:
client: The client to call the model.
logger: An optional logger. If None, no logging will be performed.
"""
def __init__(self, client: ChatCompletionClient, logger: PageLogger | None = None) -> None:
if logger is None:
logger = PageLogger() # Nothing will be logged by this object.
self.logger = logger
self.client = client
self.default_system_message_content = "You are a helpful assistant."
self.time_spent_in_model_calls = 0.0
self.num_model_calls = 0
self.start_time = time.time()
# Create the chat history
self._chat_history: List[LLMMessage] = []
async def call_model(
self,
summary: str,
user_content: UserContent,
system_message_content: str | None = None,
keep_these_messages: bool = True,
) -> str:
"""
Calls the model client with the given input and returns the response.
"""
# Prepare the input message list
if system_message_content is None:
system_message_content = self.default_system_message_content
system_message: LLMMessage
if self.client.model_info["family"] == "o1":
# No system message allowed, so pass it as the first user message.
system_message = UserMessage(content=system_message_content, source="User")
else:
# System message allowed.
system_message = SystemMessage(content=system_message_content)
user_message = UserMessage(content=user_content, source="User")
input_messages = [system_message] + self._chat_history + [user_message]
# Double check the types of the input messages.
for message in input_messages:
for part in message.content:
assert isinstance(part, str) or isinstance(part, Image), "Invalid message content type: {}".format(
type(part)
)
# Call the model
start_time = time.time()
response = await self.client.create(input_messages)
assert isinstance(response, CreateResult)
response_string = response.content
assert isinstance(response_string, str)
response_message = AssistantMessage(content=response_string, source="Assistant")
assert isinstance(response_message, AssistantMessage)
self.time_spent_in_model_calls += time.time() - start_time
self.num_model_calls += 1
# Log the model call
self.logger.log_model_call(summary=summary, input_messages=input_messages, response=response)
# Manage the chat history
if keep_these_messages:
self._chat_history.append(user_message)
self._chat_history.append(response_message)
# Return the response as a string for now
return response_string
def _clear_history(self) -> None:
"""
Empties the message list containing the chat history.
"""
self._chat_history = []
async def learn_from_failure(
self, task_description: str, memory_section: str, final_response: str, expected_answer: str, work_history: str
) -> str:
"""
Tries to create an insight to help avoid the given failure in the future.
"""
sys_message = """- You are a patient and thorough teacher.
- Your job is to review work done by students and help them learn how to do better."""
user_message: List[Union[str, Image]] = []
user_message.append("# A team of students made a mistake on the following task:\n")
user_message.extend([task_description])
if len(memory_section) > 0:
user_message.append(memory_section)
user_message.append("# Here's the expected answer, which would have been correct:\n")
user_message.append(expected_answer)
user_message.append("# Here is the students' answer, which was INCORRECT:\n")
user_message.append(final_response)
user_message.append("# Please review the students' work which follows:\n")
user_message.append("**----- START OF STUDENTS' WORK -----**\n\n")
user_message.append(work_history)
user_message.append("\n**----- END OF STUDENTS' WORK -----**\n\n")
user_message.append(
"# Now carefully review the students' work above, explaining in detail what the students did right and what they did wrong.\n"
)
self._clear_history()
await self.call_model(
summary="Ask the model to learn from this failure",
system_message_content=sys_message,
user_content=user_message,
)
user_message = [
"Now put yourself in the mind of the students. What misconception led them to their incorrect answer?"
]
await self.call_model(
summary="Ask the model to state the misconception",
system_message_content=sys_message,
user_content=user_message,
)
user_message = [
"Please express your key insights in the form of short, general advice that will be given to the students. Just one or two sentences, or they won't bother to read it."
]
insight = await self.call_model(
summary="Ask the model to formulate a concise insight",
system_message_content=sys_message,
user_content=user_message,
)
return insight
async def find_index_topics(self, input_string: str) -> List[str]:
"""
Returns a list of topics related to the given string.
"""
sys_message = """You are an expert at semantic analysis."""
user_message: List[Union[str, Image]] = []
user_message.append("""- My job is to create a thorough index for a book called Task Completion, and I need your help.
- Every paragraph in the book needs to be indexed by all the topics related to various kinds of tasks and strategies for completing them.
- Your job is to read the text below and extract the task-completion topics that are covered.
- The number of topics depends on the length and content of the text. But you should list at least one topic, and potentially many more.
- Each topic you list should be a meaningful phrase composed of a few words. Don't use whole sentences as topics.
- Don't include details that are unrelated to the general nature of the task, or a potential strategy for completing tasks.
- List each topic on a separate line, without any extra text like numbering, or bullets, or any other formatting, because we don't want those things in the index of the book.\n\n""")
user_message.append("# Text to be indexed\n")
user_message.append(input_string)
self._clear_history()
topics = await self.call_model(
summary="Ask the model to extract topics", system_message_content=sys_message, user_content=user_message
)
# Parse the topics into a list.
topic_list: List[str] = []
for line in topics.split("\n"):
if len(line) > 0:
topic_list.append(line)
return topic_list
async def generalize_task(self, task_description: str, revise: bool | None = True) -> str:
"""
Attempts to rewrite a task description in a more general form.
"""
sys_message = """You are a helpful and thoughtful assistant."""
user_message: List[Union[str, Image]] = [
"We have been given a task description. Our job is not to complete the task, but merely rephrase the task in simpler, more general terms, if possible. Please reach through the following task description, then explain your understanding of the task in detail, as a single, flat list of all the important points."
]
user_message.append("\n# Task description")
user_message.append(task_description)
self._clear_history()
generalized_task = await self.call_model(
summary="Ask the model to rephrase the task in a list of important points",
system_message_content=sys_message,
user_content=user_message,
)
if revise:
user_message = [
"Do you see any parts of this list that are irrelevant to actually solving the task? If so, explain which items are irrelevant."
]
await self.call_model(
summary="Ask the model to identify irrelevant points",
system_message_content=sys_message,
user_content=user_message,
)
user_message = [
"Revise your original list to include only the most general terms, those that are critical to solving the task, removing any themes or descriptions that are not essential to the solution. Your final list may be shorter, but do not leave out any part of the task that is needed for solving the task. Do not add any additional commentary either before or after the list."
]
generalized_task = await self.call_model(
summary="Ask the model to make a final list of general terms",
system_message_content=sys_message,
user_content=user_message,
)
return generalized_task
async def validate_insight(self, insight: str, task_description: str) -> bool:
"""
Judges whether the insight could help solve the task.
"""
sys_message = """You are a helpful and thoughtful assistant."""
user_message: List[Union[str, Image]] = [
"""We have been given a potential insight that may or may not be useful for solving a given task.
- First review the following task.
- Then review the insight that follows, and consider whether it might help solve the given task.
- Do not attempt to actually solve the task.
- Reply with a single character, '1' if the insight may be useful, or '0' if it is not."""
]
user_message.append("\n# Task description")
user_message.append(task_description)
user_message.append("\n# Possibly useful insight")
user_message.append(insight)
self._clear_history()
response = await self.call_model(
summary="Ask the model to validate the insight",
system_message_content=sys_message,
user_content=user_message,
)
return response == "1"
async def extract_task(self, text: str) -> str | None:
"""
Returns a task found in the given text, or None if not found.
"""
sys_message = """You are a helpful and thoughtful assistant."""
user_message: List[Union[str, Image]] = [
"""Does the following text contain a question or a some task we are being asked to perform?
- If so, please reply with the full question or task description, along with any supporting information, but without adding extra commentary or formatting.
- If the task is just to remember something, that doesn't count as a task, so don't include it.
- If there is no question or task in the text, simply write "None" with no punctuation."""
]
user_message.append("\n# Text to analyze")
user_message.append(text)
self._clear_history()
response = await self.call_model(
summary="Ask the model to extract a task", system_message_content=sys_message, user_content=user_message
)
return response if response != "None" else None
async def extract_advice(self, text: str) -> str | None:
"""
Returns advice from the given text, or None if not found.
"""
sys_message = """You are a helpful and thoughtful assistant."""
user_message: List[Union[str, Image]] = [
"""Does the following text contain any information or advice that might be useful later?
- If so, please copy the information or advice, adding no extra commentary or formatting.
- If there is no potentially useful information or advice at all, simply write "None" with no punctuation."""
]
user_message.append("\n# Text to analyze")
user_message.append(text)
self._clear_history()
response = await self.call_model(
summary="Ask the model to extract advice", system_message_content=sys_message, user_content=user_message
)
return response if response != "None" else None

View File

@@ -0,0 +1,124 @@
import os
import pickle
from typing import Dict, List, Tuple, Union
import chromadb
from chromadb.api.types import (
QueryResult,
)
from chromadb.config import Settings
from .utils.page_logger import PageLogger
class StringSimilarityMap:
"""
Provides storage and similarity-based retrieval of string pairs using a vector database.
Each DB entry is a pair of strings: an input string and an output string.
The input string is embedded and used as the retrieval key.
The output string can be anything, but it's typically used as a dict key.
Vector embeddings are currently supplied by Chroma's default Sentence Transformers.
Args:
- reset: True to clear the DB immediately after creation.
- path_to_db_dir: Path to the directory where the DB is stored.
- logger: An optional logger. If None, no logging will be performed.
"""
def __init__(self, reset: bool, path_to_db_dir: str, logger: PageLogger | None = None) -> None:
if logger is None:
logger = PageLogger() # Nothing will be logged by this object.
self.logger = logger
self.path_to_db_dir = path_to_db_dir
# Load or create the vector DB on disk.
chromadb_settings = Settings(
anonymized_telemetry=False, allow_reset=True, is_persistent=True, persist_directory=path_to_db_dir
)
self.db_client = chromadb.Client(chromadb_settings)
self.vec_db = self.db_client.create_collection("string-pairs", get_or_create=True) # The collection is the DB.
# Load or create the associated string-pair dict on disk.
self.path_to_dict = os.path.join(path_to_db_dir, "uid_text_dict.pkl")
self.uid_text_dict: Dict[str, Tuple[str, str]] = {}
self.last_string_pair_id = 0
if (not reset) and os.path.exists(self.path_to_dict):
self.logger.debug("\nLOADING STRING SIMILARITY MAP FROM DISK at {}".format(self.path_to_dict))
with open(self.path_to_dict, "rb") as f:
self.uid_text_dict = pickle.load(f)
self.last_string_pair_id = len(self.uid_text_dict)
if len(self.uid_text_dict) > 0:
self.logger.debug("\n{} STRING PAIRS LOADED".format(len(self.uid_text_dict)))
self._log_string_pairs()
# Clear the DB if requested.
if reset:
self.reset_db()
def _log_string_pairs(self) -> None:
"""
Logs all string pairs currently in the map.
"""
self.logger.debug("LIST OF STRING PAIRS")
for uid, text in self.uid_text_dict.items():
input_text, output_text = text
self.logger.debug(" ID: {}\n INPUT TEXT: {}\n OUTPUT TEXT: {}".format(uid, input_text, output_text))
def save_string_pairs(self) -> None:
"""
Saves the string-pair dict (self.uid_text_dict) to disk.
"""
self.logger.debug("\nSAVING STRING SIMILARITY MAP TO DISK at {}".format(self.path_to_dict))
with open(self.path_to_dict, "wb") as file:
pickle.dump(self.uid_text_dict, file)
def reset_db(self) -> None:
"""
Forces immediate deletion of the DB's contents, in memory and on disk.
"""
self.logger.debug("\nCLEARING STRING-PAIR MAP")
self.db_client.delete_collection("string-pairs")
self.vec_db = self.db_client.create_collection("string-pairs")
self.uid_text_dict = {}
self.save_string_pairs()
def add_input_output_pair(self, input_text: str, output_text: str) -> None:
"""
Adds one input-output string pair to the DB.
"""
self.last_string_pair_id += 1
self.vec_db.add(documents=[input_text], ids=[str(self.last_string_pair_id)])
self.uid_text_dict[str(self.last_string_pair_id)] = input_text, output_text
self.logger.debug(
"\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {}\n INPUT\n {}\n OUTPUT\n {}\n".format(
self.last_string_pair_id, input_text, output_text
)
)
# self._log_string_pairs() # For deeper debugging, uncomment to log all string pairs after each addition.
def get_related_string_pairs(
self, query_text: str, n_results: int, threshold: Union[int, float]
) -> List[Tuple[str, str, float]]:
"""
Retrieves up to n string pairs that are related to the given query text within the specified distance threshold.
"""
string_pairs_with_distances: List[Tuple[str, str, float]] = []
if n_results > len(self.uid_text_dict):
n_results = len(self.uid_text_dict)
if n_results > 0:
results: QueryResult = self.vec_db.query(query_texts=[query_text], n_results=n_results)
num_results = len(results["ids"][0])
for i in range(num_results):
uid = results["ids"][0][i]
input_text = results["documents"][0][i] if results["documents"] else ""
distance = results["distances"][0][i] if results["distances"] else 0.0
if distance < threshold:
input_text_2, output_text = self.uid_text_dict[uid]
assert input_text == input_text_2
self.logger.debug(
"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {}\n OUTPUT\n {}\n DISTANCE\n {}".format(
input_text, output_text, distance
)
)
string_pairs_with_distances.append((input_text, output_text, distance))
return string_pairs_with_distances

View File

@@ -0,0 +1,478 @@
from typing import TYPE_CHECKING, Awaitable, Callable, List, Tuple, TypedDict
from agentdhal_core.models import (
ChatCompletionClient,
)
from ._memory_bank import Memo, MemoryBank
from ._prompter import Prompter
if TYPE_CHECKING:
from ._memory_bank import MemoryBankConfig
from .utils.grader import Grader
from .utils.page_logger import PageLogger
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
# the settings that change frequently, as when loading many settings from a single YAML file.
class MemoryControllerConfig(TypedDict, total=False):
generalize_task: bool
revise_generalized_task: bool
generate_topics: bool
validate_memos: bool
max_memos_to_retrieve: int
max_train_trials: int
max_test_trials: int
MemoryBank: "MemoryBankConfig"
class MemoryController:
"""
(EXPERIMENTAL, RESEARCH IN PROGRESS)
Implements fast, memory-based learning, and manages the flow of information to and from a memory bank.
Args:
reset: True to empty the memory bank before starting.
client: The model client to use internally.
task_assignment_callback: An optional callback used to assign a task to any agent managed by the caller.
config: An optional dict that can be used to override the following values:
- generalize_task: Whether to rewrite tasks in more general terms.
- revise_generalized_task: Whether to critique then rewrite the generalized task.
- generate_topics: Whether to base retrieval directly on tasks, or on topics extracted from tasks.
- validate_memos: Whether to apply a final validation stage to retrieved memos.
- max_memos_to_retrieve: The maximum number of memos to return from retrieve_relevant_memos().
- max_train_trials: The maximum number of learning iterations to attempt when training on a task.
- max_test_trials: The total number of attempts made when testing for failure on a task.
- MemoryBank: A config dict passed to MemoryBank.
logger: An optional logger. If None, a default logger will be created.
Example:
The `task-centric-memory` extra first needs to be installed:
.. code-block:: bash
pip install "agentdhal-ext[task-centric-memory]"
The following code snippet shows how to use this class for the most basic storage and retrieval of memories.:
.. code-block:: python
import asyncio
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
from agentdhal_extensions.experimental.task_centric_memory import MemoryController
from agentdhal_extensions.experimental.task_centric_memory.utils import PageLogger
async def main() -> None:
client = OpenAIChatCompletionClient(model="gpt-4o")
logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart"}) # Optional, but very useful.
memory_controller = MemoryController(reset=True, client=client, logger=logger)
# Add a few task-insight pairs as memories, where an insight can be any string that may help solve the task.
await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color")
await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan")
await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite")
# Retrieve memories for a new task that's related to only two of the stored memories.
memos = await memory_controller.retrieve_relevant_memos(task="What colors do I like most?")
print("{} memories retrieved".format(len(memos)))
for memo in memos:
print("- " + memo.insight)
asyncio.run(main())
"""
def __init__(
self,
reset: bool,
client: ChatCompletionClient,
task_assignment_callback: Callable[[str], Awaitable[Tuple[str, str]]] | None = None,
config: MemoryControllerConfig | None = None,
logger: PageLogger | None = None,
) -> None:
if logger is None:
logger = PageLogger({"level": "DEBUG"})
self.logger = logger
self.logger.enter_function()
# Apply default settings and any config overrides.
self.generalize_task = True
self.revise_generalized_task = True
self.generate_topics = True
self.validate_memos = True
self.max_memos_to_retrieve = 10
self.max_train_trials = 10
self.max_test_trials = 3
memory_bank_config = None
if config is not None:
self.generalize_task = config.get("generalize_task", self.generalize_task)
self.revise_generalized_task = config.get("revise_generalized_task", self.revise_generalized_task)
self.generate_topics = config.get("generate_topics", self.generate_topics)
self.validate_memos = config.get("validate_memos", self.validate_memos)
self.max_memos_to_retrieve = config.get("max_memos_to_retrieve", self.max_memos_to_retrieve)
self.max_train_trials = config.get("max_train_trials", self.max_train_trials)
self.max_test_trials = config.get("max_test_trials", self.max_test_trials)
memory_bank_config = config.get("MemoryBank", memory_bank_config)
self.client = client
self.task_assignment_callback = task_assignment_callback
self.prompter = Prompter(client, logger)
self.memory_bank = MemoryBank(reset=reset, config=memory_bank_config, logger=logger)
self.grader = Grader(client, logger)
self.logger.leave_function()
def reset_memory(self) -> None:
"""
Empties the memory bank in RAM and on disk.
"""
self.memory_bank.reset()
async def train_on_task(self, task: str, expected_answer: str) -> None:
"""
Repeatedly assigns a task to the agent, and tries to learn from failures by creating useful insights as memories.
"""
self.logger.enter_function()
self.logger.info("Iterate on the task, possibly discovering a useful new insight.\n")
_, insight = await self._iterate_on_task(task, expected_answer)
if insight is None:
self.logger.info("No useful insight was discovered.\n")
else:
self.logger.info("A new insight was created:\n{}".format(insight))
await self.add_memo(insight, task)
self.logger.leave_function()
async def test_on_task(self, task: str, expected_answer: str, num_trials: int = 1) -> Tuple[str, int, int]:
"""
Assigns a task to the agent, along with any relevant memos retrieved from memory.
"""
self.logger.enter_function()
assert self.task_assignment_callback is not None
response = ""
num_successes = 0
for trial in range(num_trials):
self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1))
task_plus_insights = task
# Try to retrieve any relevant memories from the DB.
filtered_memos = await self.retrieve_relevant_memos(task)
filtered_insights = [memo.insight for memo in filtered_memos]
if len(filtered_insights) > 0:
self.logger.info("Relevant insights were retrieved from memory.\n")
memory_section = self._format_memory_section(filtered_insights)
if len(memory_section) > 0:
task_plus_insights = task + "\n\n" + memory_section
# Attempt to solve the task.
self.logger.info("Try to solve the task.\n")
response, _ = await self.task_assignment_callback(task_plus_insights)
# Check if the response is correct.
response_is_correct, extracted_answer = await self.grader.is_response_correct(
task, response, expected_answer
)
self.logger.info("Extracted answer: {}".format(extracted_answer))
if response_is_correct:
self.logger.info("Answer is CORRECT.\n")
num_successes += 1
else:
self.logger.info("Answer is INCORRECT.\n")
# Calculate the success rate as a percentage, rounded to the nearest whole number.
self.logger.info("\nSuccess rate: {}%\n".format(round((num_successes / num_trials) * 100)))
self.logger.leave_function()
return response, num_successes, num_trials
async def add_memo(self, insight: str, task: None | str = None, index_on_both: bool = True) -> None:
"""
Adds one insight to the memory bank, using the task (if provided) as context.
"""
self.logger.enter_function()
generalized_task = ""
if task is not None:
self.logger.info("\nGIVEN TASK:")
self.logger.info(task)
if self.generalize_task:
generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task)
else:
generalized_task = task
self.logger.info("\nGIVEN INSIGHT:")
self.logger.info(insight)
# Get a list of topics from the insight and the task (if provided).
if task is None:
text_to_index = insight
self.logger.info("\nTOPICS EXTRACTED FROM INSIGHT:")
else:
if index_on_both:
text_to_index = generalized_task.strip() + "\n(Hint: " + insight + ")"
self.logger.info("\nTOPICS EXTRACTED FROM TASK AND INSIGHT COMBINED:")
else:
text_to_index = task
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
if self.generate_topics:
topics = await self.prompter.find_index_topics(text_to_index)
else:
topics = [text_to_index]
self.logger.info("\n".join(topics))
self.logger.info("")
# Add the insight to the memory bank.
self.memory_bank.add_memo(insight, topics, task)
self.logger.leave_function()
async def add_task_solution_pair_to_memory(self, task: str, solution: str) -> None:
"""
Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight.
This is useful when the task-solution pair is an exemplar of solving a task related to some other task.
"""
self.logger.enter_function()
self.logger.info("\nEXAMPLE TASK:")
self.logger.info(task)
self.logger.info("\nEXAMPLE SOLUTION:")
self.logger.info(solution)
# Get a list of topics from the task.
if self.generate_topics:
topics = await self.prompter.find_index_topics(task.strip())
else:
topics = [task.strip()]
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
self.logger.info("\n".join(topics))
self.logger.info("")
# Add the task and solution (as a combined insight) to the memory bank.
self.memory_bank.add_task_with_solution(task=task, solution=solution, topics=topics)
self.logger.leave_function()
async def retrieve_relevant_memos(self, task: str) -> List[Memo]:
"""
Retrieves any memos from memory that seem relevant to the task.
"""
self.logger.enter_function()
if self.memory_bank.contains_memos():
self.logger.info("\nCURRENT TASK:")
self.logger.info(task)
# Get a list of topics from the generalized task.
if self.generalize_task:
generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task)
else:
generalized_task = task
if self.generate_topics:
task_topics = await self.prompter.find_index_topics(generalized_task)
else:
task_topics = [generalized_task]
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
self.logger.info("\n".join(task_topics))
self.logger.info("")
# Retrieve relevant memos from the memory bank.
memo_list = self.memory_bank.get_relevant_memos(topics=task_topics)
# Apply a final validation stage to keep only the memos that the LLM concludes are sufficiently relevant.
validated_memos: List[Memo] = []
for memo in memo_list:
if len(validated_memos) >= self.max_memos_to_retrieve:
break
if (not self.validate_memos) or await self.prompter.validate_insight(memo.insight, task):
validated_memos.append(memo)
self.logger.info("\n{} VALIDATED MEMOS".format(len(validated_memos)))
for memo in validated_memos:
if memo.task is not None:
self.logger.info("\n TASK: {}".format(memo.task))
self.logger.info("\n INSIGHT: {}".format(memo.insight))
else:
self.logger.info("\nNO SUFFICIENTLY RELEVANT MEMOS WERE FOUND IN MEMORY")
validated_memos = []
self.logger.leave_function()
return validated_memos
def _format_memory_section(self, memories: List[str]) -> str:
"""
Formats a list of memories as a section for appending to a task description.
"""
memory_section = ""
if len(memories) > 0:
memory_section = "## Important insights that may help solve tasks like this\n"
for mem in memories:
memory_section += "- " + mem + "\n"
return memory_section
async def _test_for_failure(
self, task: str, task_plus_insights: str, expected_answer: str
) -> Tuple[bool, str, str]:
"""
Attempts to solve the given task multiple times to find a failure case to learn from.
"""
self.logger.enter_function()
self.logger.info("\nTask description, including any insights: {}".format(task_plus_insights))
self.logger.info("\nExpected answer: {}\n".format(expected_answer))
assert self.task_assignment_callback is not None
failure_found = False
response, work_history = "", ""
for trial in range(self.max_test_trials):
self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1))
# Attempt to solve the task.
self.logger.info("Try to solve the task.")
response, work_history = await self.task_assignment_callback(task_plus_insights)
response_is_correct, extracted_answer = await self.grader.is_response_correct(
task, response, expected_answer
)
self.logger.info("Extracted answer: {}".format(extracted_answer))
if response_is_correct:
self.logger.info("Answer is CORRECT.\n")
else:
self.logger.info("Answer is INCORRECT.\n Stop testing, and return the details of the failure.\n")
failure_found = True
break
self.logger.leave_function()
return failure_found, response, work_history
async def _iterate_on_task(self, task: str, expected_answer: str) -> Tuple[str, None | str]:
"""
Repeatedly assigns a task to the agent, and tries to learn from failures by creating useful insights as memories.
"""
self.logger.enter_function()
self.logger.info("\nTask description: {}".format(task))
self.logger.info("\nExpected answer: {}\n".format(expected_answer))
final_response = ""
old_memos = await self.retrieve_relevant_memos(task)
old_insights = [memo.insight for memo in old_memos]
new_insights: List[str] = []
last_insight = None
insight = None
successful_insight = None
# Loop until success (or timeout) while learning from failures.
for trial in range(1, self.max_train_trials + 1):
self.logger.info("\n----- TRAIN TRIAL {} -----\n".format(trial))
task_plus_insights = task
# Add any new insights we've accumulated so far.
if last_insight is not None:
memory_section = self._format_memory_section(old_insights + [last_insight])
else:
memory_section = self._format_memory_section(old_insights)
if len(memory_section) > 0:
task_plus_insights += "\n\n" + memory_section
# Can we find a failure case to learn from?
failure_found, response, work_history = await self._test_for_failure(
task, task_plus_insights, expected_answer
)
if not failure_found:
# No. Time to exit the loop.
self.logger.info("\nResponse is CORRECT.\n Stop looking for insights.\n")
# Was this the first trial?
if trial == 1:
# Yes. We should return the successful response, and no insight.
final_response = response
else:
# No. We learned a successful insight, which should be returned.
successful_insight = insight
break
# Will we try again?
if trial == self.max_train_trials:
# No. We're out of training trials.
self.logger.info("\nNo more trials will be attempted.\n")
break
# Try to learn from this failure.
self.logger.info("\nResponse is INCORRECT. Try to learn from this failure.\n")
insight = await self.prompter.learn_from_failure(
task, memory_section, response, expected_answer, work_history
)
self.logger.info("\nInsight: {}\n".format(insight))
new_insights.append(insight)
last_insight = insight
# Return the answer from the last loop.
self.logger.info("\n{}\n".format(final_response))
self.logger.leave_function()
return final_response, successful_insight
async def _append_any_relevant_memories(self, task: str) -> str:
"""
Appends any relevant memories to the task description.
"""
self.logger.enter_function()
filtered_memos = await self.retrieve_relevant_memos(task)
filtered_insights = [memo.insight for memo in filtered_memos]
if len(filtered_insights) > 0:
self.logger.info("Relevant insights were retrieved from memory.\n")
memory_section = self._format_memory_section(filtered_insights)
if len(memory_section) > 0:
task = task + "\n\n" + memory_section
self.logger.leave_function()
return task
async def assign_task(self, task: str, use_memory: bool = True, should_await: bool = True) -> str:
"""
Assigns a task to some agent through the task_assignment_callback, along with any relevant memories.
"""
self.logger.enter_function()
assert self.task_assignment_callback is not None
if use_memory:
task = await self._append_any_relevant_memories(task)
# Attempt to solve the task.
self.logger.info("Try to solve the task.\n")
assert should_await
response, _ = await self.task_assignment_callback(task)
self.logger.leave_function()
return response
async def consider_memo_storage(self, text: str) -> str | None:
"""
Tries to extract any advice from the given text and add it to memory.
"""
self.logger.enter_function()
advice = await self.prompter.extract_advice(text)
self.logger.info("Advice: {}".format(advice))
if advice is not None:
await self.add_memo(insight=advice)
self.logger.leave_function()
return advice
async def handle_user_message(self, text: str, should_await: bool = True) -> str:
"""
Handles a user message by extracting any advice as an insight to be stored in memory, and then calling assign_task().
"""
self.logger.enter_function()
# Check for advice.
advice = await self.consider_memo_storage(text)
# Assign the task through the task_assignment_callback, using memory only if no advice was just provided.
response = await self.assign_task(text, use_memory=(advice is None), should_await=should_await)
self.logger.leave_function()
return response

View File

@@ -0,0 +1,15 @@
from .apprentice import Apprentice, ApprenticeConfig
from .chat_completion_client_recorder import ChatCompletionClientRecorder
from .grader import Grader
from .page_logger import PageLogger, PageLoggerConfig
from .teachability import Teachability
__all__ = [
"Apprentice",
"ChatCompletionClientRecorder",
"Grader",
"PageLogger",
"Teachability",
"ApprenticeConfig",
"PageLoggerConfig",
]

View File

@@ -0,0 +1,96 @@
import hashlib
import os
from typing import List, Tuple, Union
from agentdhal_core import FunctionCall, Image
from agentdhal_core.models import FunctionExecutionResult
# Convenience types
UserContent = Union[str, List[Union[str, Image]]]
AssistantContent = Union[str, List[FunctionCall]]
FunctionExecutionContent = List[FunctionExecutionResult]
SystemContent = str
MessageContent = UserContent | AssistantContent | SystemContent | FunctionExecutionContent
def message_content_to_str(message_content: MessageContent | None) -> str:
"""
Converts the message content to a string.
"""
if message_content is None:
return ""
elif isinstance(message_content, str):
return message_content
elif isinstance(message_content, List):
converted: List[str] = list()
for item in message_content:
if isinstance(item, str):
converted.append(item)
elif isinstance(item, Image):
converted.append("<Image>")
else:
converted.append(str(item).rstrip())
return "\n".join(converted)
else:
raise AssertionError("Unexpected response type.")
def text_from_user_content(user_content: UserContent) -> str:
"""
Extracts just the text from the user content.
"""
if isinstance(user_content, str):
return user_content
elif isinstance(user_content, List):
text_list: List[str] = list()
for item in user_content:
if isinstance(item, str):
text_list.append(item.rstrip())
return "\n\n".join(text_list)
else:
raise AssertionError("Unexpected response type.")
def single_image_from_user_content(user_content: UserContent) -> Union[Image, None]:
"""
Extracts a single image from the user content.
"""
image_to_return = None
if isinstance(user_content, str):
return None
elif isinstance(user_content, List):
for item in user_content:
if isinstance(item, Image):
assert image_to_return is None, "Only one image is currently allowed in the user content."
image_to_return = item
else:
raise AssertionError("Unexpected response type.")
return image_to_return
def hash_directory(directory: str, hash_algo: str = "sha256") -> Tuple[str, int, int]:
"""Computes a hash representing the state of a directory, including its structure and file contents."""
hash_func = hashlib.new(hash_algo)
# Also count the number of files and sub-directories
num_files = 0
num_subdirs = 0
for root, dirs, files in sorted(os.walk(directory)): # Ensure order for consistent hashing
num_files += len(files)
num_subdirs += len(dirs)
for dir_name in sorted(dirs):
hash_func.update(dir_name.encode()) # Hash directory names
for file_name in sorted(files):
file_path = os.path.join(root, file_name)
hash_func.update(file_name.encode()) # Hash file names
try:
with open(file_path, "rb") as f:
while chunk := f.read(4096): # Read in chunks
hash_func.update(chunk)
except Exception:
pass
return hash_func.hexdigest(), num_files, num_subdirs

View File

@@ -0,0 +1,257 @@
import random
import time
from typing import TYPE_CHECKING, Any, List, Sequence, Tuple, TypedDict
from agentdhal_agentchat.agents import AssistantAgent
from agentdhal_agentchat.base import TaskResult
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, TextMessage
from agentdhal_core.models import (
ChatCompletionClient,
LLMMessage,
SystemMessage,
UserMessage,
)
from .page_logger import PageLogger
if TYPE_CHECKING:
from ..memory_controller import MemoryControllerConfig
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
# the settings that change frequently, as when loading many settings from a single YAML file.
class ApprenticeConfig(TypedDict, total=False):
name_of_agent_or_team: str
disable_prefix_caching: bool
MemoryController: "MemoryControllerConfig"
class Apprentice:
"""
A minimal wrapper combining task-centric memory with an agent or team.
Applications may use the Apprentice class, or they may directly instantiate
and call the Memory Controller using this class as an example.
Args:
client: The client to call the model.
config: An optional dict that can be used to override the following values:
- name_of_agent_or_team: The name of the target agent or team for assigning tasks to.
- disable_prefix_caching: True to disable prefix caching by prepending random ints to the first message.
- MemoryController: A config dict passed to MemoryController.
logger: An optional logger. If None, a default logger will be created.
"""
def __init__(
self,
client: ChatCompletionClient,
config: ApprenticeConfig | None = None,
logger: PageLogger | None = None,
) -> None:
if logger is None:
logger = PageLogger({"level": "DEBUG"})
self.logger = logger
# Apply default settings and any config overrides.
self.name_of_agent_or_team = "AssistantAgent"
self.disable_prefix_caching = False
memory_controller_config = None
if config is not None:
self.name_of_agent_or_team = config.get("name_of_agent_or_team", self.name_of_agent_or_team)
self.disable_prefix_caching = config.get("disable_prefix_caching", self.disable_prefix_caching)
memory_controller_config = config.get("MemoryController", memory_controller_config)
self.client = client
if self.disable_prefix_caching:
self.rand = random.Random()
self.rand.seed(int(time.time() * 1000))
# Create the MemoryController, which creates the MemoryBank.
from ..memory_controller import MemoryController
self.memory_controller = MemoryController(
reset=True,
client=self.client,
task_assignment_callback=self.assign_task_to_agent_or_team,
config=memory_controller_config,
logger=self.logger,
)
def reset_memory(self) -> None:
"""
Resets the memory bank.
"""
self.memory_controller.reset_memory()
async def handle_user_message(self, text: str, should_await: bool = True) -> str:
"""
Handles a user message, extracting any advice and assigning a task to the agent.
"""
self.logger.enter_function()
# Pass the user message through to the memory controller.
response = await self.memory_controller.handle_user_message(text, should_await)
self.logger.leave_function()
return response
async def add_task_solution_pair_to_memory(self, task: str, solution: str) -> None:
"""
Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight.
This is useful when the insight is a demonstration of how to solve a given type of task.
"""
self.logger.enter_function()
# Pass the task and solution through to the memory controller.
await self.memory_controller.add_task_solution_pair_to_memory(task, solution)
self.logger.leave_function()
async def assign_task(self, task: str, use_memory: bool = True, should_await: bool = True) -> str:
"""
Assigns a task to the agent, along with any relevant insights/memories.
"""
self.logger.enter_function()
# Pass the task through to the memory controller.
response = await self.memory_controller.assign_task(task, use_memory, should_await)
self.logger.leave_function()
return response
async def train_on_task(self, task: str, expected_answer: str) -> None:
"""
Repeatedly assigns a task to the completion agent, and tries to learn from failures by creating useful insights as memories.
"""
self.logger.enter_function()
# Pass the task through to the memory controller.
await self.memory_controller.train_on_task(task, expected_answer)
self.logger.leave_function()
async def assign_task_to_agent_or_team(self, task: str) -> Tuple[str, str]:
"""
Passes the given task to the target agent or team.
"""
self.logger.enter_function()
# Pass the task through.
if self.name_of_agent_or_team == "MagenticOneGroupChat":
response, work_history = await self._assign_task_to_magentic_one(task)
elif self.name_of_agent_or_team == "AssistantAgent":
response, work_history = await self._assign_task_to_assistant_agent(task)
else:
raise AssertionError("Invalid base agent")
self.logger.leave_function()
return response, work_history
async def _assign_task_to_assistant_agent(self, task: str) -> Tuple[Any, Any]:
"""
Passes the given task to a newly created AssistantAgent with a generic 6-step system prompt.
"""
self.logger.enter_function()
self.logger.info(task)
system_message_content = """You are a helpful and thoughtful assistant.
In responding to every user message, you follow the same multi-step process given here:
1. Explain your understanding of the user message in detail, covering all the important points.
2. List as many possible responses as you can think of.
3. Carefully list and weigh the pros and cons (if any) of each possible response.
4. Critique the pros and cons above, looking for any flaws in your reasoning. But don't make up flaws that don't exist.
5. Decide on the best response, looping back to step 1 if none of the responses are satisfactory.
6. Finish by providing your final response in the particular format requested by the user."""
if self.disable_prefix_caching:
# Prepend a random int to disable prefix caching.
random_str = "({})\n\n".format(self.rand.randint(0, 1000000))
system_message_content = random_str + system_message_content
system_message: LLMMessage
if self.client.model_info["family"] == "o1":
# No system message allowed, so pass it as the first user message.
system_message = UserMessage(content=system_message_content, source="User")
else:
# System message allowed.
system_message = SystemMessage(content=system_message_content)
user_message: LLMMessage = UserMessage(content=task, source="User")
system_message_list: List[LLMMessage] = [system_message]
user_message_list: List[LLMMessage] = [user_message]
input_messages: List[LLMMessage] = system_message_list + user_message_list
assistant_agent = AssistantAgent(
"assistant_agent",
self.client,
system_message=system_message_content,
)
# Get the agent's response to the task.
task_result: TaskResult = await assistant_agent.run(task=TextMessage(content=task, source="User"))
messages: Sequence[BaseAgentEvent | BaseChatMessage] = task_result.messages
message: BaseAgentEvent | BaseChatMessage = messages[-1]
response_str = message.to_text()
# Log the model call
self.logger.log_model_task(
summary="Ask the model to complete the task", input_messages=input_messages, task_result=task_result
)
self.logger.info("\n----- RESPONSE -----\n\n{}\n".format(response_str))
# Use the response as the work history as well.
work_history = response_str
self.logger.leave_function()
return response_str, work_history
async def _assign_task_to_magentic_one(self, task: str) -> Tuple[str, str]:
"""
Instantiates a MagenticOneGroupChat team, and passes the given task to it.
"""
self.logger.enter_function()
self.logger.info(task)
general_agent = AssistantAgent(
"general_agent",
self.client,
description="A general GPT-4o AI assistant capable of performing a variety of tasks.",
)
from agentdhal_extensions.agents.web_surfer import MultimodalWebSurfer
web_surfer = MultimodalWebSurfer(
name="web_surfer",
model_client=self.client,
downloads_folder="logs",
debug_dir="logs",
to_save_screenshots=True,
)
from agentdhal_agentchat.teams import MagenticOneGroupChat
team = MagenticOneGroupChat(
[general_agent, web_surfer],
model_client=self.client,
max_turns=20,
)
# Get the team's response to the task.
task_result: TaskResult = await team.run(task=task)
assert isinstance(task_result, TaskResult)
messages = task_result.messages
response_str_list: List[str] = []
for message in messages:
response_str_list.append(message.to_text())
response_str = "\n".join(response_str_list)
self.logger.info("\n----- RESPONSE -----\n\n{}\n".format(response_str))
# MagenticOne's response is the chat history, which we use here as the work history.
work_history = response_str
self.logger.leave_function()
return response_str, work_history

View File

@@ -0,0 +1,227 @@
import json
import os
import warnings
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Sequence, TypedDict, Union
from agentdhal_core import CancellationToken
from agentdhal_core.models import (
ChatCompletionClient,
CreateResult,
LLMMessage,
ModelCapabilities, # type: ignore
ModelInfo,
RequestUsage,
)
from agentdhal_core.tools import Tool, ToolSchema
from pydantic import BaseModel
from .page_logger import PageLogger
class RecordDict(TypedDict):
mode: Literal["create", "create_stream"]
messages: List[Mapping[str, Any]]
response: Dict[str, Any]
stream: List[Mapping[str, Any]]
class ChatCompletionClientRecorder(ChatCompletionClient):
"""
A chat completion client that supports fast, large-scale tests of code calling LLM clients.
Two modes are supported:
1. "record": delegates to the underlying client while also recording the input messages and responses,
which are saved to disk when finalize() is called.
2. "replay": loads previously recorded message and responses from disk, then on each call
checks that its message matches the recorded message, and returns the recorded response.
The recorded data is stored as a JSON list of records. Each record is a dictionary with a "mode"
field (either "create" or "create_stream"), a serialized list of messages, and either a "response" (for
create calls) or a "stream" (a list of streamed outputs for create_stream calls).
ReplayChatCompletionClient and ChatCompletionCache do similar things, but with significant differences:
- ReplayChatCompletionClient replays pre-defined responses in a specified order without recording anything or checking the messages sent to the client.
- ChatCompletionCache caches responses and replays them for messages that have been seen before, regardless of order, and calls the base client for any uncached messages.
"""
def __init__(
self,
client: ChatCompletionClient,
mode: Literal["record", "replay"],
session_file_path: str,
logger: PageLogger | None = None,
) -> None:
if logger is None:
self.logger = PageLogger() # Disabled by default.
else:
self.logger = logger
self.logger.enter_function()
self.logger.info("Wrapping the base client in ChatCompletionClientRecorder.")
self.base_client = client
self.mode = mode
self.session_file_path = os.path.expanduser(session_file_path)
self.records: List[RecordDict] = []
self._record_index = 0
self._num_checked_records = 0
if self.mode == "record":
# Prepare to record the messages and responses.
self.logger.info("Recording mode enabled.\nRecording session to: " + self.session_file_path)
elif self.mode == "replay":
# Load the previously recorded messages and responses from disk.
self.logger.info("Replay mode enabled.\nRetrieving session from: " + self.session_file_path)
try:
with open(self.session_file_path, "r") as f:
self.records = json.load(f)
except Exception as e:
error_str = f"\nFailed to load recorded session: '{self.session_file_path}': {e}"
self.logger.error(error_str)
raise ValueError(error_str) from e
self.logger.leave_function()
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
) -> CreateResult:
current_messages: List[Mapping[str, Any]] = [msg.model_dump() for msg in messages]
if self.mode == "record":
response = await self.base_client.create(
messages,
tools=tools,
json_output=json_output,
tool_choice=tool_choice,
extra_create_args=extra_create_args,
cancellation_token=cancellation_token,
)
rec: RecordDict = {
"mode": "create",
"messages": current_messages,
"response": response.model_dump(),
"stream": [],
}
self.records.append(rec)
return response
elif self.mode == "replay":
if self._record_index >= len(self.records):
error_str = "\nNo more recorded turns to check."
self.logger.error(error_str)
raise ValueError(error_str)
rec = self.records[self._record_index]
if rec.get("mode") != "create":
error_str = f"\nRecorded call type mismatch at index {self._record_index}: expected 'create', got '{rec.get('mode')}'."
self.logger.error(error_str)
raise ValueError(error_str)
recorded_messages = rec.get("messages")
if recorded_messages != current_messages:
error_str = (
"\nCurrent message list doesn't match the recorded message list. See the pagelogs for details."
)
assert recorded_messages is not None
self.logger.log_dict_list(recorded_messages, "recorded message list")
assert current_messages is not None
self.logger.log_dict_list(current_messages, "current message list")
self.logger.error(error_str)
raise ValueError(error_str)
self._record_index += 1
self._num_checked_records += 1
data = rec.get("response")
# Populate a CreateResult from the data.
assert data is not None
result = CreateResult(
content=data.get("content", ""),
finish_reason=data.get("finish_reason", "stop"),
usage=data.get("usage", RequestUsage(prompt_tokens=0, completion_tokens=0)),
cached=True,
)
return result
else:
error_str = f"\nUnknown mode: {self.mode}"
self.logger.error(error_str)
raise ValueError(error_str)
def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
) -> AsyncGenerator[Union[str, CreateResult], None]:
return self.base_client.create_stream(
messages,
tools=tools,
tool_choice=tool_choice,
json_output=json_output,
extra_create_args=extra_create_args,
cancellation_token=cancellation_token,
)
async def close(self) -> None:
await self.base_client.close()
def actual_usage(self) -> RequestUsage:
# Calls base_client.actual_usage() and returns the result.
return self.base_client.actual_usage()
def total_usage(self) -> RequestUsage:
# Calls base_client.total_usage() and returns the result.
return self.base_client.total_usage()
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
# Calls base_client.count_tokens() and returns the result.
return self.base_client.count_tokens(messages, tools=tools)
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
# Calls base_client.remaining_tokens() and returns the result.
return self.base_client.remaining_tokens(messages, tools=tools)
@property
def capabilities(self) -> ModelCapabilities: # type: ignore
# Calls base_client.capabilities and returns the result.
warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2)
return self.base_client.capabilities
@property
def model_info(self) -> ModelInfo:
# Calls base_client.model_info and returns the result.
return self.base_client.model_info
def finalize(self) -> None:
"""
In record mode, saves the accumulated records to disk.
In replay mode, makes sure all the records were checked.
"""
self.logger.enter_function()
if self.mode == "record":
try:
# Create the directory if it doesn't exist.
os.makedirs(os.path.dirname(self.session_file_path), exist_ok=True)
# Write the records to disk.
with open(self.session_file_path, "w") as f:
json.dump(self.records, f, indent=2)
self.logger.info("\nRecorded session was saved to: " + self.session_file_path)
except Exception as e:
error_str = f"Failed to write records to '{self.session_file_path}': {e}"
self.logger.error(error_str)
raise ValueError(error_str) from e
elif self.mode == "replay":
if self._num_checked_records < len(self.records):
error_str = f"\nEarly termination. Only {self._num_checked_records} of the {len(self.records)} recorded turns were checked."
self.logger.error(error_str)
raise ValueError(error_str)
self.logger.info("\nRecorded session was fully replayed and checked.")
self.logger.leave_function()

View File

@@ -0,0 +1,179 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Tuple, Union
from agentdhal_core import Image
from agentdhal_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
LLMMessage,
SystemMessage,
UserMessage,
)
from ._functions import UserContent
from .page_logger import PageLogger
if TYPE_CHECKING:
from .apprentice import Apprentice
class Grader:
"""
Runs basic tests, and determines task success without limitation to string matches.
Args:
client: The client to call the model.
logger: An optional logger. If None, no logging will be performed.
"""
def __init__(self, client: ChatCompletionClient, logger: PageLogger | None = None) -> None:
if logger is None:
logger = PageLogger() # Nothing will be logged by this object.
self.logger = logger
self.client = client
# Create the chat history
self._chat_history: List[LLMMessage] = []
async def test_apprentice(
self,
apprentice: Apprentice,
task_description: str,
expected_answer: str,
num_trials: int,
use_memory: bool,
client: ChatCompletionClient,
) -> Tuple[int, int]:
self.logger.enter_function()
self.logger.info("Testing the apprentice on the given task.\n")
num_successes = 0
for trial in range(num_trials):
self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1))
self.logger.info("Try to solve the task.\n")
response = await apprentice.assign_task(task_description, use_memory=use_memory)
response_is_correct, extracted_answer = await self.is_response_correct(
task_description, response, expected_answer
)
self.logger.info("Extracted answer: {}".format(extracted_answer))
if response_is_correct:
self.logger.info("Answer is CORRECT.\n")
num_successes += 1
else:
self.logger.info("Answer is INCORRECT.\n")
self.logger.info("\nSuccess rate: {}%\n".format(round((num_successes / num_trials) * 100)))
self.logger.leave_function()
return num_successes, num_trials
async def call_model(
self,
summary: str,
user_content: UserContent,
system_message_content: str | None = None,
keep_these_messages: bool = True,
) -> str:
"""
Calls the model client with the given input and returns the response.
"""
# Prepare the input message list
if system_message_content is None:
system_message_content = "You are a helpful assistant."
system_message: LLMMessage
if self.client.model_info["family"] == "o1":
# No system message allowed, so pass it as the first user message.
system_message = UserMessage(content=system_message_content, source="User")
else:
# System message allowed.
system_message = SystemMessage(content=system_message_content)
user_message = UserMessage(content=user_content, source="User")
input_messages = [system_message] + self._chat_history + [user_message]
# Call the model.
response = await self.client.create(input_messages)
assert isinstance(response, CreateResult)
response_string = response.content
assert isinstance(response_string, str)
response_message = AssistantMessage(content=response_string, source="Assistant")
assert isinstance(response_message, AssistantMessage)
# Log the model call
self.logger.log_model_call(summary=summary, input_messages=input_messages, response=response)
# Manage the chat history
if keep_these_messages:
self._chat_history.append(user_message)
self._chat_history.append(response_message)
# Return the response as a string
return response_string
def _clear_history(self) -> None:
"""
Empties the message list containing the chat history.
"""
self._chat_history = []
async def is_response_correct(
self, task_description: str, response_to_be_graded: str, correct_answer: str
) -> Tuple[bool, str]:
"""
Determines whether the response is equivalent to the task's correct answer.
"""
self.logger.enter_function()
sys_message = """You are a helpful and thoughtful assistant."""
# Ask the model to extract the answer from the response.
user_message: List[Union[str, Image]] = []
user_message.append("""Your job is to extract a possible answer to the following question from the given text.
- First review the following task.
- Then review the text that follows, which may an answer, plus reasoning that led to the answer.
- Do not attempt to actually solve the task yourself.
- Don't try to judge whether the reasoning steps were correct.
- Simply respond by summarizing the answer described in the text, omitting any other parts of the text.
- If no answer is present can be extracted from the text, simply reply "None".""")
user_message.append("\n# Task description")
user_message.append(task_description)
user_message.append("\n# Text that may contain an answer")
user_message.append(response_to_be_graded)
user_message_arg: UserContent = user_message
self._clear_history()
extracted_answer = await self.call_model(
summary="Ask the model to extract the answer",
system_message_content=sys_message,
user_content=user_message_arg,
)
self.logger.info("Extracted answer: " + extracted_answer)
# Ask the model to check the answer for correctness.
user_message = [
"""Your job is to decide whether a given answer to a task is correct or not.
- You will be given the task description and the correct, gold-standard answer, along with the answer to be graded.
- In general, an answer is correct if it is equivalent to the correct answer.
- Specifically, the given answer must contain the important information from the correct answer, and must not in any way contradict the correct answer.
- Ignore any differences of grammar, spelling mistakes, punctuation, capitalization, formatting, or extra commentary.
- An answer should be considered correct if it omits information that is clearly inferred.
- For instance, if the correct answer is "Paris, France", the answer "Paris" should be considered correct.
- Respond with a single character: '1' if the answer to be graded is correct", '0' if not."""
]
user_message.append("\n# Task description")
user_message.append(task_description)
user_message.append("\n# Correct answer")
user_message.append(correct_answer)
user_message.append("\n# Answer to be graded")
user_message.append(extracted_answer)
self._clear_history()
decision = await self.call_model(
summary="Ask the model to check the answer for correctness",
system_message_content=sys_message,
user_content=user_message,
)
self.logger.info("Decision: " + decision)
self.logger.leave_function()
return decision == "1", extracted_answer

View File

@@ -0,0 +1,546 @@
import inspect
import json
import os
import shutil
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict
from agentdhal_agentchat.base import TaskResult
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage
from agentdhal_core import Image
from agentdhal_core.models import (
AssistantMessage,
CreateResult,
FunctionExecutionResultMessage,
LLMMessage,
RequestUsage,
SystemMessage,
UserMessage,
)
from ._functions import MessageContent, hash_directory
def _html_opening(file_title: str, finished: bool = False) -> str:
"""
Returns the opening text of a simple HTML file.
"""
refresh_tag = '<meta http-equiv="refresh" content="2">' if not finished else ""
st = f"""
<!DOCTYPE html>
<html>
<head>
{refresh_tag}
<title>{file_title}</title>
<style>
body {{font-size: 20px}}
body {{white-space: pre-wrap}}
</style>
</head>
<body>"""
return st
def _html_closing() -> str:
"""
Return the closing text of a simple HTML file.
"""
return """</body></html>"""
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
# the settings that change frequently, as when loading many settings from a single YAML file.
class PageLoggerConfig(TypedDict, total=False):
level: str
path: str
class PageLogger:
"""
Logs text and images to a set of HTML pages, one per function/method, linked to each other in a call tree.
Args:
config: An optional dict that can be used to override the following values:
- level: The logging level, one of DEBUG, INFO, WARNING, ERROR, CRITICAL, or NONE.
- path: The path to the directory where the log files will be written.
"""
def __init__(self, config: PageLoggerConfig | None = None) -> None:
self.levels = {
"DEBUG": 10,
"INFO": 20,
"WARNING": 30,
"ERROR": 40,
"CRITICAL": 50,
"NONE": 100,
}
# Apply default settings and any config overrides.
level_str = "NONE" # Default to no logging at all.
self.log_dir = "./pagelogs/default"
if config is not None:
level_str = config.get("level", level_str)
self.log_dir = config.get("path", self.log_dir)
self.level = self.levels[level_str]
self.log_dir = os.path.expanduser(self.log_dir)
# If the logging level is set to NONE or higher, don't log anything.
if self.level >= self.levels["NONE"]:
return
self.page_stack = PageStack()
self.pages: List[Page] = []
self.last_page_id = 0
self.name = "0 Call Tree"
self._create_run_dir()
self.flush()
self.finalized = False
def __del__(self) -> None:
self.finalize()
def finalize(self) -> None:
# Writes a hash of the log directory to a file for change detection.
if self.level >= self.levels["NONE"]:
return
# Don't finalize the log if it has already been finalized.
if self.finalized:
return
# Do nothing if the app is being forced to exit early.
if self.page_stack.size() > 0:
return
self.flush(finished=True)
# Write the hash and other details to a file.
hash_str, num_files, num_subdirs = hash_directory(self.log_dir)
hash_path = os.path.join(self.log_dir, "hash.txt")
with open(hash_path, "w") as f:
f.write(hash_str)
f.write("\n")
f.write("{} files\n".format(num_files))
f.write("{} subdirectories\n".format(num_subdirs))
self.finalized = True
@staticmethod
def _decorate_text(text: str, color: str, weight: str = "bold", demarcate: bool = False) -> str:
"""
Returns a string of text with HTML styling for weight and color.
"""
if demarcate:
text = f"<<<<< {text} >>>>>"
return f'<span style="color: {color}; font-weight: {weight};">{text}</span>'
@staticmethod
def _link_to_image(image_path: str, description: str) -> str:
"""
Returns an HTML string defining a thumbnail link to an image.
"""
# To avoid a bug in heml rendering aht displays underscores to the left of thumbnails,
# define the following string on a single line.
link = f"""<a href="{image_path}"><img src="{image_path}" alt="{description}" style="width: 300px; height: auto;"></a>"""
return link
def _get_next_page_id(self) -> int:
"""Returns the next page id and increments the counter."""
self.last_page_id += 1
return self.last_page_id
def _create_run_dir(self) -> None:
"""Creates a fresh log directory."""
if os.path.exists(self.log_dir):
shutil.rmtree(self.log_dir)
os.makedirs(self.log_dir)
def _add_page(self, summary: str, show_in_call_tree: bool = True, finished: bool = True) -> "Page":
"""
Adds a new page to the log.
"""
page = Page(
page_logger=self,
index=self._get_next_page_id(),
summary=summary,
indent_level=len(self.page_stack.stack),
show_in_call_tree=show_in_call_tree,
finished=finished,
)
self.pages.append(page)
self.flush()
if len(self.page_stack.stack) > 0:
# Insert a link to the new page into the calling page.
self.info("\n" + page.full_link)
return page
def _log_text(self, text: str) -> None:
"""
Adds text to the current page.
"""
page = self.page_stack.top()
if page is not None:
page.add_lines(text, flush=True)
def debug(self, line: str) -> None:
"""
Adds DEBUG text to the current page if debugging level <= DEBUG.
"""
if self.level <= self.levels["DEBUG"]:
self._log_text(line)
def info(self, line: str) -> None:
"""
Adds INFO text to the current page if debugging level <= INFO.
"""
if self.level <= self.levels["INFO"]:
self._log_text(line)
def warning(self, line: str) -> None:
"""
Adds WARNING text to the current page if debugging level <= WARNING.
"""
if self.level <= self.levels["WARNING"]:
self._log_text(line)
def error(self, line: str) -> None:
"""
Adds ERROR text to the current page if debugging level <= ERROR.
"""
if self.level <= self.levels["ERROR"]:
self._log_text(line)
def critical(self, line: str) -> None:
"""
Adds CRITICAL text to the current page if debugging level <= CRITICAL.
"""
if self.level <= self.levels["CRITICAL"]:
self._log_text(line)
def _message_source(self, message: LLMMessage) -> str:
"""
Returns a decorated string indicating the source of a message.
"""
source = "UNKNOWN"
color = "black"
if isinstance(message, SystemMessage):
source = "SYSTEM"
color = "purple"
elif isinstance(message, UserMessage):
source = "USER"
color = "blue"
elif isinstance(message, AssistantMessage):
source = "ASSISTANT"
color = "green"
elif isinstance(message, FunctionExecutionResultMessage):
source = "FUNCTION"
color = "red"
return self._decorate_text(source, color, demarcate=True)
def _format_message_content(self, message_content: MessageContent) -> str:
"""
Formats the message content for logging.
"""
# Start by converting the message content to a list of strings.
content_list: List[str] = []
content = message_content
if isinstance(content, str):
content_list.append(content)
elif isinstance(content, list):
for item in content:
if isinstance(item, str):
content_list.append(item.rstrip())
elif isinstance(item, Image):
# Save the image to disk.
image_filename = str(self._get_next_page_id()) + " image.jpg"
image_path = os.path.join(self.log_dir, image_filename)
item.image.save(image_path)
# Add a link to the image.
content_list.append(self._link_to_image(image_filename, "message_image"))
elif isinstance(item, Dict):
# Add a dictionary to the log.
json_str = json.dumps(item, indent=4)
content_list.append(json_str)
else:
content_list.append(str(item).rstrip())
else:
content_list.append("<UNKNOWN MESSAGE CONTENT>")
# Convert the list of strings to a single string containing newline separators.
output = ""
for item in content_list:
output += f"\n{item}\n"
return output
def log_message_content(self, message_content: MessageContent, summary: str) -> None:
"""
Adds a page containing the message's content, including any images.
"""
if self.level > self.levels["INFO"]:
return None
page = self._add_page(summary=summary, show_in_call_tree=False)
self.page_stack.write_stack_to_page(page)
page.add_lines(self._format_message_content(message_content=message_content))
page.flush()
def log_dict_list(self, content: List[Mapping[str, Any]], summary: str) -> None:
"""
Adds a page containing a list of dicts.
"""
if self.level > self.levels["INFO"]:
return None
page = self._add_page(summary=summary, show_in_call_tree=False)
self.page_stack.write_stack_to_page(page)
for item in content:
json_str = json.dumps(item, indent=4)
page.add_lines(json_str)
page.flush()
def _log_model_messages(
self, summary: str, input_messages: List[LLMMessage], response_str: str, usage: RequestUsage | None
) -> Optional["Page"]:
"""
Adds a page containing the messages to a model (including any input images) and its response.
"""
page = self._add_page(summary=summary, show_in_call_tree=False)
self.page_stack.write_stack_to_page(page)
if usage is not None:
page.add_lines("{} prompt tokens".format(usage.prompt_tokens))
page.add_lines("{} completion tokens".format(usage.completion_tokens))
for m in input_messages:
page.add_lines("\n" + self._message_source(m))
page.add_lines(self._format_message_content(message_content=m.content))
page.add_lines("\n" + self._decorate_text("ASSISTANT RESPONSE", "green", demarcate=True))
page.add_lines("\n" + response_str + "\n")
page.flush()
return page
def log_model_call(
self, summary: str, input_messages: List[LLMMessage], response: CreateResult
) -> Optional["Page"]:
"""
Logs messages sent to a model and the TaskResult response to a new page.
"""
if self.level > self.levels["INFO"]:
return None
response_str = response.content
if not isinstance(response_str, str):
response_str = "??"
page = self._log_model_messages(summary, input_messages, response_str, response.usage)
return page
def log_model_task(
self, summary: str, input_messages: List[LLMMessage], task_result: TaskResult
) -> Optional["Page"]:
"""
Logs messages sent to a model and the TaskResult response to a new page.
"""
if self.level > self.levels["INFO"]:
return None
messages: Sequence[BaseAgentEvent | BaseChatMessage] = task_result.messages
message = messages[-1]
response_str = message.to_text()
if not isinstance(response_str, str):
response_str = "??"
if hasattr(message, "models_usage"):
usage: RequestUsage | None = message.models_usage
else:
usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
page = self._log_model_messages(summary, input_messages, response_str, usage)
return page
def log_link_to_local_file(self, file_path: str) -> str:
"""
Returns a link to a local file in the log.
"""
file_name = os.path.basename(file_path)
link = f'<a href="{file_name}">{file_name}</a>'
return link
def add_link_to_image(self, description: str, source_image_path: str) -> None:
"""
Inserts a thumbnail link to an image to the page.
"""
# Remove every character from the string 'description' that is not alphanumeric or a space.
description = "".join(e for e in description if e.isalnum() or e.isspace())
target_image_filename = str(self._get_next_page_id()) + " - " + description
# Copy the image to the log directory.
local_image_path = os.path.join(self.log_dir, target_image_filename)
shutil.copyfile(source_image_path, local_image_path)
self._log_text("\n" + description)
self._log_text(self._link_to_image(target_image_filename, description))
def flush(self, finished: bool = False) -> None:
"""
Writes the current state of the log to disk.
"""
if self.level > self.levels["INFO"]:
return
# Create a call tree of the log.
call_tree_path = os.path.join(self.log_dir, self.name + ".html")
with open(call_tree_path, "w") as f:
f.write(_html_opening("0 Call Tree", finished=finished))
f.write(f"<h3>{self.name}</h3>")
f.write("\n")
for page in self.pages:
if page.show_in_call_tree:
f.write(page.line_text + "\n")
f.write("\n")
f.write(_html_closing())
def enter_function(self) -> Optional["Page"]:
"""
Adds a new page corresponding to the current function call.
"""
if self.level > self.levels["INFO"]:
return None
page = None
frame_type = inspect.currentframe()
if frame_type is not None:
frame = frame_type.f_back # Get the calling frame
if frame is not None:
# Check if it's a method by looking for 'self' or 'cls' in f_locals
if "self" in frame.f_locals:
class_name = type(frame.f_locals["self"]).__name__
elif "cls" in frame.f_locals:
class_name = frame.f_locals["cls"].__name__
else:
class_name = None # Not part of a class
if class_name is None: # Not part of a class
caller_name = frame.f_code.co_name
else:
caller_name = class_name + "." + frame.f_code.co_name
# Create a new page for this function.
page = self._add_page(summary=caller_name, show_in_call_tree=True, finished=False)
self.page_stack.push(page)
self.page_stack.write_stack_to_page(page)
page.add_lines("\nENTER {}".format(caller_name), flush=True)
return page
def leave_function(self) -> None:
"""
Finishes the page corresponding to the current function call.
"""
if self.level > self.levels["INFO"]:
return None
page = self.page_stack.top()
if page is not None:
page.finished = True
page.add_lines("\nLEAVE {}".format(page.summary), flush=True)
self.page_stack.pop()
class Page:
"""
Represents a single HTML page in the logger output.
Args:
page_logger: The PageLogger object that created this page.
index: The index of the page.
summary: A brief summary of the page's contents for display.
indent_level: The level of indentation in the call tree.
show_in_call_tree: Whether to display the page in the call tree.
finished: Whether the page is complete.
"""
def __init__(
self,
page_logger: PageLogger,
index: int,
summary: str,
indent_level: int,
show_in_call_tree: bool = True,
finished: bool = True,
):
"""
Initializes and writes to a new HTML page.
"""
self.page_logger = page_logger
self.index_str = str(index)
self.summary = summary
self.indent_level = indent_level
self.show_in_call_tree = show_in_call_tree
self.finished = finished
self.file_title = self.index_str + " " + self.summary
self.indentation_text = "|&emsp;" * self.indent_level
self.full_link = f'<a href="{self.index_str}.html">{self.file_title}</a>'
self.line_text = self.indentation_text + self.full_link
self.lines: List[str] = []
self.flush()
def add_lines(self, lines: str, flush: bool = False) -> None:
"""
Adds one or more lines to the page.
"""
lines_to_add: List[str] = []
if "\n" in lines:
lines_to_add = lines.split("\n")
else:
lines_to_add.append(lines)
self.lines.extend(lines_to_add)
if flush:
self.flush()
def flush(self) -> None:
"""
Writes the HTML page to disk.
"""
page_path = os.path.join(self.page_logger.log_dir, self.index_str + ".html")
with open(page_path, "w") as f:
f.write(_html_opening(self.file_title, finished=self.finished))
f.write(f"<h3>{self.file_title}</h3>\n")
for line in self.lines:
try:
f.write(f"{line}\n")
except UnicodeEncodeError:
f.write("UnicodeEncodeError in this line.\n")
f.write(_html_closing())
f.flush()
class PageStack:
"""
A call stack containing a list of currently active function pages in the order they called each other.
"""
def __init__(self) -> None:
self.stack: List[Page] = []
def push(self, page: Page) -> None:
"""Adds a page to the top of the stack."""
self.stack.append(page)
def pop(self) -> Page:
"""Removes and returns the top page from the stack"""
return self.stack.pop()
def size(self) -> int:
"""Returns the number of pages in the stack."""
return len(self.stack)
def top(self) -> Page | None:
"""Returns the top page from the stack without removing it"""
if self.size() == 0:
return None
return self.stack[-1]
def write_stack_to_page(self, page: Page) -> None:
# Logs a properly indented string displaying the current call stack.
page.add_lines("\nCALL STACK")
for stack_page in self.stack:
page.add_lines(stack_page.line_text)
page.add_lines("")
page.add_lines("")
page.flush()

View File

@@ -0,0 +1,133 @@
from typing import TYPE_CHECKING, Any
from agentdhal_core import CancellationToken, Image
from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
from agentdhal_core.model_context import ChatCompletionContext
from agentdhal_core.models import UserMessage
if TYPE_CHECKING:
from agentdhal_extensions.experimental.task_centric_memory import MemoryController
class Teachability(Memory):
"""
Gives an AssistantAgent the ability to learn quickly from user teachings, hints, and advice.
Steps for usage:
1. Instantiate MemoryController.
2. Instantiate Teachability, passing the memory controller as a parameter.
3. Instantiate an AssistantAgent, passing the teachability instance (wrapped in a list) as the memory parameter.
4. Use the AssistantAgent as usual, such as for chatting with the user.
"""
def __init__(self, memory_controller: "MemoryController", name: str | None = None) -> None:
"""Initialize Teachability."""
self._memory_controller = memory_controller
self._logger = memory_controller.logger
self._name = name or "teachability"
@property
def name(self) -> str:
"""Get the memory instance identifier."""
return self._name
def _extract_text(self, content_item: str | MemoryContent) -> str:
"""Extract searchable text from content."""
if isinstance(content_item, str):
return content_item
content = content_item.content
mime_type = content_item.mime_type
if mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]:
return str(content)
elif mime_type == MemoryMimeType.JSON:
if isinstance(content, dict):
# Store original JSON string representation
return str(content).lower()
raise ValueError("JSON content must be a dict")
elif isinstance(content, Image):
raise ValueError("Image content cannot be converted to text")
else:
raise ValueError(f"Unsupported content type: {mime_type}")
async def update_context(
self,
model_context: ChatCompletionContext,
) -> UpdateContextResult:
"""
Extracts any advice from the last user turn to be stored in memory,
and adds any relevant memories to the model context.
"""
self._logger.enter_function()
# Extract text from the user's last message
messages = await model_context.get_messages()
if not messages:
self._logger.leave_function()
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
last_message = messages[-1]
last_user_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
# Add any relevant memories to the chat history
query_results = await self.query(last_user_text)
if query_results.results:
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)]
memory_context = "\nPotentially relevant memories:\n" + "\n".join(memory_strings)
await model_context.add_message(UserMessage(content=memory_context, source="user"))
# Add any user advice to memory
await self._memory_controller.consider_memo_storage(last_user_text)
self._logger.leave_function()
return UpdateContextResult(memories=query_results)
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
"""
Tries to extract any advice from the passed content and add it to memory.
"""
self._logger.enter_function()
# Extract text from the incoming content
text = self._extract_text(content)
# Check for advice to add to memory for later turns.
await self._memory_controller.consider_memo_storage(text)
self._logger.leave_function()
async def query(
self,
query: str | MemoryContent,
cancellation_token: CancellationToken | None = None,
**kwargs: Any,
) -> MemoryQueryResult:
"""
Returns any memories that seem relevant to the query.
"""
self._logger.enter_function()
task = self._extract_text(query)
memory_results: list[MemoryContent] = []
filtered_memos = await self._memory_controller.retrieve_relevant_memos(task=task)
filtered_insights = [memo.insight for memo in filtered_memos]
for insight in filtered_insights:
self._logger.info(f"Insight: {insight}")
memory_content = MemoryContent(
content=insight,
mime_type="MemoryMimeType.TEXT",
metadata={},
)
memory_results.append(memory_content)
self._logger.leave_function()
return MemoryQueryResult(results=memory_results)
async def clear(self) -> None:
"""Clear all entries from memory."""
self._memory_controller.reset_memory()
async def close(self) -> None:
"""Clean up memory resources."""
pass # No cleanup needed for this memory implementation