first commit
This commit is contained in:
@@ -0,0 +1,210 @@
|
||||
# Task-Centric Memory
|
||||
_(EXPERIMENTAL, RESEARCH IN PROGRESS)_
|
||||
|
||||
**Task-Centric Memory** is an active research project aimed at giving AI agents the ability to:
|
||||
|
||||
* Accomplish general tasks more effectively by learning quickly and continually beyond context-window limitations.
|
||||
* Remember guidance, corrections, plans, and demonstrations provided by users.
|
||||
* Learn through the agent's own experience and adapt quickly to changing circumstances.
|
||||
* Avoid repeating mistakes on tasks that are similar to those previously encountered.
|
||||
|
||||
## Installation
|
||||
|
||||
Install AutoGen and its extension package as follows:
|
||||
|
||||
```bash
|
||||
pip install -U "autogen-agentchat" "autogen-ext[openai]" "autogen-ext[task-centric-memory]"
|
||||
```
|
||||
|
||||
## Quickstart
|
||||
|
||||
<p align="right">
|
||||
<img src="../../../../imgs/task_centric_memory_2.png" alt="Description" width="150" align="right" style="margin-left: 10px;">
|
||||
</p>
|
||||
|
||||
This first code snippet runs a basic test to verify that the installation was successful,
|
||||
as illustrated by the diagram to the right.
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.experimental.task_centric_memory import MemoryController
|
||||
from autogen_ext.experimental.task_centric_memory.utils import PageLogger
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart"}) # Optional, but very useful.
|
||||
memory_controller = MemoryController(reset=True, client=client, logger=logger)
|
||||
|
||||
# Add a few task-insight pairs as memories, where an insight can be any string that may help solve the task.
|
||||
await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color")
|
||||
await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan")
|
||||
await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite")
|
||||
|
||||
# Retrieve memories for a new task that's related to only two of the stored memories.
|
||||
memos = await memory_controller.retrieve_relevant_memos(task="What colors do I like most?")
|
||||
print("{} memories retrieved".format(len(memos)))
|
||||
for memo in memos:
|
||||
print("- " + memo.insight)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
<p align="right">
|
||||
<img src="../../../../imgs/task_centric_memory_3.png" alt="Description" width="150" align="right" style="margin-left: 10px;">
|
||||
</p>
|
||||
|
||||
This second code example shows one way to incorporate task-centric memory directly into an AutoGen agent,
|
||||
in this case a subclass of RoutedAgent.
|
||||
To keep the code short, only the simplest form of memory retrieval is exercised by this agent.
|
||||
|
||||
```python
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler
|
||||
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.experimental.task_centric_memory import MemoryController
|
||||
from autogen_ext.experimental.task_centric_memory.utils import PageLogger
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
content: str
|
||||
|
||||
|
||||
class MemoryEnabledAgent(RoutedAgent):
|
||||
def __init__(
|
||||
self, description: str, model_client: ChatCompletionClient, memory_controller: MemoryController
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
self._model_client = model_client
|
||||
self._memory_controller = memory_controller
|
||||
|
||||
@message_handler
|
||||
async def handle_message(self, message: Message, context: MessageContext) -> Message:
|
||||
# Retrieve relevant memories for the task.
|
||||
memos = await self._memory_controller.retrieve_relevant_memos(task=message.content)
|
||||
|
||||
# Format the memories for the model.
|
||||
formatted_memos = "Info that may be useful:\n" + "\n".join(["- " + memo.insight for memo in memos])
|
||||
print(f"{'-' * 23}Text appended to the user message{'-' * 24}\n{formatted_memos}\n{'-' * 80}")
|
||||
|
||||
# Create the messages for the model with the retrieved memories.
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content=message.content, source="user"),
|
||||
UserMessage(content=formatted_memos, source="user"),
|
||||
]
|
||||
|
||||
# Call the model with the messages.
|
||||
model_result = await self._model_client.create(messages=messages)
|
||||
assert isinstance(model_result.content, str)
|
||||
|
||||
# Send the model's response to the user.
|
||||
return Message(content=model_result.content)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart2"}) # Optional, but very useful.
|
||||
memory_controller = MemoryController(reset=True, client=client, logger=logger)
|
||||
|
||||
# Prepopulate memory to mimic learning from a prior session.
|
||||
await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color")
|
||||
await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan")
|
||||
await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite")
|
||||
|
||||
# Create and start an agent runtime.
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
runtime.start()
|
||||
|
||||
# Register the agent type.
|
||||
await MemoryEnabledAgent.register(
|
||||
runtime,
|
||||
"memory_enabled_agent",
|
||||
lambda: MemoryEnabledAgent(
|
||||
"A agent with memory", model_client=client, memory_controller=memory_controller
|
||||
),
|
||||
)
|
||||
|
||||
# Send a direct message to the agent.
|
||||
request = "What colors do I like most?"
|
||||
print("User request: " + request)
|
||||
response = await runtime.send_message(
|
||||
Message(content=request), AgentId("memory_enabled_agent", "default")
|
||||
)
|
||||
print("Agent response: " + response.content)
|
||||
|
||||
# Stop the agent runtime.
|
||||
await runtime.stop()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## Sample Code
|
||||
|
||||
The example above modifies the agent's code.
|
||||
But it's also possible to add task-centric memory to an agent or multi-agent team _without_ modifying any agent code.
|
||||
See the [sample code](../../../../../../samples/task_centric_memory) for that and other forms of fast, memory-based learning.
|
||||
|
||||
|
||||
## Architecture
|
||||
|
||||
<p align="right">
|
||||
<img src="../../../../imgs/task_centric_memory.png" alt="Description" width="300" align="right" style="margin-left: 10px;">
|
||||
</p>
|
||||
|
||||
The block diagram to the right outlines the key components of the architecture in the most general form.
|
||||
The memory components are shown in blue, and the green blocks represent external components.
|
||||
|
||||
The **Memory Controller** implements the fast-learning methods described below,
|
||||
and manages communication with a **Memory Bank** containing a vector DB and associated structures.
|
||||
|
||||
The **Agent or Team** is the AI agent or team of agents to which memory is being added.
|
||||
The sample code shows how to add task-centric memory to a simple AssistantAgent or a MagenticOneGroupChat team.
|
||||
|
||||
The **Apprentice, app, or service** represents the code that instantiates the agent and memory controller,
|
||||
and routes information between them, effectively wrapping agent and memory into a combined component.
|
||||
The term _Apprentice_ connotes that this combination uses memory to learn quickly on the job.
|
||||
The Apprentice class is a minimal reference implementation provided as utility code for illustration and testing,
|
||||
but most applications will use their own code instead of the Apprentice.
|
||||
|
||||
## Memory Creation and Storage
|
||||
|
||||
Each stored memory (called a _memo_) contains a text insight and (optionally) a task description.
|
||||
The insight is intended to help the agent accomplish future tasks that are similar to a prior task.
|
||||
The memory controller provides methods for different types of learning.
|
||||
If the user provides advice for solving a given task, the advice is extracted by the model client and stored as an insight.
|
||||
If the user demonstrates how to perform a task,
|
||||
the task and demonstration are stored together as an insight used to solve similar but different tasks.
|
||||
If the agent is given a task (free of side-effects) and some means of determining success or failure,
|
||||
the memory controller repeats the following learning loop in the background some number of times:
|
||||
|
||||
1. Test the agent on the task a few times to check for a failure.
|
||||
2. If a failure is found, analyze the agent's response in order to:
|
||||
1. Diagnose the failure of reasoning or missing information,
|
||||
2. Phrase a general piece of advice, such as what a teacher might give to a student,
|
||||
3. Temporarily append this advice to the task description,
|
||||
4. Return to step 1.
|
||||
5. If some piece of advice succeeds in helping the agent solve the task a number of times, add the advice as an insight to memory.
|
||||
3. For each insight to be stored in memory, an LLM is prompted to generate a set of free-form, multi-word topics related to the insight. Each topic is embedded to a fixed-length vector and stored in a vector DB mapping it to the topic’s related insight.
|
||||
|
||||
## Memory Retrieval and Usage
|
||||
|
||||
The memory controller provides methods for different types of memory retrieval.
|
||||
When the agent is given a task, the following steps are performed by the controller:
|
||||
1. The task is rephrased into a generalized form.
|
||||
2. A set of free-form, multi-word query topics are generated from the generalized task.
|
||||
3. A potentially large number of previously stored topics, those most similar to each query topic, are retrieved from the vector DB along with the insights they map to.
|
||||
4. These candidate memos are filtered by the aggregate similarity of their stored topics to the query topics.
|
||||
5. In the final filtering stage, an LLM is prompted to validate only those insights that seem potentially useful in solving the task at hand.
|
||||
|
||||
Retrieved insights that pass the filtering steps are listed under a heading like
|
||||
"Important insights that may help solve tasks like this", then appended to the task description before it is passed to the agent as usual.
|
||||
@@ -0,0 +1,4 @@
|
||||
from ._memory_bank import MemoryBankConfig
|
||||
from .memory_controller import MemoryController, MemoryControllerConfig
|
||||
|
||||
__all__ = ["MemoryController", "MemoryControllerConfig", "MemoryBankConfig"]
|
||||
@@ -0,0 +1,201 @@
|
||||
import os
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict
|
||||
|
||||
from ._string_similarity_map import StringSimilarityMap
|
||||
from .utils.page_logger import PageLogger
|
||||
|
||||
|
||||
@dataclass
|
||||
class Memo:
|
||||
"""
|
||||
Represents an atomic unit of memory that can be stored in a memory bank and later retrieved.
|
||||
"""
|
||||
|
||||
task: str | None # The task description, if any.
|
||||
insight: str # A hint, solution, plan, or any other text that may help solve a similar task.
|
||||
|
||||
|
||||
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
|
||||
# the settings that change frequently, as when loading many settings from a single YAML file.
|
||||
class MemoryBankConfig(TypedDict, total=False):
|
||||
path: str
|
||||
relevance_conversion_threshold: float
|
||||
n_results: int
|
||||
distance_threshold: int
|
||||
|
||||
|
||||
class MemoryBank:
|
||||
"""
|
||||
Stores task-completion insights as memories in a vector DB for later retrieval.
|
||||
|
||||
Args:
|
||||
reset: True to clear the DB before starting.
|
||||
config: An optional dict that can be used to override the following values:
|
||||
|
||||
- path: The path to the directory where the memory bank files are stored.
|
||||
- relevance_conversion_threshold: The threshold used to normalize relevance.
|
||||
- n_results: The maximum number of most relevant results to return for any given topic.
|
||||
- distance_threshold: The maximum string-pair distance for a memo to be retrieved.
|
||||
|
||||
logger: An optional logger. If None, no logging will be performed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reset: bool,
|
||||
config: MemoryBankConfig | None = None,
|
||||
logger: PageLogger | None = None,
|
||||
) -> None:
|
||||
if logger is None:
|
||||
logger = PageLogger() # Nothing will be logged by this object.
|
||||
self.logger = logger
|
||||
self.logger.enter_function()
|
||||
|
||||
# Apply default settings and any config overrides.
|
||||
memory_dir_path = "./memory_bank/default"
|
||||
self.relevance_conversion_threshold = 1.7
|
||||
self.n_results = 25
|
||||
self.distance_threshold = 100
|
||||
if config is not None:
|
||||
memory_dir_path = config.get("path", memory_dir_path)
|
||||
self.relevance_conversion_threshold = config.get(
|
||||
"relevance_conversion_threshold", self.relevance_conversion_threshold
|
||||
)
|
||||
self.n_results = config.get("n_results", self.n_results)
|
||||
self.distance_threshold = config.get("distance_threshold", self.distance_threshold)
|
||||
|
||||
memory_dir_path = os.path.expanduser(memory_dir_path)
|
||||
self.logger.info("\nMEMORY BANK DIRECTORY {}".format(memory_dir_path))
|
||||
path_to_db_dir = os.path.join(memory_dir_path, "string_map")
|
||||
self.path_to_dict = os.path.join(memory_dir_path, "uid_memo_dict.pkl")
|
||||
|
||||
self.string_map = StringSimilarityMap(reset=reset, path_to_db_dir=path_to_db_dir, logger=self.logger)
|
||||
|
||||
# Load or create the associated memo dict on disk.
|
||||
self.uid_memo_dict: Dict[str, Memo] = {}
|
||||
self.last_memo_id = 0
|
||||
if (not reset) and os.path.exists(self.path_to_dict):
|
||||
self.logger.info("\nLOADING MEMOS FROM DISK at {}".format(self.path_to_dict))
|
||||
with open(self.path_to_dict, "rb") as f:
|
||||
self.uid_memo_dict = pickle.load(f)
|
||||
self.last_memo_id = len(self.uid_memo_dict)
|
||||
self.logger.info("\n{} MEMOS LOADED".format(len(self.uid_memo_dict)))
|
||||
|
||||
# Clear the DB if requested.
|
||||
if reset:
|
||||
self._reset_memos()
|
||||
|
||||
self.logger.leave_function()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Forces immediate deletion of all contents, in memory and on disk.
|
||||
"""
|
||||
self.string_map.reset_db()
|
||||
self._reset_memos()
|
||||
|
||||
def _reset_memos(self) -> None:
|
||||
"""
|
||||
Forces immediate deletion of the memos, in memory and on disk.
|
||||
"""
|
||||
self.logger.info("\nCLEARING MEMOS")
|
||||
self.uid_memo_dict = {}
|
||||
self.save_memos()
|
||||
|
||||
def save_memos(self) -> None:
|
||||
"""
|
||||
Saves the current memo structures (possibly empty) to disk.
|
||||
"""
|
||||
self.string_map.save_string_pairs()
|
||||
with open(self.path_to_dict, "wb") as file:
|
||||
self.logger.info("\nSAVING MEMOS TO DISK at {}".format(self.path_to_dict))
|
||||
pickle.dump(self.uid_memo_dict, file)
|
||||
|
||||
def contains_memos(self) -> bool:
|
||||
"""
|
||||
Returns True if the memory bank contains any memo.
|
||||
"""
|
||||
return len(self.uid_memo_dict) > 0
|
||||
|
||||
def _map_topics_to_memo(self, topics: List[str], memo_id: str, memo: Memo) -> None:
|
||||
"""
|
||||
Adds a mapping in the vec DB from each topic to the memo.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.logger.info("\nINSIGHT\n{}".format(memo.insight))
|
||||
for topic in topics:
|
||||
self.logger.info("\n TOPIC = {}".format(topic))
|
||||
self.string_map.add_input_output_pair(topic, memo_id)
|
||||
self.uid_memo_dict[memo_id] = memo
|
||||
self.save_memos()
|
||||
self.logger.leave_function()
|
||||
|
||||
def add_memo(self, insight_str: str, topics: List[str], task_str: Optional[str] = None) -> None:
|
||||
"""
|
||||
Adds an insight to the memory bank, given topics related to the insight, and optionally the task.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.last_memo_id += 1
|
||||
id_str = str(self.last_memo_id)
|
||||
insight = Memo(insight=insight_str, task=task_str)
|
||||
self._map_topics_to_memo(topics, id_str, insight)
|
||||
self.logger.leave_function()
|
||||
|
||||
def add_task_with_solution(self, task: str, solution: str, topics: List[str]) -> None:
|
||||
"""
|
||||
Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight.
|
||||
This is useful when the insight is a demonstration of how to solve a given type of task.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.last_memo_id += 1
|
||||
id_str = str(self.last_memo_id)
|
||||
# Prepend the insight to the task description for context.
|
||||
insight_str = "Example task:\n\n{}\n\nExample solution:\n\n{}".format(task, solution)
|
||||
memo = Memo(insight=insight_str, task=task)
|
||||
self._map_topics_to_memo(topics, id_str, memo)
|
||||
self.logger.leave_function()
|
||||
|
||||
def get_relevant_memos(self, topics: List[str]) -> List[Memo]:
|
||||
"""
|
||||
Returns any memos from the memory bank that appear sufficiently relevant to the input topics.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
# Retrieve all topic matches, and gather them into a single list.
|
||||
matches: List[Tuple[str, str, float]] = [] # Each match is a tuple: (topic, memo_id, distance)
|
||||
for topic in topics:
|
||||
matches.extend(self.string_map.get_related_string_pairs(topic, self.n_results, self.distance_threshold))
|
||||
|
||||
# Build a dict of memo-relevance pairs from the matches.
|
||||
memo_relevance_dict: Dict[str, float] = {}
|
||||
for match in matches:
|
||||
relevance = self.relevance_conversion_threshold - match[2]
|
||||
memo_id = match[1]
|
||||
if memo_id in memo_relevance_dict:
|
||||
memo_relevance_dict[memo_id] += relevance
|
||||
else:
|
||||
memo_relevance_dict[memo_id] = relevance
|
||||
|
||||
# Log the details of all the retrieved memos.
|
||||
self.logger.info("\n{} POTENTIALLY RELEVANT MEMOS".format(len(memo_relevance_dict)))
|
||||
for memo_id, relevance in memo_relevance_dict.items():
|
||||
memo = self.uid_memo_dict[memo_id]
|
||||
details = ""
|
||||
if memo.task is not None:
|
||||
details += "\n TASK: {}\n".format(memo.task)
|
||||
details += "\n INSIGHT: {}\n\n RELEVANCE: {:.3f}\n".format(memo.insight, relevance)
|
||||
self.logger.info(details)
|
||||
|
||||
# Sort the memo-relevance pairs by relevance, in descending order.
|
||||
memo_relevance_dict = dict(sorted(memo_relevance_dict.items(), key=lambda item: item[1], reverse=True))
|
||||
|
||||
# Compose the list of sufficiently relevant memos to return.
|
||||
memo_list: List[Memo] = []
|
||||
for memo_id in memo_relevance_dict:
|
||||
if memo_relevance_dict[memo_id] >= 0:
|
||||
memo_list.append(self.uid_memo_dict[memo_id])
|
||||
|
||||
self.logger.leave_function()
|
||||
return memo_list
|
||||
@@ -0,0 +1,289 @@
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
from agentdhal_core import Image
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from .utils._functions import UserContent
|
||||
from .utils.page_logger import PageLogger
|
||||
|
||||
|
||||
class Prompter:
|
||||
"""
|
||||
Centralizes most of the Apprentice prompts sent to the model client.
|
||||
|
||||
Args:
|
||||
client: The client to call the model.
|
||||
logger: An optional logger. If None, no logging will be performed.
|
||||
"""
|
||||
|
||||
def __init__(self, client: ChatCompletionClient, logger: PageLogger | None = None) -> None:
|
||||
if logger is None:
|
||||
logger = PageLogger() # Nothing will be logged by this object.
|
||||
self.logger = logger
|
||||
|
||||
self.client = client
|
||||
self.default_system_message_content = "You are a helpful assistant."
|
||||
self.time_spent_in_model_calls = 0.0
|
||||
self.num_model_calls = 0
|
||||
self.start_time = time.time()
|
||||
|
||||
# Create the chat history
|
||||
self._chat_history: List[LLMMessage] = []
|
||||
|
||||
async def call_model(
|
||||
self,
|
||||
summary: str,
|
||||
user_content: UserContent,
|
||||
system_message_content: str | None = None,
|
||||
keep_these_messages: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Calls the model client with the given input and returns the response.
|
||||
"""
|
||||
# Prepare the input message list
|
||||
if system_message_content is None:
|
||||
system_message_content = self.default_system_message_content
|
||||
system_message: LLMMessage
|
||||
if self.client.model_info["family"] == "o1":
|
||||
# No system message allowed, so pass it as the first user message.
|
||||
system_message = UserMessage(content=system_message_content, source="User")
|
||||
else:
|
||||
# System message allowed.
|
||||
system_message = SystemMessage(content=system_message_content)
|
||||
|
||||
user_message = UserMessage(content=user_content, source="User")
|
||||
input_messages = [system_message] + self._chat_history + [user_message]
|
||||
|
||||
# Double check the types of the input messages.
|
||||
for message in input_messages:
|
||||
for part in message.content:
|
||||
assert isinstance(part, str) or isinstance(part, Image), "Invalid message content type: {}".format(
|
||||
type(part)
|
||||
)
|
||||
|
||||
# Call the model
|
||||
start_time = time.time()
|
||||
response = await self.client.create(input_messages)
|
||||
assert isinstance(response, CreateResult)
|
||||
response_string = response.content
|
||||
assert isinstance(response_string, str)
|
||||
response_message = AssistantMessage(content=response_string, source="Assistant")
|
||||
assert isinstance(response_message, AssistantMessage)
|
||||
self.time_spent_in_model_calls += time.time() - start_time
|
||||
self.num_model_calls += 1
|
||||
|
||||
# Log the model call
|
||||
self.logger.log_model_call(summary=summary, input_messages=input_messages, response=response)
|
||||
|
||||
# Manage the chat history
|
||||
if keep_these_messages:
|
||||
self._chat_history.append(user_message)
|
||||
self._chat_history.append(response_message)
|
||||
|
||||
# Return the response as a string for now
|
||||
return response_string
|
||||
|
||||
def _clear_history(self) -> None:
|
||||
"""
|
||||
Empties the message list containing the chat history.
|
||||
"""
|
||||
self._chat_history = []
|
||||
|
||||
async def learn_from_failure(
|
||||
self, task_description: str, memory_section: str, final_response: str, expected_answer: str, work_history: str
|
||||
) -> str:
|
||||
"""
|
||||
Tries to create an insight to help avoid the given failure in the future.
|
||||
"""
|
||||
sys_message = """- You are a patient and thorough teacher.
|
||||
- Your job is to review work done by students and help them learn how to do better."""
|
||||
|
||||
user_message: List[Union[str, Image]] = []
|
||||
user_message.append("# A team of students made a mistake on the following task:\n")
|
||||
user_message.extend([task_description])
|
||||
|
||||
if len(memory_section) > 0:
|
||||
user_message.append(memory_section)
|
||||
|
||||
user_message.append("# Here's the expected answer, which would have been correct:\n")
|
||||
user_message.append(expected_answer)
|
||||
|
||||
user_message.append("# Here is the students' answer, which was INCORRECT:\n")
|
||||
user_message.append(final_response)
|
||||
|
||||
user_message.append("# Please review the students' work which follows:\n")
|
||||
user_message.append("**----- START OF STUDENTS' WORK -----**\n\n")
|
||||
user_message.append(work_history)
|
||||
user_message.append("\n**----- END OF STUDENTS' WORK -----**\n\n")
|
||||
|
||||
user_message.append(
|
||||
"# Now carefully review the students' work above, explaining in detail what the students did right and what they did wrong.\n"
|
||||
)
|
||||
|
||||
self._clear_history()
|
||||
await self.call_model(
|
||||
summary="Ask the model to learn from this failure",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
user_message = [
|
||||
"Now put yourself in the mind of the students. What misconception led them to their incorrect answer?"
|
||||
]
|
||||
await self.call_model(
|
||||
summary="Ask the model to state the misconception",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
|
||||
user_message = [
|
||||
"Please express your key insights in the form of short, general advice that will be given to the students. Just one or two sentences, or they won't bother to read it."
|
||||
]
|
||||
insight = await self.call_model(
|
||||
summary="Ask the model to formulate a concise insight",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
return insight
|
||||
|
||||
async def find_index_topics(self, input_string: str) -> List[str]:
|
||||
"""
|
||||
Returns a list of topics related to the given string.
|
||||
"""
|
||||
sys_message = """You are an expert at semantic analysis."""
|
||||
|
||||
user_message: List[Union[str, Image]] = []
|
||||
user_message.append("""- My job is to create a thorough index for a book called Task Completion, and I need your help.
|
||||
- Every paragraph in the book needs to be indexed by all the topics related to various kinds of tasks and strategies for completing them.
|
||||
- Your job is to read the text below and extract the task-completion topics that are covered.
|
||||
- The number of topics depends on the length and content of the text. But you should list at least one topic, and potentially many more.
|
||||
- Each topic you list should be a meaningful phrase composed of a few words. Don't use whole sentences as topics.
|
||||
- Don't include details that are unrelated to the general nature of the task, or a potential strategy for completing tasks.
|
||||
- List each topic on a separate line, without any extra text like numbering, or bullets, or any other formatting, because we don't want those things in the index of the book.\n\n""")
|
||||
|
||||
user_message.append("# Text to be indexed\n")
|
||||
user_message.append(input_string)
|
||||
|
||||
self._clear_history()
|
||||
topics = await self.call_model(
|
||||
summary="Ask the model to extract topics", system_message_content=sys_message, user_content=user_message
|
||||
)
|
||||
|
||||
# Parse the topics into a list.
|
||||
topic_list: List[str] = []
|
||||
for line in topics.split("\n"):
|
||||
if len(line) > 0:
|
||||
topic_list.append(line)
|
||||
|
||||
return topic_list
|
||||
|
||||
async def generalize_task(self, task_description: str, revise: bool | None = True) -> str:
|
||||
"""
|
||||
Attempts to rewrite a task description in a more general form.
|
||||
"""
|
||||
|
||||
sys_message = """You are a helpful and thoughtful assistant."""
|
||||
|
||||
user_message: List[Union[str, Image]] = [
|
||||
"We have been given a task description. Our job is not to complete the task, but merely rephrase the task in simpler, more general terms, if possible. Please reach through the following task description, then explain your understanding of the task in detail, as a single, flat list of all the important points."
|
||||
]
|
||||
user_message.append("\n# Task description")
|
||||
user_message.append(task_description)
|
||||
|
||||
self._clear_history()
|
||||
generalized_task = await self.call_model(
|
||||
summary="Ask the model to rephrase the task in a list of important points",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
|
||||
if revise:
|
||||
user_message = [
|
||||
"Do you see any parts of this list that are irrelevant to actually solving the task? If so, explain which items are irrelevant."
|
||||
]
|
||||
await self.call_model(
|
||||
summary="Ask the model to identify irrelevant points",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
|
||||
user_message = [
|
||||
"Revise your original list to include only the most general terms, those that are critical to solving the task, removing any themes or descriptions that are not essential to the solution. Your final list may be shorter, but do not leave out any part of the task that is needed for solving the task. Do not add any additional commentary either before or after the list."
|
||||
]
|
||||
generalized_task = await self.call_model(
|
||||
summary="Ask the model to make a final list of general terms",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
|
||||
return generalized_task
|
||||
|
||||
async def validate_insight(self, insight: str, task_description: str) -> bool:
|
||||
"""
|
||||
Judges whether the insight could help solve the task.
|
||||
"""
|
||||
|
||||
sys_message = """You are a helpful and thoughtful assistant."""
|
||||
|
||||
user_message: List[Union[str, Image]] = [
|
||||
"""We have been given a potential insight that may or may not be useful for solving a given task.
|
||||
- First review the following task.
|
||||
- Then review the insight that follows, and consider whether it might help solve the given task.
|
||||
- Do not attempt to actually solve the task.
|
||||
- Reply with a single character, '1' if the insight may be useful, or '0' if it is not."""
|
||||
]
|
||||
user_message.append("\n# Task description")
|
||||
user_message.append(task_description)
|
||||
user_message.append("\n# Possibly useful insight")
|
||||
user_message.append(insight)
|
||||
self._clear_history()
|
||||
response = await self.call_model(
|
||||
summary="Ask the model to validate the insight",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
return response == "1"
|
||||
|
||||
async def extract_task(self, text: str) -> str | None:
|
||||
"""
|
||||
Returns a task found in the given text, or None if not found.
|
||||
"""
|
||||
sys_message = """You are a helpful and thoughtful assistant."""
|
||||
user_message: List[Union[str, Image]] = [
|
||||
"""Does the following text contain a question or a some task we are being asked to perform?
|
||||
- If so, please reply with the full question or task description, along with any supporting information, but without adding extra commentary or formatting.
|
||||
- If the task is just to remember something, that doesn't count as a task, so don't include it.
|
||||
- If there is no question or task in the text, simply write "None" with no punctuation."""
|
||||
]
|
||||
user_message.append("\n# Text to analyze")
|
||||
user_message.append(text)
|
||||
self._clear_history()
|
||||
response = await self.call_model(
|
||||
summary="Ask the model to extract a task", system_message_content=sys_message, user_content=user_message
|
||||
)
|
||||
return response if response != "None" else None
|
||||
|
||||
async def extract_advice(self, text: str) -> str | None:
|
||||
"""
|
||||
Returns advice from the given text, or None if not found.
|
||||
"""
|
||||
sys_message = """You are a helpful and thoughtful assistant."""
|
||||
user_message: List[Union[str, Image]] = [
|
||||
"""Does the following text contain any information or advice that might be useful later?
|
||||
- If so, please copy the information or advice, adding no extra commentary or formatting.
|
||||
- If there is no potentially useful information or advice at all, simply write "None" with no punctuation."""
|
||||
]
|
||||
user_message.append("\n# Text to analyze")
|
||||
user_message.append(text)
|
||||
self._clear_history()
|
||||
response = await self.call_model(
|
||||
summary="Ask the model to extract advice", system_message_content=sys_message, user_content=user_message
|
||||
)
|
||||
return response if response != "None" else None
|
||||
@@ -0,0 +1,124 @@
|
||||
import os
|
||||
import pickle
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import chromadb
|
||||
from chromadb.api.types import (
|
||||
QueryResult,
|
||||
)
|
||||
from chromadb.config import Settings
|
||||
|
||||
from .utils.page_logger import PageLogger
|
||||
|
||||
|
||||
class StringSimilarityMap:
|
||||
"""
|
||||
Provides storage and similarity-based retrieval of string pairs using a vector database.
|
||||
Each DB entry is a pair of strings: an input string and an output string.
|
||||
The input string is embedded and used as the retrieval key.
|
||||
The output string can be anything, but it's typically used as a dict key.
|
||||
Vector embeddings are currently supplied by Chroma's default Sentence Transformers.
|
||||
|
||||
Args:
|
||||
- reset: True to clear the DB immediately after creation.
|
||||
- path_to_db_dir: Path to the directory where the DB is stored.
|
||||
- logger: An optional logger. If None, no logging will be performed.
|
||||
"""
|
||||
|
||||
def __init__(self, reset: bool, path_to_db_dir: str, logger: PageLogger | None = None) -> None:
|
||||
if logger is None:
|
||||
logger = PageLogger() # Nothing will be logged by this object.
|
||||
self.logger = logger
|
||||
self.path_to_db_dir = path_to_db_dir
|
||||
|
||||
# Load or create the vector DB on disk.
|
||||
chromadb_settings = Settings(
|
||||
anonymized_telemetry=False, allow_reset=True, is_persistent=True, persist_directory=path_to_db_dir
|
||||
)
|
||||
self.db_client = chromadb.Client(chromadb_settings)
|
||||
self.vec_db = self.db_client.create_collection("string-pairs", get_or_create=True) # The collection is the DB.
|
||||
|
||||
# Load or create the associated string-pair dict on disk.
|
||||
self.path_to_dict = os.path.join(path_to_db_dir, "uid_text_dict.pkl")
|
||||
self.uid_text_dict: Dict[str, Tuple[str, str]] = {}
|
||||
self.last_string_pair_id = 0
|
||||
if (not reset) and os.path.exists(self.path_to_dict):
|
||||
self.logger.debug("\nLOADING STRING SIMILARITY MAP FROM DISK at {}".format(self.path_to_dict))
|
||||
with open(self.path_to_dict, "rb") as f:
|
||||
self.uid_text_dict = pickle.load(f)
|
||||
self.last_string_pair_id = len(self.uid_text_dict)
|
||||
if len(self.uid_text_dict) > 0:
|
||||
self.logger.debug("\n{} STRING PAIRS LOADED".format(len(self.uid_text_dict)))
|
||||
self._log_string_pairs()
|
||||
|
||||
# Clear the DB if requested.
|
||||
if reset:
|
||||
self.reset_db()
|
||||
|
||||
def _log_string_pairs(self) -> None:
|
||||
"""
|
||||
Logs all string pairs currently in the map.
|
||||
"""
|
||||
self.logger.debug("LIST OF STRING PAIRS")
|
||||
for uid, text in self.uid_text_dict.items():
|
||||
input_text, output_text = text
|
||||
self.logger.debug(" ID: {}\n INPUT TEXT: {}\n OUTPUT TEXT: {}".format(uid, input_text, output_text))
|
||||
|
||||
def save_string_pairs(self) -> None:
|
||||
"""
|
||||
Saves the string-pair dict (self.uid_text_dict) to disk.
|
||||
"""
|
||||
self.logger.debug("\nSAVING STRING SIMILARITY MAP TO DISK at {}".format(self.path_to_dict))
|
||||
with open(self.path_to_dict, "wb") as file:
|
||||
pickle.dump(self.uid_text_dict, file)
|
||||
|
||||
def reset_db(self) -> None:
|
||||
"""
|
||||
Forces immediate deletion of the DB's contents, in memory and on disk.
|
||||
"""
|
||||
self.logger.debug("\nCLEARING STRING-PAIR MAP")
|
||||
self.db_client.delete_collection("string-pairs")
|
||||
self.vec_db = self.db_client.create_collection("string-pairs")
|
||||
self.uid_text_dict = {}
|
||||
self.save_string_pairs()
|
||||
|
||||
def add_input_output_pair(self, input_text: str, output_text: str) -> None:
|
||||
"""
|
||||
Adds one input-output string pair to the DB.
|
||||
"""
|
||||
self.last_string_pair_id += 1
|
||||
self.vec_db.add(documents=[input_text], ids=[str(self.last_string_pair_id)])
|
||||
self.uid_text_dict[str(self.last_string_pair_id)] = input_text, output_text
|
||||
self.logger.debug(
|
||||
"\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {}\n INPUT\n {}\n OUTPUT\n {}\n".format(
|
||||
self.last_string_pair_id, input_text, output_text
|
||||
)
|
||||
)
|
||||
# self._log_string_pairs() # For deeper debugging, uncomment to log all string pairs after each addition.
|
||||
|
||||
def get_related_string_pairs(
|
||||
self, query_text: str, n_results: int, threshold: Union[int, float]
|
||||
) -> List[Tuple[str, str, float]]:
|
||||
"""
|
||||
Retrieves up to n string pairs that are related to the given query text within the specified distance threshold.
|
||||
"""
|
||||
string_pairs_with_distances: List[Tuple[str, str, float]] = []
|
||||
if n_results > len(self.uid_text_dict):
|
||||
n_results = len(self.uid_text_dict)
|
||||
if n_results > 0:
|
||||
results: QueryResult = self.vec_db.query(query_texts=[query_text], n_results=n_results)
|
||||
num_results = len(results["ids"][0])
|
||||
for i in range(num_results):
|
||||
uid = results["ids"][0][i]
|
||||
input_text = results["documents"][0][i] if results["documents"] else ""
|
||||
distance = results["distances"][0][i] if results["distances"] else 0.0
|
||||
if distance < threshold:
|
||||
input_text_2, output_text = self.uid_text_dict[uid]
|
||||
assert input_text == input_text_2
|
||||
self.logger.debug(
|
||||
"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {}\n OUTPUT\n {}\n DISTANCE\n {}".format(
|
||||
input_text, output_text, distance
|
||||
)
|
||||
)
|
||||
string_pairs_with_distances.append((input_text, output_text, distance))
|
||||
return string_pairs_with_distances
|
||||
@@ -0,0 +1,478 @@
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, List, Tuple, TypedDict
|
||||
|
||||
from agentdhal_core.models import (
|
||||
ChatCompletionClient,
|
||||
)
|
||||
|
||||
from ._memory_bank import Memo, MemoryBank
|
||||
from ._prompter import Prompter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._memory_bank import MemoryBankConfig
|
||||
from .utils.grader import Grader
|
||||
from .utils.page_logger import PageLogger
|
||||
|
||||
|
||||
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
|
||||
# the settings that change frequently, as when loading many settings from a single YAML file.
|
||||
class MemoryControllerConfig(TypedDict, total=False):
|
||||
generalize_task: bool
|
||||
revise_generalized_task: bool
|
||||
generate_topics: bool
|
||||
validate_memos: bool
|
||||
max_memos_to_retrieve: int
|
||||
max_train_trials: int
|
||||
max_test_trials: int
|
||||
MemoryBank: "MemoryBankConfig"
|
||||
|
||||
|
||||
class MemoryController:
|
||||
"""
|
||||
(EXPERIMENTAL, RESEARCH IN PROGRESS)
|
||||
|
||||
Implements fast, memory-based learning, and manages the flow of information to and from a memory bank.
|
||||
|
||||
Args:
|
||||
reset: True to empty the memory bank before starting.
|
||||
client: The model client to use internally.
|
||||
task_assignment_callback: An optional callback used to assign a task to any agent managed by the caller.
|
||||
config: An optional dict that can be used to override the following values:
|
||||
|
||||
- generalize_task: Whether to rewrite tasks in more general terms.
|
||||
- revise_generalized_task: Whether to critique then rewrite the generalized task.
|
||||
- generate_topics: Whether to base retrieval directly on tasks, or on topics extracted from tasks.
|
||||
- validate_memos: Whether to apply a final validation stage to retrieved memos.
|
||||
- max_memos_to_retrieve: The maximum number of memos to return from retrieve_relevant_memos().
|
||||
- max_train_trials: The maximum number of learning iterations to attempt when training on a task.
|
||||
- max_test_trials: The total number of attempts made when testing for failure on a task.
|
||||
- MemoryBank: A config dict passed to MemoryBank.
|
||||
|
||||
logger: An optional logger. If None, a default logger will be created.
|
||||
|
||||
Example:
|
||||
|
||||
The `task-centric-memory` extra first needs to be installed:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[task-centric-memory]"
|
||||
|
||||
The following code snippet shows how to use this class for the most basic storage and retrieval of memories.:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.experimental.task_centric_memory import MemoryController
|
||||
from agentdhal_extensions.experimental.task_centric_memory.utils import PageLogger
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
logger = PageLogger(config={"level": "DEBUG", "path": "./pagelogs/quickstart"}) # Optional, but very useful.
|
||||
memory_controller = MemoryController(reset=True, client=client, logger=logger)
|
||||
|
||||
# Add a few task-insight pairs as memories, where an insight can be any string that may help solve the task.
|
||||
await memory_controller.add_memo(task="What color do I like?", insight="Deep blue is my favorite color")
|
||||
await memory_controller.add_memo(task="What's another color I like?", insight="I really like cyan")
|
||||
await memory_controller.add_memo(task="What's my favorite food?", insight="Halibut is my favorite")
|
||||
|
||||
# Retrieve memories for a new task that's related to only two of the stored memories.
|
||||
memos = await memory_controller.retrieve_relevant_memos(task="What colors do I like most?")
|
||||
print("{} memories retrieved".format(len(memos)))
|
||||
for memo in memos:
|
||||
print("- " + memo.insight)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reset: bool,
|
||||
client: ChatCompletionClient,
|
||||
task_assignment_callback: Callable[[str], Awaitable[Tuple[str, str]]] | None = None,
|
||||
config: MemoryControllerConfig | None = None,
|
||||
logger: PageLogger | None = None,
|
||||
) -> None:
|
||||
if logger is None:
|
||||
logger = PageLogger({"level": "DEBUG"})
|
||||
self.logger = logger
|
||||
self.logger.enter_function()
|
||||
|
||||
# Apply default settings and any config overrides.
|
||||
self.generalize_task = True
|
||||
self.revise_generalized_task = True
|
||||
self.generate_topics = True
|
||||
self.validate_memos = True
|
||||
self.max_memos_to_retrieve = 10
|
||||
self.max_train_trials = 10
|
||||
self.max_test_trials = 3
|
||||
memory_bank_config = None
|
||||
if config is not None:
|
||||
self.generalize_task = config.get("generalize_task", self.generalize_task)
|
||||
self.revise_generalized_task = config.get("revise_generalized_task", self.revise_generalized_task)
|
||||
self.generate_topics = config.get("generate_topics", self.generate_topics)
|
||||
self.validate_memos = config.get("validate_memos", self.validate_memos)
|
||||
self.max_memos_to_retrieve = config.get("max_memos_to_retrieve", self.max_memos_to_retrieve)
|
||||
self.max_train_trials = config.get("max_train_trials", self.max_train_trials)
|
||||
self.max_test_trials = config.get("max_test_trials", self.max_test_trials)
|
||||
memory_bank_config = config.get("MemoryBank", memory_bank_config)
|
||||
|
||||
self.client = client
|
||||
self.task_assignment_callback = task_assignment_callback
|
||||
self.prompter = Prompter(client, logger)
|
||||
self.memory_bank = MemoryBank(reset=reset, config=memory_bank_config, logger=logger)
|
||||
self.grader = Grader(client, logger)
|
||||
self.logger.leave_function()
|
||||
|
||||
def reset_memory(self) -> None:
|
||||
"""
|
||||
Empties the memory bank in RAM and on disk.
|
||||
"""
|
||||
self.memory_bank.reset()
|
||||
|
||||
async def train_on_task(self, task: str, expected_answer: str) -> None:
|
||||
"""
|
||||
Repeatedly assigns a task to the agent, and tries to learn from failures by creating useful insights as memories.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.logger.info("Iterate on the task, possibly discovering a useful new insight.\n")
|
||||
_, insight = await self._iterate_on_task(task, expected_answer)
|
||||
if insight is None:
|
||||
self.logger.info("No useful insight was discovered.\n")
|
||||
else:
|
||||
self.logger.info("A new insight was created:\n{}".format(insight))
|
||||
await self.add_memo(insight, task)
|
||||
self.logger.leave_function()
|
||||
|
||||
async def test_on_task(self, task: str, expected_answer: str, num_trials: int = 1) -> Tuple[str, int, int]:
|
||||
"""
|
||||
Assigns a task to the agent, along with any relevant memos retrieved from memory.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
assert self.task_assignment_callback is not None
|
||||
response = ""
|
||||
num_successes = 0
|
||||
|
||||
for trial in range(num_trials):
|
||||
self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1))
|
||||
task_plus_insights = task
|
||||
|
||||
# Try to retrieve any relevant memories from the DB.
|
||||
filtered_memos = await self.retrieve_relevant_memos(task)
|
||||
filtered_insights = [memo.insight for memo in filtered_memos]
|
||||
if len(filtered_insights) > 0:
|
||||
self.logger.info("Relevant insights were retrieved from memory.\n")
|
||||
memory_section = self._format_memory_section(filtered_insights)
|
||||
if len(memory_section) > 0:
|
||||
task_plus_insights = task + "\n\n" + memory_section
|
||||
|
||||
# Attempt to solve the task.
|
||||
self.logger.info("Try to solve the task.\n")
|
||||
response, _ = await self.task_assignment_callback(task_plus_insights)
|
||||
|
||||
# Check if the response is correct.
|
||||
response_is_correct, extracted_answer = await self.grader.is_response_correct(
|
||||
task, response, expected_answer
|
||||
)
|
||||
self.logger.info("Extracted answer: {}".format(extracted_answer))
|
||||
if response_is_correct:
|
||||
self.logger.info("Answer is CORRECT.\n")
|
||||
num_successes += 1
|
||||
else:
|
||||
self.logger.info("Answer is INCORRECT.\n")
|
||||
|
||||
# Calculate the success rate as a percentage, rounded to the nearest whole number.
|
||||
self.logger.info("\nSuccess rate: {}%\n".format(round((num_successes / num_trials) * 100)))
|
||||
self.logger.leave_function()
|
||||
return response, num_successes, num_trials
|
||||
|
||||
async def add_memo(self, insight: str, task: None | str = None, index_on_both: bool = True) -> None:
|
||||
"""
|
||||
Adds one insight to the memory bank, using the task (if provided) as context.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
generalized_task = ""
|
||||
if task is not None:
|
||||
self.logger.info("\nGIVEN TASK:")
|
||||
self.logger.info(task)
|
||||
if self.generalize_task:
|
||||
generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task)
|
||||
else:
|
||||
generalized_task = task
|
||||
|
||||
self.logger.info("\nGIVEN INSIGHT:")
|
||||
self.logger.info(insight)
|
||||
|
||||
# Get a list of topics from the insight and the task (if provided).
|
||||
if task is None:
|
||||
text_to_index = insight
|
||||
self.logger.info("\nTOPICS EXTRACTED FROM INSIGHT:")
|
||||
else:
|
||||
if index_on_both:
|
||||
text_to_index = generalized_task.strip() + "\n(Hint: " + insight + ")"
|
||||
self.logger.info("\nTOPICS EXTRACTED FROM TASK AND INSIGHT COMBINED:")
|
||||
else:
|
||||
text_to_index = task
|
||||
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
|
||||
|
||||
if self.generate_topics:
|
||||
topics = await self.prompter.find_index_topics(text_to_index)
|
||||
else:
|
||||
topics = [text_to_index]
|
||||
self.logger.info("\n".join(topics))
|
||||
self.logger.info("")
|
||||
|
||||
# Add the insight to the memory bank.
|
||||
self.memory_bank.add_memo(insight, topics, task)
|
||||
self.logger.leave_function()
|
||||
|
||||
async def add_task_solution_pair_to_memory(self, task: str, solution: str) -> None:
|
||||
"""
|
||||
Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight.
|
||||
This is useful when the task-solution pair is an exemplar of solving a task related to some other task.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
self.logger.info("\nEXAMPLE TASK:")
|
||||
self.logger.info(task)
|
||||
|
||||
self.logger.info("\nEXAMPLE SOLUTION:")
|
||||
self.logger.info(solution)
|
||||
|
||||
# Get a list of topics from the task.
|
||||
if self.generate_topics:
|
||||
topics = await self.prompter.find_index_topics(task.strip())
|
||||
else:
|
||||
topics = [task.strip()]
|
||||
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
|
||||
self.logger.info("\n".join(topics))
|
||||
self.logger.info("")
|
||||
|
||||
# Add the task and solution (as a combined insight) to the memory bank.
|
||||
self.memory_bank.add_task_with_solution(task=task, solution=solution, topics=topics)
|
||||
self.logger.leave_function()
|
||||
|
||||
async def retrieve_relevant_memos(self, task: str) -> List[Memo]:
|
||||
"""
|
||||
Retrieves any memos from memory that seem relevant to the task.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
if self.memory_bank.contains_memos():
|
||||
self.logger.info("\nCURRENT TASK:")
|
||||
self.logger.info(task)
|
||||
|
||||
# Get a list of topics from the generalized task.
|
||||
if self.generalize_task:
|
||||
generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task)
|
||||
else:
|
||||
generalized_task = task
|
||||
if self.generate_topics:
|
||||
task_topics = await self.prompter.find_index_topics(generalized_task)
|
||||
else:
|
||||
task_topics = [generalized_task]
|
||||
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
|
||||
self.logger.info("\n".join(task_topics))
|
||||
self.logger.info("")
|
||||
|
||||
# Retrieve relevant memos from the memory bank.
|
||||
memo_list = self.memory_bank.get_relevant_memos(topics=task_topics)
|
||||
|
||||
# Apply a final validation stage to keep only the memos that the LLM concludes are sufficiently relevant.
|
||||
validated_memos: List[Memo] = []
|
||||
for memo in memo_list:
|
||||
if len(validated_memos) >= self.max_memos_to_retrieve:
|
||||
break
|
||||
if (not self.validate_memos) or await self.prompter.validate_insight(memo.insight, task):
|
||||
validated_memos.append(memo)
|
||||
|
||||
self.logger.info("\n{} VALIDATED MEMOS".format(len(validated_memos)))
|
||||
for memo in validated_memos:
|
||||
if memo.task is not None:
|
||||
self.logger.info("\n TASK: {}".format(memo.task))
|
||||
self.logger.info("\n INSIGHT: {}".format(memo.insight))
|
||||
else:
|
||||
self.logger.info("\nNO SUFFICIENTLY RELEVANT MEMOS WERE FOUND IN MEMORY")
|
||||
validated_memos = []
|
||||
|
||||
self.logger.leave_function()
|
||||
return validated_memos
|
||||
|
||||
def _format_memory_section(self, memories: List[str]) -> str:
|
||||
"""
|
||||
Formats a list of memories as a section for appending to a task description.
|
||||
"""
|
||||
memory_section = ""
|
||||
if len(memories) > 0:
|
||||
memory_section = "## Important insights that may help solve tasks like this\n"
|
||||
for mem in memories:
|
||||
memory_section += "- " + mem + "\n"
|
||||
return memory_section
|
||||
|
||||
async def _test_for_failure(
|
||||
self, task: str, task_plus_insights: str, expected_answer: str
|
||||
) -> Tuple[bool, str, str]:
|
||||
"""
|
||||
Attempts to solve the given task multiple times to find a failure case to learn from.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.logger.info("\nTask description, including any insights: {}".format(task_plus_insights))
|
||||
self.logger.info("\nExpected answer: {}\n".format(expected_answer))
|
||||
|
||||
assert self.task_assignment_callback is not None
|
||||
failure_found = False
|
||||
response, work_history = "", ""
|
||||
|
||||
for trial in range(self.max_test_trials):
|
||||
self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1))
|
||||
|
||||
# Attempt to solve the task.
|
||||
self.logger.info("Try to solve the task.")
|
||||
response, work_history = await self.task_assignment_callback(task_plus_insights)
|
||||
|
||||
response_is_correct, extracted_answer = await self.grader.is_response_correct(
|
||||
task, response, expected_answer
|
||||
)
|
||||
self.logger.info("Extracted answer: {}".format(extracted_answer))
|
||||
if response_is_correct:
|
||||
self.logger.info("Answer is CORRECT.\n")
|
||||
else:
|
||||
self.logger.info("Answer is INCORRECT.\n Stop testing, and return the details of the failure.\n")
|
||||
failure_found = True
|
||||
break
|
||||
|
||||
self.logger.leave_function()
|
||||
return failure_found, response, work_history
|
||||
|
||||
async def _iterate_on_task(self, task: str, expected_answer: str) -> Tuple[str, None | str]:
|
||||
"""
|
||||
Repeatedly assigns a task to the agent, and tries to learn from failures by creating useful insights as memories.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.logger.info("\nTask description: {}".format(task))
|
||||
self.logger.info("\nExpected answer: {}\n".format(expected_answer))
|
||||
|
||||
final_response = ""
|
||||
old_memos = await self.retrieve_relevant_memos(task)
|
||||
old_insights = [memo.insight for memo in old_memos]
|
||||
new_insights: List[str] = []
|
||||
last_insight = None
|
||||
insight = None
|
||||
successful_insight = None
|
||||
|
||||
# Loop until success (or timeout) while learning from failures.
|
||||
for trial in range(1, self.max_train_trials + 1):
|
||||
self.logger.info("\n----- TRAIN TRIAL {} -----\n".format(trial))
|
||||
task_plus_insights = task
|
||||
|
||||
# Add any new insights we've accumulated so far.
|
||||
if last_insight is not None:
|
||||
memory_section = self._format_memory_section(old_insights + [last_insight])
|
||||
else:
|
||||
memory_section = self._format_memory_section(old_insights)
|
||||
if len(memory_section) > 0:
|
||||
task_plus_insights += "\n\n" + memory_section
|
||||
|
||||
# Can we find a failure case to learn from?
|
||||
failure_found, response, work_history = await self._test_for_failure(
|
||||
task, task_plus_insights, expected_answer
|
||||
)
|
||||
if not failure_found:
|
||||
# No. Time to exit the loop.
|
||||
self.logger.info("\nResponse is CORRECT.\n Stop looking for insights.\n")
|
||||
# Was this the first trial?
|
||||
if trial == 1:
|
||||
# Yes. We should return the successful response, and no insight.
|
||||
final_response = response
|
||||
else:
|
||||
# No. We learned a successful insight, which should be returned.
|
||||
successful_insight = insight
|
||||
break
|
||||
|
||||
# Will we try again?
|
||||
if trial == self.max_train_trials:
|
||||
# No. We're out of training trials.
|
||||
self.logger.info("\nNo more trials will be attempted.\n")
|
||||
break
|
||||
|
||||
# Try to learn from this failure.
|
||||
self.logger.info("\nResponse is INCORRECT. Try to learn from this failure.\n")
|
||||
insight = await self.prompter.learn_from_failure(
|
||||
task, memory_section, response, expected_answer, work_history
|
||||
)
|
||||
self.logger.info("\nInsight: {}\n".format(insight))
|
||||
new_insights.append(insight)
|
||||
last_insight = insight
|
||||
|
||||
# Return the answer from the last loop.
|
||||
self.logger.info("\n{}\n".format(final_response))
|
||||
self.logger.leave_function()
|
||||
return final_response, successful_insight
|
||||
|
||||
async def _append_any_relevant_memories(self, task: str) -> str:
|
||||
"""
|
||||
Appends any relevant memories to the task description.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
filtered_memos = await self.retrieve_relevant_memos(task)
|
||||
filtered_insights = [memo.insight for memo in filtered_memos]
|
||||
if len(filtered_insights) > 0:
|
||||
self.logger.info("Relevant insights were retrieved from memory.\n")
|
||||
memory_section = self._format_memory_section(filtered_insights)
|
||||
if len(memory_section) > 0:
|
||||
task = task + "\n\n" + memory_section
|
||||
|
||||
self.logger.leave_function()
|
||||
return task
|
||||
|
||||
async def assign_task(self, task: str, use_memory: bool = True, should_await: bool = True) -> str:
|
||||
"""
|
||||
Assigns a task to some agent through the task_assignment_callback, along with any relevant memories.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
assert self.task_assignment_callback is not None
|
||||
|
||||
if use_memory:
|
||||
task = await self._append_any_relevant_memories(task)
|
||||
|
||||
# Attempt to solve the task.
|
||||
self.logger.info("Try to solve the task.\n")
|
||||
assert should_await
|
||||
response, _ = await self.task_assignment_callback(task)
|
||||
|
||||
self.logger.leave_function()
|
||||
return response
|
||||
|
||||
async def consider_memo_storage(self, text: str) -> str | None:
|
||||
"""
|
||||
Tries to extract any advice from the given text and add it to memory.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
advice = await self.prompter.extract_advice(text)
|
||||
self.logger.info("Advice: {}".format(advice))
|
||||
if advice is not None:
|
||||
await self.add_memo(insight=advice)
|
||||
|
||||
self.logger.leave_function()
|
||||
return advice
|
||||
|
||||
async def handle_user_message(self, text: str, should_await: bool = True) -> str:
|
||||
"""
|
||||
Handles a user message by extracting any advice as an insight to be stored in memory, and then calling assign_task().
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
# Check for advice.
|
||||
advice = await self.consider_memo_storage(text)
|
||||
|
||||
# Assign the task through the task_assignment_callback, using memory only if no advice was just provided.
|
||||
response = await self.assign_task(text, use_memory=(advice is None), should_await=should_await)
|
||||
|
||||
self.logger.leave_function()
|
||||
return response
|
||||
@@ -0,0 +1,15 @@
|
||||
from .apprentice import Apprentice, ApprenticeConfig
|
||||
from .chat_completion_client_recorder import ChatCompletionClientRecorder
|
||||
from .grader import Grader
|
||||
from .page_logger import PageLogger, PageLoggerConfig
|
||||
from .teachability import Teachability
|
||||
|
||||
__all__ = [
|
||||
"Apprentice",
|
||||
"ChatCompletionClientRecorder",
|
||||
"Grader",
|
||||
"PageLogger",
|
||||
"Teachability",
|
||||
"ApprenticeConfig",
|
||||
"PageLoggerConfig",
|
||||
]
|
||||
@@ -0,0 +1,96 @@
|
||||
import hashlib
|
||||
import os
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from agentdhal_core import FunctionCall, Image
|
||||
from agentdhal_core.models import FunctionExecutionResult
|
||||
|
||||
# Convenience types
|
||||
UserContent = Union[str, List[Union[str, Image]]]
|
||||
AssistantContent = Union[str, List[FunctionCall]]
|
||||
FunctionExecutionContent = List[FunctionExecutionResult]
|
||||
SystemContent = str
|
||||
MessageContent = UserContent | AssistantContent | SystemContent | FunctionExecutionContent
|
||||
|
||||
|
||||
def message_content_to_str(message_content: MessageContent | None) -> str:
|
||||
"""
|
||||
Converts the message content to a string.
|
||||
"""
|
||||
if message_content is None:
|
||||
return ""
|
||||
elif isinstance(message_content, str):
|
||||
return message_content
|
||||
elif isinstance(message_content, List):
|
||||
converted: List[str] = list()
|
||||
for item in message_content:
|
||||
if isinstance(item, str):
|
||||
converted.append(item)
|
||||
elif isinstance(item, Image):
|
||||
converted.append("<Image>")
|
||||
else:
|
||||
converted.append(str(item).rstrip())
|
||||
return "\n".join(converted)
|
||||
else:
|
||||
raise AssertionError("Unexpected response type.")
|
||||
|
||||
|
||||
def text_from_user_content(user_content: UserContent) -> str:
|
||||
"""
|
||||
Extracts just the text from the user content.
|
||||
"""
|
||||
if isinstance(user_content, str):
|
||||
return user_content
|
||||
elif isinstance(user_content, List):
|
||||
text_list: List[str] = list()
|
||||
for item in user_content:
|
||||
if isinstance(item, str):
|
||||
text_list.append(item.rstrip())
|
||||
return "\n\n".join(text_list)
|
||||
else:
|
||||
raise AssertionError("Unexpected response type.")
|
||||
|
||||
|
||||
def single_image_from_user_content(user_content: UserContent) -> Union[Image, None]:
|
||||
"""
|
||||
Extracts a single image from the user content.
|
||||
"""
|
||||
image_to_return = None
|
||||
if isinstance(user_content, str):
|
||||
return None
|
||||
elif isinstance(user_content, List):
|
||||
for item in user_content:
|
||||
if isinstance(item, Image):
|
||||
assert image_to_return is None, "Only one image is currently allowed in the user content."
|
||||
image_to_return = item
|
||||
else:
|
||||
raise AssertionError("Unexpected response type.")
|
||||
return image_to_return
|
||||
|
||||
|
||||
def hash_directory(directory: str, hash_algo: str = "sha256") -> Tuple[str, int, int]:
|
||||
"""Computes a hash representing the state of a directory, including its structure and file contents."""
|
||||
hash_func = hashlib.new(hash_algo)
|
||||
|
||||
# Also count the number of files and sub-directories
|
||||
num_files = 0
|
||||
num_subdirs = 0
|
||||
|
||||
for root, dirs, files in sorted(os.walk(directory)): # Ensure order for consistent hashing
|
||||
num_files += len(files)
|
||||
num_subdirs += len(dirs)
|
||||
for dir_name in sorted(dirs):
|
||||
hash_func.update(dir_name.encode()) # Hash directory names
|
||||
|
||||
for file_name in sorted(files):
|
||||
file_path = os.path.join(root, file_name)
|
||||
hash_func.update(file_name.encode()) # Hash file names
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
while chunk := f.read(4096): # Read in chunks
|
||||
hash_func.update(chunk)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return hash_func.hexdigest(), num_files, num_subdirs
|
||||
@@ -0,0 +1,257 @@
|
||||
import random
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, List, Sequence, Tuple, TypedDict
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_agentchat.base import TaskResult
|
||||
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage, TextMessage
|
||||
from agentdhal_core.models import (
|
||||
ChatCompletionClient,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from .page_logger import PageLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..memory_controller import MemoryControllerConfig
|
||||
|
||||
|
||||
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
|
||||
# the settings that change frequently, as when loading many settings from a single YAML file.
|
||||
class ApprenticeConfig(TypedDict, total=False):
|
||||
name_of_agent_or_team: str
|
||||
disable_prefix_caching: bool
|
||||
MemoryController: "MemoryControllerConfig"
|
||||
|
||||
|
||||
class Apprentice:
|
||||
"""
|
||||
A minimal wrapper combining task-centric memory with an agent or team.
|
||||
Applications may use the Apprentice class, or they may directly instantiate
|
||||
and call the Memory Controller using this class as an example.
|
||||
|
||||
Args:
|
||||
client: The client to call the model.
|
||||
config: An optional dict that can be used to override the following values:
|
||||
|
||||
- name_of_agent_or_team: The name of the target agent or team for assigning tasks to.
|
||||
- disable_prefix_caching: True to disable prefix caching by prepending random ints to the first message.
|
||||
- MemoryController: A config dict passed to MemoryController.
|
||||
|
||||
logger: An optional logger. If None, a default logger will be created.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ChatCompletionClient,
|
||||
config: ApprenticeConfig | None = None,
|
||||
logger: PageLogger | None = None,
|
||||
) -> None:
|
||||
if logger is None:
|
||||
logger = PageLogger({"level": "DEBUG"})
|
||||
self.logger = logger
|
||||
|
||||
# Apply default settings and any config overrides.
|
||||
self.name_of_agent_or_team = "AssistantAgent"
|
||||
self.disable_prefix_caching = False
|
||||
memory_controller_config = None
|
||||
if config is not None:
|
||||
self.name_of_agent_or_team = config.get("name_of_agent_or_team", self.name_of_agent_or_team)
|
||||
self.disable_prefix_caching = config.get("disable_prefix_caching", self.disable_prefix_caching)
|
||||
memory_controller_config = config.get("MemoryController", memory_controller_config)
|
||||
|
||||
self.client = client
|
||||
if self.disable_prefix_caching:
|
||||
self.rand = random.Random()
|
||||
self.rand.seed(int(time.time() * 1000))
|
||||
|
||||
# Create the MemoryController, which creates the MemoryBank.
|
||||
from ..memory_controller import MemoryController
|
||||
|
||||
self.memory_controller = MemoryController(
|
||||
reset=True,
|
||||
client=self.client,
|
||||
task_assignment_callback=self.assign_task_to_agent_or_team,
|
||||
config=memory_controller_config,
|
||||
logger=self.logger,
|
||||
)
|
||||
|
||||
def reset_memory(self) -> None:
|
||||
"""
|
||||
Resets the memory bank.
|
||||
"""
|
||||
self.memory_controller.reset_memory()
|
||||
|
||||
async def handle_user_message(self, text: str, should_await: bool = True) -> str:
|
||||
"""
|
||||
Handles a user message, extracting any advice and assigning a task to the agent.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
# Pass the user message through to the memory controller.
|
||||
response = await self.memory_controller.handle_user_message(text, should_await)
|
||||
|
||||
self.logger.leave_function()
|
||||
return response
|
||||
|
||||
async def add_task_solution_pair_to_memory(self, task: str, solution: str) -> None:
|
||||
"""
|
||||
Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight.
|
||||
This is useful when the insight is a demonstration of how to solve a given type of task.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
# Pass the task and solution through to the memory controller.
|
||||
await self.memory_controller.add_task_solution_pair_to_memory(task, solution)
|
||||
|
||||
self.logger.leave_function()
|
||||
|
||||
async def assign_task(self, task: str, use_memory: bool = True, should_await: bool = True) -> str:
|
||||
"""
|
||||
Assigns a task to the agent, along with any relevant insights/memories.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
# Pass the task through to the memory controller.
|
||||
response = await self.memory_controller.assign_task(task, use_memory, should_await)
|
||||
|
||||
self.logger.leave_function()
|
||||
return response
|
||||
|
||||
async def train_on_task(self, task: str, expected_answer: str) -> None:
|
||||
"""
|
||||
Repeatedly assigns a task to the completion agent, and tries to learn from failures by creating useful insights as memories.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
# Pass the task through to the memory controller.
|
||||
await self.memory_controller.train_on_task(task, expected_answer)
|
||||
|
||||
self.logger.leave_function()
|
||||
|
||||
async def assign_task_to_agent_or_team(self, task: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Passes the given task to the target agent or team.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
# Pass the task through.
|
||||
if self.name_of_agent_or_team == "MagenticOneGroupChat":
|
||||
response, work_history = await self._assign_task_to_magentic_one(task)
|
||||
elif self.name_of_agent_or_team == "AssistantAgent":
|
||||
response, work_history = await self._assign_task_to_assistant_agent(task)
|
||||
else:
|
||||
raise AssertionError("Invalid base agent")
|
||||
|
||||
self.logger.leave_function()
|
||||
return response, work_history
|
||||
|
||||
async def _assign_task_to_assistant_agent(self, task: str) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Passes the given task to a newly created AssistantAgent with a generic 6-step system prompt.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.logger.info(task)
|
||||
|
||||
system_message_content = """You are a helpful and thoughtful assistant.
|
||||
In responding to every user message, you follow the same multi-step process given here:
|
||||
1. Explain your understanding of the user message in detail, covering all the important points.
|
||||
2. List as many possible responses as you can think of.
|
||||
3. Carefully list and weigh the pros and cons (if any) of each possible response.
|
||||
4. Critique the pros and cons above, looking for any flaws in your reasoning. But don't make up flaws that don't exist.
|
||||
5. Decide on the best response, looping back to step 1 if none of the responses are satisfactory.
|
||||
6. Finish by providing your final response in the particular format requested by the user."""
|
||||
|
||||
if self.disable_prefix_caching:
|
||||
# Prepend a random int to disable prefix caching.
|
||||
random_str = "({})\n\n".format(self.rand.randint(0, 1000000))
|
||||
system_message_content = random_str + system_message_content
|
||||
|
||||
system_message: LLMMessage
|
||||
if self.client.model_info["family"] == "o1":
|
||||
# No system message allowed, so pass it as the first user message.
|
||||
system_message = UserMessage(content=system_message_content, source="User")
|
||||
else:
|
||||
# System message allowed.
|
||||
system_message = SystemMessage(content=system_message_content)
|
||||
|
||||
user_message: LLMMessage = UserMessage(content=task, source="User")
|
||||
system_message_list: List[LLMMessage] = [system_message]
|
||||
user_message_list: List[LLMMessage] = [user_message]
|
||||
input_messages: List[LLMMessage] = system_message_list + user_message_list
|
||||
|
||||
assistant_agent = AssistantAgent(
|
||||
"assistant_agent",
|
||||
self.client,
|
||||
system_message=system_message_content,
|
||||
)
|
||||
|
||||
# Get the agent's response to the task.
|
||||
task_result: TaskResult = await assistant_agent.run(task=TextMessage(content=task, source="User"))
|
||||
messages: Sequence[BaseAgentEvent | BaseChatMessage] = task_result.messages
|
||||
message: BaseAgentEvent | BaseChatMessage = messages[-1]
|
||||
response_str = message.to_text()
|
||||
|
||||
# Log the model call
|
||||
self.logger.log_model_task(
|
||||
summary="Ask the model to complete the task", input_messages=input_messages, task_result=task_result
|
||||
)
|
||||
self.logger.info("\n----- RESPONSE -----\n\n{}\n".format(response_str))
|
||||
|
||||
# Use the response as the work history as well.
|
||||
work_history = response_str
|
||||
|
||||
self.logger.leave_function()
|
||||
return response_str, work_history
|
||||
|
||||
async def _assign_task_to_magentic_one(self, task: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Instantiates a MagenticOneGroupChat team, and passes the given task to it.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
self.logger.info(task)
|
||||
|
||||
general_agent = AssistantAgent(
|
||||
"general_agent",
|
||||
self.client,
|
||||
description="A general GPT-4o AI assistant capable of performing a variety of tasks.",
|
||||
)
|
||||
|
||||
from agentdhal_extensions.agents.web_surfer import MultimodalWebSurfer
|
||||
|
||||
web_surfer = MultimodalWebSurfer(
|
||||
name="web_surfer",
|
||||
model_client=self.client,
|
||||
downloads_folder="logs",
|
||||
debug_dir="logs",
|
||||
to_save_screenshots=True,
|
||||
)
|
||||
|
||||
from agentdhal_agentchat.teams import MagenticOneGroupChat
|
||||
|
||||
team = MagenticOneGroupChat(
|
||||
[general_agent, web_surfer],
|
||||
model_client=self.client,
|
||||
max_turns=20,
|
||||
)
|
||||
|
||||
# Get the team's response to the task.
|
||||
task_result: TaskResult = await team.run(task=task)
|
||||
|
||||
assert isinstance(task_result, TaskResult)
|
||||
messages = task_result.messages
|
||||
|
||||
response_str_list: List[str] = []
|
||||
for message in messages:
|
||||
response_str_list.append(message.to_text())
|
||||
response_str = "\n".join(response_str_list)
|
||||
|
||||
self.logger.info("\n----- RESPONSE -----\n\n{}\n".format(response_str))
|
||||
|
||||
# MagenticOne's response is the chat history, which we use here as the work history.
|
||||
work_history = response_str
|
||||
|
||||
self.logger.leave_function()
|
||||
return response_str, work_history
|
||||
@@ -0,0 +1,227 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Sequence, TypedDict, Union
|
||||
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.models import (
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
LLMMessage,
|
||||
ModelCapabilities, # type: ignore
|
||||
ModelInfo,
|
||||
RequestUsage,
|
||||
)
|
||||
from agentdhal_core.tools import Tool, ToolSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .page_logger import PageLogger
|
||||
|
||||
|
||||
class RecordDict(TypedDict):
|
||||
mode: Literal["create", "create_stream"]
|
||||
messages: List[Mapping[str, Any]]
|
||||
response: Dict[str, Any]
|
||||
stream: List[Mapping[str, Any]]
|
||||
|
||||
|
||||
class ChatCompletionClientRecorder(ChatCompletionClient):
|
||||
"""
|
||||
A chat completion client that supports fast, large-scale tests of code calling LLM clients.
|
||||
|
||||
Two modes are supported:
|
||||
|
||||
1. "record": delegates to the underlying client while also recording the input messages and responses,
|
||||
which are saved to disk when finalize() is called.
|
||||
2. "replay": loads previously recorded message and responses from disk, then on each call
|
||||
checks that its message matches the recorded message, and returns the recorded response.
|
||||
|
||||
The recorded data is stored as a JSON list of records. Each record is a dictionary with a "mode"
|
||||
field (either "create" or "create_stream"), a serialized list of messages, and either a "response" (for
|
||||
create calls) or a "stream" (a list of streamed outputs for create_stream calls).
|
||||
|
||||
ReplayChatCompletionClient and ChatCompletionCache do similar things, but with significant differences:
|
||||
|
||||
- ReplayChatCompletionClient replays pre-defined responses in a specified order without recording anything or checking the messages sent to the client.
|
||||
- ChatCompletionCache caches responses and replays them for messages that have been seen before, regardless of order, and calls the base client for any uncached messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ChatCompletionClient,
|
||||
mode: Literal["record", "replay"],
|
||||
session_file_path: str,
|
||||
logger: PageLogger | None = None,
|
||||
) -> None:
|
||||
if logger is None:
|
||||
self.logger = PageLogger() # Disabled by default.
|
||||
else:
|
||||
self.logger = logger
|
||||
self.logger.enter_function()
|
||||
self.logger.info("Wrapping the base client in ChatCompletionClientRecorder.")
|
||||
|
||||
self.base_client = client
|
||||
self.mode = mode
|
||||
self.session_file_path = os.path.expanduser(session_file_path)
|
||||
self.records: List[RecordDict] = []
|
||||
self._record_index = 0
|
||||
self._num_checked_records = 0
|
||||
if self.mode == "record":
|
||||
# Prepare to record the messages and responses.
|
||||
self.logger.info("Recording mode enabled.\nRecording session to: " + self.session_file_path)
|
||||
elif self.mode == "replay":
|
||||
# Load the previously recorded messages and responses from disk.
|
||||
self.logger.info("Replay mode enabled.\nRetrieving session from: " + self.session_file_path)
|
||||
try:
|
||||
with open(self.session_file_path, "r") as f:
|
||||
self.records = json.load(f)
|
||||
except Exception as e:
|
||||
error_str = f"\nFailed to load recorded session: '{self.session_file_path}': {e}"
|
||||
self.logger.error(error_str)
|
||||
raise ValueError(error_str) from e
|
||||
|
||||
self.logger.leave_function()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
) -> CreateResult:
|
||||
current_messages: List[Mapping[str, Any]] = [msg.model_dump() for msg in messages]
|
||||
if self.mode == "record":
|
||||
response = await self.base_client.create(
|
||||
messages,
|
||||
tools=tools,
|
||||
json_output=json_output,
|
||||
tool_choice=tool_choice,
|
||||
extra_create_args=extra_create_args,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
rec: RecordDict = {
|
||||
"mode": "create",
|
||||
"messages": current_messages,
|
||||
"response": response.model_dump(),
|
||||
"stream": [],
|
||||
}
|
||||
self.records.append(rec)
|
||||
return response
|
||||
elif self.mode == "replay":
|
||||
if self._record_index >= len(self.records):
|
||||
error_str = "\nNo more recorded turns to check."
|
||||
self.logger.error(error_str)
|
||||
raise ValueError(error_str)
|
||||
rec = self.records[self._record_index]
|
||||
if rec.get("mode") != "create":
|
||||
error_str = f"\nRecorded call type mismatch at index {self._record_index}: expected 'create', got '{rec.get('mode')}'."
|
||||
self.logger.error(error_str)
|
||||
raise ValueError(error_str)
|
||||
recorded_messages = rec.get("messages")
|
||||
if recorded_messages != current_messages:
|
||||
error_str = (
|
||||
"\nCurrent message list doesn't match the recorded message list. See the pagelogs for details."
|
||||
)
|
||||
assert recorded_messages is not None
|
||||
self.logger.log_dict_list(recorded_messages, "recorded message list")
|
||||
assert current_messages is not None
|
||||
self.logger.log_dict_list(current_messages, "current message list")
|
||||
self.logger.error(error_str)
|
||||
raise ValueError(error_str)
|
||||
self._record_index += 1
|
||||
self._num_checked_records += 1
|
||||
|
||||
data = rec.get("response")
|
||||
# Populate a CreateResult from the data.
|
||||
assert data is not None
|
||||
result = CreateResult(
|
||||
content=data.get("content", ""),
|
||||
finish_reason=data.get("finish_reason", "stop"),
|
||||
usage=data.get("usage", RequestUsage(prompt_tokens=0, completion_tokens=0)),
|
||||
cached=True,
|
||||
)
|
||||
return result
|
||||
|
||||
else:
|
||||
error_str = f"\nUnknown mode: {self.mode}"
|
||||
self.logger.error(error_str)
|
||||
raise ValueError(error_str)
|
||||
|
||||
def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
return self.base_client.create_stream(
|
||||
messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
json_output=json_output,
|
||||
extra_create_args=extra_create_args,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.base_client.close()
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
# Calls base_client.actual_usage() and returns the result.
|
||||
return self.base_client.actual_usage()
|
||||
|
||||
def total_usage(self) -> RequestUsage:
|
||||
# Calls base_client.total_usage() and returns the result.
|
||||
return self.base_client.total_usage()
|
||||
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
# Calls base_client.count_tokens() and returns the result.
|
||||
return self.base_client.count_tokens(messages, tools=tools)
|
||||
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
# Calls base_client.remaining_tokens() and returns the result.
|
||||
return self.base_client.remaining_tokens(messages, tools=tools)
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ModelCapabilities: # type: ignore
|
||||
# Calls base_client.capabilities and returns the result.
|
||||
warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2)
|
||||
return self.base_client.capabilities
|
||||
|
||||
@property
|
||||
def model_info(self) -> ModelInfo:
|
||||
# Calls base_client.model_info and returns the result.
|
||||
return self.base_client.model_info
|
||||
|
||||
def finalize(self) -> None:
|
||||
"""
|
||||
In record mode, saves the accumulated records to disk.
|
||||
In replay mode, makes sure all the records were checked.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
if self.mode == "record":
|
||||
try:
|
||||
# Create the directory if it doesn't exist.
|
||||
os.makedirs(os.path.dirname(self.session_file_path), exist_ok=True)
|
||||
# Write the records to disk.
|
||||
with open(self.session_file_path, "w") as f:
|
||||
json.dump(self.records, f, indent=2)
|
||||
self.logger.info("\nRecorded session was saved to: " + self.session_file_path)
|
||||
except Exception as e:
|
||||
error_str = f"Failed to write records to '{self.session_file_path}': {e}"
|
||||
self.logger.error(error_str)
|
||||
raise ValueError(error_str) from e
|
||||
elif self.mode == "replay":
|
||||
if self._num_checked_records < len(self.records):
|
||||
error_str = f"\nEarly termination. Only {self._num_checked_records} of the {len(self.records)} recorded turns were checked."
|
||||
self.logger.error(error_str)
|
||||
raise ValueError(error_str)
|
||||
self.logger.info("\nRecorded session was fully replayed and checked.")
|
||||
self.logger.leave_function()
|
||||
@@ -0,0 +1,179 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
|
||||
from agentdhal_core import Image
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from ._functions import UserContent
|
||||
from .page_logger import PageLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .apprentice import Apprentice
|
||||
|
||||
|
||||
class Grader:
|
||||
"""
|
||||
Runs basic tests, and determines task success without limitation to string matches.
|
||||
|
||||
Args:
|
||||
client: The client to call the model.
|
||||
logger: An optional logger. If None, no logging will be performed.
|
||||
"""
|
||||
|
||||
def __init__(self, client: ChatCompletionClient, logger: PageLogger | None = None) -> None:
|
||||
if logger is None:
|
||||
logger = PageLogger() # Nothing will be logged by this object.
|
||||
self.logger = logger
|
||||
self.client = client
|
||||
|
||||
# Create the chat history
|
||||
self._chat_history: List[LLMMessage] = []
|
||||
|
||||
async def test_apprentice(
|
||||
self,
|
||||
apprentice: Apprentice,
|
||||
task_description: str,
|
||||
expected_answer: str,
|
||||
num_trials: int,
|
||||
use_memory: bool,
|
||||
client: ChatCompletionClient,
|
||||
) -> Tuple[int, int]:
|
||||
self.logger.enter_function()
|
||||
|
||||
self.logger.info("Testing the apprentice on the given task.\n")
|
||||
|
||||
num_successes = 0
|
||||
|
||||
for trial in range(num_trials):
|
||||
self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1))
|
||||
self.logger.info("Try to solve the task.\n")
|
||||
response = await apprentice.assign_task(task_description, use_memory=use_memory)
|
||||
response_is_correct, extracted_answer = await self.is_response_correct(
|
||||
task_description, response, expected_answer
|
||||
)
|
||||
self.logger.info("Extracted answer: {}".format(extracted_answer))
|
||||
if response_is_correct:
|
||||
self.logger.info("Answer is CORRECT.\n")
|
||||
num_successes += 1
|
||||
else:
|
||||
self.logger.info("Answer is INCORRECT.\n")
|
||||
|
||||
self.logger.info("\nSuccess rate: {}%\n".format(round((num_successes / num_trials) * 100)))
|
||||
self.logger.leave_function()
|
||||
return num_successes, num_trials
|
||||
|
||||
async def call_model(
|
||||
self,
|
||||
summary: str,
|
||||
user_content: UserContent,
|
||||
system_message_content: str | None = None,
|
||||
keep_these_messages: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Calls the model client with the given input and returns the response.
|
||||
"""
|
||||
# Prepare the input message list
|
||||
if system_message_content is None:
|
||||
system_message_content = "You are a helpful assistant."
|
||||
system_message: LLMMessage
|
||||
if self.client.model_info["family"] == "o1":
|
||||
# No system message allowed, so pass it as the first user message.
|
||||
system_message = UserMessage(content=system_message_content, source="User")
|
||||
else:
|
||||
# System message allowed.
|
||||
system_message = SystemMessage(content=system_message_content)
|
||||
user_message = UserMessage(content=user_content, source="User")
|
||||
input_messages = [system_message] + self._chat_history + [user_message]
|
||||
|
||||
# Call the model.
|
||||
response = await self.client.create(input_messages)
|
||||
assert isinstance(response, CreateResult)
|
||||
response_string = response.content
|
||||
assert isinstance(response_string, str)
|
||||
response_message = AssistantMessage(content=response_string, source="Assistant")
|
||||
assert isinstance(response_message, AssistantMessage)
|
||||
|
||||
# Log the model call
|
||||
self.logger.log_model_call(summary=summary, input_messages=input_messages, response=response)
|
||||
|
||||
# Manage the chat history
|
||||
if keep_these_messages:
|
||||
self._chat_history.append(user_message)
|
||||
self._chat_history.append(response_message)
|
||||
|
||||
# Return the response as a string
|
||||
return response_string
|
||||
|
||||
def _clear_history(self) -> None:
|
||||
"""
|
||||
Empties the message list containing the chat history.
|
||||
"""
|
||||
self._chat_history = []
|
||||
|
||||
async def is_response_correct(
|
||||
self, task_description: str, response_to_be_graded: str, correct_answer: str
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Determines whether the response is equivalent to the task's correct answer.
|
||||
"""
|
||||
self.logger.enter_function()
|
||||
|
||||
sys_message = """You are a helpful and thoughtful assistant."""
|
||||
|
||||
# Ask the model to extract the answer from the response.
|
||||
user_message: List[Union[str, Image]] = []
|
||||
user_message.append("""Your job is to extract a possible answer to the following question from the given text.
|
||||
- First review the following task.
|
||||
- Then review the text that follows, which may an answer, plus reasoning that led to the answer.
|
||||
- Do not attempt to actually solve the task yourself.
|
||||
- Don't try to judge whether the reasoning steps were correct.
|
||||
- Simply respond by summarizing the answer described in the text, omitting any other parts of the text.
|
||||
- If no answer is present can be extracted from the text, simply reply "None".""")
|
||||
user_message.append("\n# Task description")
|
||||
user_message.append(task_description)
|
||||
user_message.append("\n# Text that may contain an answer")
|
||||
user_message.append(response_to_be_graded)
|
||||
user_message_arg: UserContent = user_message
|
||||
self._clear_history()
|
||||
extracted_answer = await self.call_model(
|
||||
summary="Ask the model to extract the answer",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message_arg,
|
||||
)
|
||||
self.logger.info("Extracted answer: " + extracted_answer)
|
||||
|
||||
# Ask the model to check the answer for correctness.
|
||||
user_message = [
|
||||
"""Your job is to decide whether a given answer to a task is correct or not.
|
||||
- You will be given the task description and the correct, gold-standard answer, along with the answer to be graded.
|
||||
- In general, an answer is correct if it is equivalent to the correct answer.
|
||||
- Specifically, the given answer must contain the important information from the correct answer, and must not in any way contradict the correct answer.
|
||||
- Ignore any differences of grammar, spelling mistakes, punctuation, capitalization, formatting, or extra commentary.
|
||||
- An answer should be considered correct if it omits information that is clearly inferred.
|
||||
- For instance, if the correct answer is "Paris, France", the answer "Paris" should be considered correct.
|
||||
- Respond with a single character: '1' if the answer to be graded is correct", '0' if not."""
|
||||
]
|
||||
user_message.append("\n# Task description")
|
||||
user_message.append(task_description)
|
||||
user_message.append("\n# Correct answer")
|
||||
user_message.append(correct_answer)
|
||||
user_message.append("\n# Answer to be graded")
|
||||
user_message.append(extracted_answer)
|
||||
self._clear_history()
|
||||
decision = await self.call_model(
|
||||
summary="Ask the model to check the answer for correctness",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
self.logger.info("Decision: " + decision)
|
||||
|
||||
self.logger.leave_function()
|
||||
return decision == "1", extracted_answer
|
||||
@@ -0,0 +1,546 @@
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict
|
||||
|
||||
from agentdhal_agentchat.base import TaskResult
|
||||
from agentdhal_agentchat.messages import BaseAgentEvent, BaseChatMessage
|
||||
from agentdhal_core import Image
|
||||
from agentdhal_core.models import (
|
||||
AssistantMessage,
|
||||
CreateResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from ._functions import MessageContent, hash_directory
|
||||
|
||||
|
||||
def _html_opening(file_title: str, finished: bool = False) -> str:
|
||||
"""
|
||||
Returns the opening text of a simple HTML file.
|
||||
"""
|
||||
refresh_tag = '<meta http-equiv="refresh" content="2">' if not finished else ""
|
||||
st = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
{refresh_tag}
|
||||
<title>{file_title}</title>
|
||||
<style>
|
||||
body {{font-size: 20px}}
|
||||
body {{white-space: pre-wrap}}
|
||||
</style>
|
||||
</head>
|
||||
<body>"""
|
||||
return st
|
||||
|
||||
|
||||
def _html_closing() -> str:
|
||||
"""
|
||||
Return the closing text of a simple HTML file.
|
||||
"""
|
||||
return """</body></html>"""
|
||||
|
||||
|
||||
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
|
||||
# the settings that change frequently, as when loading many settings from a single YAML file.
|
||||
class PageLoggerConfig(TypedDict, total=False):
|
||||
level: str
|
||||
path: str
|
||||
|
||||
|
||||
class PageLogger:
|
||||
"""
|
||||
Logs text and images to a set of HTML pages, one per function/method, linked to each other in a call tree.
|
||||
|
||||
Args:
|
||||
config: An optional dict that can be used to override the following values:
|
||||
|
||||
- level: The logging level, one of DEBUG, INFO, WARNING, ERROR, CRITICAL, or NONE.
|
||||
- path: The path to the directory where the log files will be written.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PageLoggerConfig | None = None) -> None:
|
||||
self.levels = {
|
||||
"DEBUG": 10,
|
||||
"INFO": 20,
|
||||
"WARNING": 30,
|
||||
"ERROR": 40,
|
||||
"CRITICAL": 50,
|
||||
"NONE": 100,
|
||||
}
|
||||
|
||||
# Apply default settings and any config overrides.
|
||||
level_str = "NONE" # Default to no logging at all.
|
||||
self.log_dir = "./pagelogs/default"
|
||||
if config is not None:
|
||||
level_str = config.get("level", level_str)
|
||||
self.log_dir = config.get("path", self.log_dir)
|
||||
self.level = self.levels[level_str]
|
||||
self.log_dir = os.path.expanduser(self.log_dir)
|
||||
|
||||
# If the logging level is set to NONE or higher, don't log anything.
|
||||
if self.level >= self.levels["NONE"]:
|
||||
return
|
||||
|
||||
self.page_stack = PageStack()
|
||||
self.pages: List[Page] = []
|
||||
self.last_page_id = 0
|
||||
self.name = "0 Call Tree"
|
||||
self._create_run_dir()
|
||||
self.flush()
|
||||
self.finalized = False
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.finalize()
|
||||
|
||||
def finalize(self) -> None:
|
||||
# Writes a hash of the log directory to a file for change detection.
|
||||
if self.level >= self.levels["NONE"]:
|
||||
return
|
||||
|
||||
# Don't finalize the log if it has already been finalized.
|
||||
if self.finalized:
|
||||
return
|
||||
|
||||
# Do nothing if the app is being forced to exit early.
|
||||
if self.page_stack.size() > 0:
|
||||
return
|
||||
|
||||
self.flush(finished=True)
|
||||
|
||||
# Write the hash and other details to a file.
|
||||
hash_str, num_files, num_subdirs = hash_directory(self.log_dir)
|
||||
hash_path = os.path.join(self.log_dir, "hash.txt")
|
||||
with open(hash_path, "w") as f:
|
||||
f.write(hash_str)
|
||||
f.write("\n")
|
||||
f.write("{} files\n".format(num_files))
|
||||
f.write("{} subdirectories\n".format(num_subdirs))
|
||||
|
||||
self.finalized = True
|
||||
|
||||
@staticmethod
|
||||
def _decorate_text(text: str, color: str, weight: str = "bold", demarcate: bool = False) -> str:
|
||||
"""
|
||||
Returns a string of text with HTML styling for weight and color.
|
||||
"""
|
||||
if demarcate:
|
||||
text = f"<<<<< {text} >>>>>"
|
||||
return f'<span style="color: {color}; font-weight: {weight};">{text}</span>'
|
||||
|
||||
@staticmethod
|
||||
def _link_to_image(image_path: str, description: str) -> str:
|
||||
"""
|
||||
Returns an HTML string defining a thumbnail link to an image.
|
||||
"""
|
||||
# To avoid a bug in heml rendering aht displays underscores to the left of thumbnails,
|
||||
# define the following string on a single line.
|
||||
link = f"""<a href="{image_path}"><img src="{image_path}" alt="{description}" style="width: 300px; height: auto;"></a>"""
|
||||
return link
|
||||
|
||||
def _get_next_page_id(self) -> int:
|
||||
"""Returns the next page id and increments the counter."""
|
||||
self.last_page_id += 1
|
||||
return self.last_page_id
|
||||
|
||||
def _create_run_dir(self) -> None:
|
||||
"""Creates a fresh log directory."""
|
||||
if os.path.exists(self.log_dir):
|
||||
shutil.rmtree(self.log_dir)
|
||||
os.makedirs(self.log_dir)
|
||||
|
||||
def _add_page(self, summary: str, show_in_call_tree: bool = True, finished: bool = True) -> "Page":
|
||||
"""
|
||||
Adds a new page to the log.
|
||||
"""
|
||||
page = Page(
|
||||
page_logger=self,
|
||||
index=self._get_next_page_id(),
|
||||
summary=summary,
|
||||
indent_level=len(self.page_stack.stack),
|
||||
show_in_call_tree=show_in_call_tree,
|
||||
finished=finished,
|
||||
)
|
||||
self.pages.append(page)
|
||||
self.flush()
|
||||
if len(self.page_stack.stack) > 0:
|
||||
# Insert a link to the new page into the calling page.
|
||||
self.info("\n" + page.full_link)
|
||||
return page
|
||||
|
||||
def _log_text(self, text: str) -> None:
|
||||
"""
|
||||
Adds text to the current page.
|
||||
"""
|
||||
page = self.page_stack.top()
|
||||
if page is not None:
|
||||
page.add_lines(text, flush=True)
|
||||
|
||||
def debug(self, line: str) -> None:
|
||||
"""
|
||||
Adds DEBUG text to the current page if debugging level <= DEBUG.
|
||||
"""
|
||||
if self.level <= self.levels["DEBUG"]:
|
||||
self._log_text(line)
|
||||
|
||||
def info(self, line: str) -> None:
|
||||
"""
|
||||
Adds INFO text to the current page if debugging level <= INFO.
|
||||
"""
|
||||
if self.level <= self.levels["INFO"]:
|
||||
self._log_text(line)
|
||||
|
||||
def warning(self, line: str) -> None:
|
||||
"""
|
||||
Adds WARNING text to the current page if debugging level <= WARNING.
|
||||
"""
|
||||
if self.level <= self.levels["WARNING"]:
|
||||
self._log_text(line)
|
||||
|
||||
def error(self, line: str) -> None:
|
||||
"""
|
||||
Adds ERROR text to the current page if debugging level <= ERROR.
|
||||
"""
|
||||
if self.level <= self.levels["ERROR"]:
|
||||
self._log_text(line)
|
||||
|
||||
def critical(self, line: str) -> None:
|
||||
"""
|
||||
Adds CRITICAL text to the current page if debugging level <= CRITICAL.
|
||||
"""
|
||||
if self.level <= self.levels["CRITICAL"]:
|
||||
self._log_text(line)
|
||||
|
||||
def _message_source(self, message: LLMMessage) -> str:
|
||||
"""
|
||||
Returns a decorated string indicating the source of a message.
|
||||
"""
|
||||
source = "UNKNOWN"
|
||||
color = "black"
|
||||
if isinstance(message, SystemMessage):
|
||||
source = "SYSTEM"
|
||||
color = "purple"
|
||||
elif isinstance(message, UserMessage):
|
||||
source = "USER"
|
||||
color = "blue"
|
||||
elif isinstance(message, AssistantMessage):
|
||||
source = "ASSISTANT"
|
||||
color = "green"
|
||||
elif isinstance(message, FunctionExecutionResultMessage):
|
||||
source = "FUNCTION"
|
||||
color = "red"
|
||||
return self._decorate_text(source, color, demarcate=True)
|
||||
|
||||
def _format_message_content(self, message_content: MessageContent) -> str:
|
||||
"""
|
||||
Formats the message content for logging.
|
||||
"""
|
||||
# Start by converting the message content to a list of strings.
|
||||
content_list: List[str] = []
|
||||
content = message_content
|
||||
if isinstance(content, str):
|
||||
content_list.append(content)
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
content_list.append(item.rstrip())
|
||||
elif isinstance(item, Image):
|
||||
# Save the image to disk.
|
||||
image_filename = str(self._get_next_page_id()) + " image.jpg"
|
||||
image_path = os.path.join(self.log_dir, image_filename)
|
||||
item.image.save(image_path)
|
||||
# Add a link to the image.
|
||||
content_list.append(self._link_to_image(image_filename, "message_image"))
|
||||
elif isinstance(item, Dict):
|
||||
# Add a dictionary to the log.
|
||||
json_str = json.dumps(item, indent=4)
|
||||
content_list.append(json_str)
|
||||
else:
|
||||
content_list.append(str(item).rstrip())
|
||||
else:
|
||||
content_list.append("<UNKNOWN MESSAGE CONTENT>")
|
||||
|
||||
# Convert the list of strings to a single string containing newline separators.
|
||||
output = ""
|
||||
for item in content_list:
|
||||
output += f"\n{item}\n"
|
||||
return output
|
||||
|
||||
def log_message_content(self, message_content: MessageContent, summary: str) -> None:
|
||||
"""
|
||||
Adds a page containing the message's content, including any images.
|
||||
"""
|
||||
if self.level > self.levels["INFO"]:
|
||||
return None
|
||||
page = self._add_page(summary=summary, show_in_call_tree=False)
|
||||
self.page_stack.write_stack_to_page(page)
|
||||
page.add_lines(self._format_message_content(message_content=message_content))
|
||||
page.flush()
|
||||
|
||||
def log_dict_list(self, content: List[Mapping[str, Any]], summary: str) -> None:
|
||||
"""
|
||||
Adds a page containing a list of dicts.
|
||||
"""
|
||||
if self.level > self.levels["INFO"]:
|
||||
return None
|
||||
page = self._add_page(summary=summary, show_in_call_tree=False)
|
||||
self.page_stack.write_stack_to_page(page)
|
||||
|
||||
for item in content:
|
||||
json_str = json.dumps(item, indent=4)
|
||||
page.add_lines(json_str)
|
||||
|
||||
page.flush()
|
||||
|
||||
def _log_model_messages(
|
||||
self, summary: str, input_messages: List[LLMMessage], response_str: str, usage: RequestUsage | None
|
||||
) -> Optional["Page"]:
|
||||
"""
|
||||
Adds a page containing the messages to a model (including any input images) and its response.
|
||||
"""
|
||||
page = self._add_page(summary=summary, show_in_call_tree=False)
|
||||
self.page_stack.write_stack_to_page(page)
|
||||
|
||||
if usage is not None:
|
||||
page.add_lines("{} prompt tokens".format(usage.prompt_tokens))
|
||||
page.add_lines("{} completion tokens".format(usage.completion_tokens))
|
||||
for m in input_messages:
|
||||
page.add_lines("\n" + self._message_source(m))
|
||||
page.add_lines(self._format_message_content(message_content=m.content))
|
||||
page.add_lines("\n" + self._decorate_text("ASSISTANT RESPONSE", "green", demarcate=True))
|
||||
page.add_lines("\n" + response_str + "\n")
|
||||
page.flush()
|
||||
return page
|
||||
|
||||
def log_model_call(
|
||||
self, summary: str, input_messages: List[LLMMessage], response: CreateResult
|
||||
) -> Optional["Page"]:
|
||||
"""
|
||||
Logs messages sent to a model and the TaskResult response to a new page.
|
||||
"""
|
||||
if self.level > self.levels["INFO"]:
|
||||
return None
|
||||
|
||||
response_str = response.content
|
||||
if not isinstance(response_str, str):
|
||||
response_str = "??"
|
||||
|
||||
page = self._log_model_messages(summary, input_messages, response_str, response.usage)
|
||||
return page
|
||||
|
||||
def log_model_task(
|
||||
self, summary: str, input_messages: List[LLMMessage], task_result: TaskResult
|
||||
) -> Optional["Page"]:
|
||||
"""
|
||||
Logs messages sent to a model and the TaskResult response to a new page.
|
||||
"""
|
||||
if self.level > self.levels["INFO"]:
|
||||
return None
|
||||
|
||||
messages: Sequence[BaseAgentEvent | BaseChatMessage] = task_result.messages
|
||||
message = messages[-1]
|
||||
response_str = message.to_text()
|
||||
if not isinstance(response_str, str):
|
||||
response_str = "??"
|
||||
|
||||
if hasattr(message, "models_usage"):
|
||||
usage: RequestUsage | None = message.models_usage
|
||||
else:
|
||||
usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
page = self._log_model_messages(summary, input_messages, response_str, usage)
|
||||
return page
|
||||
|
||||
def log_link_to_local_file(self, file_path: str) -> str:
|
||||
"""
|
||||
Returns a link to a local file in the log.
|
||||
"""
|
||||
file_name = os.path.basename(file_path)
|
||||
link = f'<a href="{file_name}">{file_name}</a>'
|
||||
return link
|
||||
|
||||
def add_link_to_image(self, description: str, source_image_path: str) -> None:
|
||||
"""
|
||||
Inserts a thumbnail link to an image to the page.
|
||||
"""
|
||||
# Remove every character from the string 'description' that is not alphanumeric or a space.
|
||||
description = "".join(e for e in description if e.isalnum() or e.isspace())
|
||||
target_image_filename = str(self._get_next_page_id()) + " - " + description
|
||||
# Copy the image to the log directory.
|
||||
local_image_path = os.path.join(self.log_dir, target_image_filename)
|
||||
shutil.copyfile(source_image_path, local_image_path)
|
||||
self._log_text("\n" + description)
|
||||
self._log_text(self._link_to_image(target_image_filename, description))
|
||||
|
||||
def flush(self, finished: bool = False) -> None:
|
||||
"""
|
||||
Writes the current state of the log to disk.
|
||||
"""
|
||||
if self.level > self.levels["INFO"]:
|
||||
return
|
||||
# Create a call tree of the log.
|
||||
call_tree_path = os.path.join(self.log_dir, self.name + ".html")
|
||||
with open(call_tree_path, "w") as f:
|
||||
f.write(_html_opening("0 Call Tree", finished=finished))
|
||||
f.write(f"<h3>{self.name}</h3>")
|
||||
f.write("\n")
|
||||
for page in self.pages:
|
||||
if page.show_in_call_tree:
|
||||
f.write(page.line_text + "\n")
|
||||
f.write("\n")
|
||||
f.write(_html_closing())
|
||||
|
||||
def enter_function(self) -> Optional["Page"]:
|
||||
"""
|
||||
Adds a new page corresponding to the current function call.
|
||||
"""
|
||||
if self.level > self.levels["INFO"]:
|
||||
return None
|
||||
|
||||
page = None
|
||||
frame_type = inspect.currentframe()
|
||||
if frame_type is not None:
|
||||
frame = frame_type.f_back # Get the calling frame
|
||||
if frame is not None:
|
||||
# Check if it's a method by looking for 'self' or 'cls' in f_locals
|
||||
if "self" in frame.f_locals:
|
||||
class_name = type(frame.f_locals["self"]).__name__
|
||||
elif "cls" in frame.f_locals:
|
||||
class_name = frame.f_locals["cls"].__name__
|
||||
else:
|
||||
class_name = None # Not part of a class
|
||||
|
||||
if class_name is None: # Not part of a class
|
||||
caller_name = frame.f_code.co_name
|
||||
else:
|
||||
caller_name = class_name + "." + frame.f_code.co_name
|
||||
|
||||
# Create a new page for this function.
|
||||
page = self._add_page(summary=caller_name, show_in_call_tree=True, finished=False)
|
||||
self.page_stack.push(page)
|
||||
self.page_stack.write_stack_to_page(page)
|
||||
|
||||
page.add_lines("\nENTER {}".format(caller_name), flush=True)
|
||||
return page
|
||||
|
||||
def leave_function(self) -> None:
|
||||
"""
|
||||
Finishes the page corresponding to the current function call.
|
||||
"""
|
||||
if self.level > self.levels["INFO"]:
|
||||
return None
|
||||
page = self.page_stack.top()
|
||||
if page is not None:
|
||||
page.finished = True
|
||||
page.add_lines("\nLEAVE {}".format(page.summary), flush=True)
|
||||
self.page_stack.pop()
|
||||
|
||||
|
||||
class Page:
|
||||
"""
|
||||
Represents a single HTML page in the logger output.
|
||||
|
||||
Args:
|
||||
page_logger: The PageLogger object that created this page.
|
||||
index: The index of the page.
|
||||
summary: A brief summary of the page's contents for display.
|
||||
indent_level: The level of indentation in the call tree.
|
||||
show_in_call_tree: Whether to display the page in the call tree.
|
||||
finished: Whether the page is complete.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
page_logger: PageLogger,
|
||||
index: int,
|
||||
summary: str,
|
||||
indent_level: int,
|
||||
show_in_call_tree: bool = True,
|
||||
finished: bool = True,
|
||||
):
|
||||
"""
|
||||
Initializes and writes to a new HTML page.
|
||||
"""
|
||||
self.page_logger = page_logger
|
||||
self.index_str = str(index)
|
||||
self.summary = summary
|
||||
self.indent_level = indent_level
|
||||
self.show_in_call_tree = show_in_call_tree
|
||||
self.finished = finished
|
||||
self.file_title = self.index_str + " " + self.summary
|
||||
self.indentation_text = "| " * self.indent_level
|
||||
self.full_link = f'<a href="{self.index_str}.html">{self.file_title}</a>'
|
||||
self.line_text = self.indentation_text + self.full_link
|
||||
self.lines: List[str] = []
|
||||
self.flush()
|
||||
|
||||
def add_lines(self, lines: str, flush: bool = False) -> None:
|
||||
"""
|
||||
Adds one or more lines to the page.
|
||||
"""
|
||||
lines_to_add: List[str] = []
|
||||
if "\n" in lines:
|
||||
lines_to_add = lines.split("\n")
|
||||
else:
|
||||
lines_to_add.append(lines)
|
||||
self.lines.extend(lines_to_add)
|
||||
if flush:
|
||||
self.flush()
|
||||
|
||||
def flush(self) -> None:
|
||||
"""
|
||||
Writes the HTML page to disk.
|
||||
"""
|
||||
page_path = os.path.join(self.page_logger.log_dir, self.index_str + ".html")
|
||||
with open(page_path, "w") as f:
|
||||
f.write(_html_opening(self.file_title, finished=self.finished))
|
||||
f.write(f"<h3>{self.file_title}</h3>\n")
|
||||
for line in self.lines:
|
||||
try:
|
||||
f.write(f"{line}\n")
|
||||
except UnicodeEncodeError:
|
||||
f.write("UnicodeEncodeError in this line.\n")
|
||||
f.write(_html_closing())
|
||||
f.flush()
|
||||
|
||||
|
||||
class PageStack:
|
||||
"""
|
||||
A call stack containing a list of currently active function pages in the order they called each other.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.stack: List[Page] = []
|
||||
|
||||
def push(self, page: Page) -> None:
|
||||
"""Adds a page to the top of the stack."""
|
||||
self.stack.append(page)
|
||||
|
||||
def pop(self) -> Page:
|
||||
"""Removes and returns the top page from the stack"""
|
||||
return self.stack.pop()
|
||||
|
||||
def size(self) -> int:
|
||||
"""Returns the number of pages in the stack."""
|
||||
return len(self.stack)
|
||||
|
||||
def top(self) -> Page | None:
|
||||
"""Returns the top page from the stack without removing it"""
|
||||
if self.size() == 0:
|
||||
return None
|
||||
return self.stack[-1]
|
||||
|
||||
def write_stack_to_page(self, page: Page) -> None:
|
||||
# Logs a properly indented string displaying the current call stack.
|
||||
page.add_lines("\nCALL STACK")
|
||||
for stack_page in self.stack:
|
||||
page.add_lines(stack_page.line_text)
|
||||
page.add_lines("")
|
||||
page.add_lines("")
|
||||
page.flush()
|
||||
@@ -0,0 +1,133 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agentdhal_core import CancellationToken, Image
|
||||
from agentdhal_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
|
||||
from agentdhal_core.model_context import ChatCompletionContext
|
||||
from agentdhal_core.models import UserMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentdhal_extensions.experimental.task_centric_memory import MemoryController
|
||||
|
||||
|
||||
class Teachability(Memory):
|
||||
"""
|
||||
Gives an AssistantAgent the ability to learn quickly from user teachings, hints, and advice.
|
||||
|
||||
Steps for usage:
|
||||
|
||||
1. Instantiate MemoryController.
|
||||
2. Instantiate Teachability, passing the memory controller as a parameter.
|
||||
3. Instantiate an AssistantAgent, passing the teachability instance (wrapped in a list) as the memory parameter.
|
||||
4. Use the AssistantAgent as usual, such as for chatting with the user.
|
||||
"""
|
||||
|
||||
def __init__(self, memory_controller: "MemoryController", name: str | None = None) -> None:
|
||||
"""Initialize Teachability."""
|
||||
self._memory_controller = memory_controller
|
||||
self._logger = memory_controller.logger
|
||||
self._name = name or "teachability"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get the memory instance identifier."""
|
||||
return self._name
|
||||
|
||||
def _extract_text(self, content_item: str | MemoryContent) -> str:
|
||||
"""Extract searchable text from content."""
|
||||
if isinstance(content_item, str):
|
||||
return content_item
|
||||
|
||||
content = content_item.content
|
||||
mime_type = content_item.mime_type
|
||||
|
||||
if mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]:
|
||||
return str(content)
|
||||
elif mime_type == MemoryMimeType.JSON:
|
||||
if isinstance(content, dict):
|
||||
# Store original JSON string representation
|
||||
return str(content).lower()
|
||||
raise ValueError("JSON content must be a dict")
|
||||
elif isinstance(content, Image):
|
||||
raise ValueError("Image content cannot be converted to text")
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {mime_type}")
|
||||
|
||||
async def update_context(
|
||||
self,
|
||||
model_context: ChatCompletionContext,
|
||||
) -> UpdateContextResult:
|
||||
"""
|
||||
Extracts any advice from the last user turn to be stored in memory,
|
||||
and adds any relevant memories to the model context.
|
||||
"""
|
||||
self._logger.enter_function()
|
||||
|
||||
# Extract text from the user's last message
|
||||
messages = await model_context.get_messages()
|
||||
if not messages:
|
||||
self._logger.leave_function()
|
||||
return UpdateContextResult(memories=MemoryQueryResult(results=[]))
|
||||
last_message = messages[-1]
|
||||
last_user_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
|
||||
|
||||
# Add any relevant memories to the chat history
|
||||
query_results = await self.query(last_user_text)
|
||||
if query_results.results:
|
||||
memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(query_results.results, 1)]
|
||||
memory_context = "\nPotentially relevant memories:\n" + "\n".join(memory_strings)
|
||||
await model_context.add_message(UserMessage(content=memory_context, source="user"))
|
||||
|
||||
# Add any user advice to memory
|
||||
await self._memory_controller.consider_memo_storage(last_user_text)
|
||||
|
||||
self._logger.leave_function()
|
||||
return UpdateContextResult(memories=query_results)
|
||||
|
||||
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
|
||||
"""
|
||||
Tries to extract any advice from the passed content and add it to memory.
|
||||
"""
|
||||
self._logger.enter_function()
|
||||
|
||||
# Extract text from the incoming content
|
||||
text = self._extract_text(content)
|
||||
|
||||
# Check for advice to add to memory for later turns.
|
||||
await self._memory_controller.consider_memo_storage(text)
|
||||
|
||||
self._logger.leave_function()
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str | MemoryContent,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MemoryQueryResult:
|
||||
"""
|
||||
Returns any memories that seem relevant to the query.
|
||||
"""
|
||||
self._logger.enter_function()
|
||||
|
||||
task = self._extract_text(query)
|
||||
memory_results: list[MemoryContent] = []
|
||||
filtered_memos = await self._memory_controller.retrieve_relevant_memos(task=task)
|
||||
filtered_insights = [memo.insight for memo in filtered_memos]
|
||||
for insight in filtered_insights:
|
||||
self._logger.info(f"Insight: {insight}")
|
||||
memory_content = MemoryContent(
|
||||
content=insight,
|
||||
mime_type="MemoryMimeType.TEXT",
|
||||
metadata={},
|
||||
)
|
||||
memory_results.append(memory_content)
|
||||
|
||||
self._logger.leave_function()
|
||||
return MemoryQueryResult(results=memory_results)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all entries from memory."""
|
||||
self._memory_controller.reset_memory()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up memory resources."""
|
||||
pass # No cleanup needed for this memory implementation
|
||||
Reference in New Issue
Block a user